Integrate LLVM at llvm/llvm-project@6c66b089bcd7

Updates LLVM usage to match
[6c66b089bcd7](https://github.com/llvm/llvm-project/commit/6c66b089bcd7)

PiperOrigin-RevId: 466522991
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
index 0eaac81..707a34a 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
@@ -846,7 +846,7 @@
     // Const op can have a result of dynamic shaped type (e.g. due to constant
     // folding), but we can still derive the shape of a constant tensor for
     // its attribute type.
-    mlir::Attribute tensor_attr = inst->getAttr("value");
+    auto tensor_attr = inst->getAttr("value").cast<mlir::TypedAttr>();
     llvm::ArrayRef<int64_t> shape_ref =
         tensor_attr.getType().cast<TensorType>().getShape();
     if (mlir::failed(check_shape(shape_ref))) return llvm::None;
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index df75625..88e1ec2 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -128,9 +128,21 @@
   }
   return parent_of_defining_op;
 }
-
 }  // namespace
 
+// Returns true when the given type lists contain a single element of shaped
+// type with compatible shapes (unranked shape is compatible with any ranked
+// shape, ranked shapes are compatible if their respective dimensions are
+// compatible, dynamic dimensions are compatible with any size, static
+// dimensions must be equal to be compatible) and identical element types.
+bool VerifyCompatibleShapesSameElementType(TypeRange lhs, TypeRange rhs) {
+  if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
+  if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
+  auto lhsShaped = lhs[0].cast<ShapedType>();
+  auto rhsShaped = rhs[0].cast<ShapedType>();
+  return lhsShaped.getElementType() == rhsShaped.getElementType();
+}
+
 // Returns true when the given operand arguments have the same shape or
 // broadcastable shape within the given rank. If any given shapes are
 // non-static and maximum rank is within the given rank, this method returns
@@ -836,8 +848,9 @@
   int64_t out = 0;
   for (int64_t outer = 0; outer < outer_size; ++outer) {
     for (auto op : operands) {
+      auto typed_attr = op.cast<TypedAttr>();
       const int64_t dim_size =
-          op.getType().cast<RankedTensorType>().getDimSize(axis);
+          typed_attr.getType().cast<RankedTensorType>().getDimSize(axis);
       const int64_t inner_size = dim_size * base_inner_size;
 
       auto input_attrs = op.cast<DenseElementsAttr>().getValues<Attribute>();
@@ -911,15 +924,10 @@
 // CustomOp
 //===----------------------------------------------------------------------===//
 
+// TODO(b/241745316): Confirm that this is always valid
 mlir::LogicalResult CustomOp::verify() {
-  CustomOp op = *this;
-  ConstBytesAttr const_bytes_attr = op.custom_option().cast<ConstBytesAttr>();
-  const int attribute_size = const_bytes_attr.getValue().size();
-  if (const_bytes_attr.getType().isa<ShapedType>() &&
-      attribute_size !=
-          const_bytes_attr.getType().cast<ShapedType>().getDimSize(0))
-    return op.emitOpError(
-        "custom_option should have the same length of content with shape.");
+  // Currently, this is always valid as it is a wrapper around a StringRef of 0
+  // or more characters.
   return success();
 }
 
@@ -4012,7 +4020,8 @@
   // If this is a constant bytes attribute or the result type doesn't match the
   // attribute type, then generate a tfl.pseudo_const.
   if (value.isa<ConstBytesAttr>() ||
-      (value.isa<ElementsAttr>() && value.getType() != type))
+      (value.isa<ElementsAttr>() &&
+       value.cast<ElementsAttr>().getType() != type))
     return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
   if (arith::ConstantOp::isBuildableWith(value, type))
     return builder.create<arith::ConstantOp>(loc, type, value);
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index b835767..e67247d 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -855,7 +855,7 @@
   let hasCanonicalizer = 1;
 
   let builders = [
-    OpBuilder<(ins "Attribute":$value),
+    OpBuilder<(ins "TypedAttr":$value),
     [{
       $_state.addAttribute("value", value);
       $_state.addTypes(value.getType());
@@ -886,7 +886,7 @@
   let results = (outs AnyTensor:$output);
 
   let builders = [
-    OpBuilder<(ins "Attribute":$value, "SparsityParameterAttr":$s_param,
+    OpBuilder<(ins "TypedAttr":$value, "SparsityParameterAttr":$s_param,
       "Attribute":$compressed_data),
     [{
       $_state.addTypes(value.getType());
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
index 5b0518a..92d7c50 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
@@ -184,11 +184,11 @@
 }
 
 func.func @const() -> tensor<2xi32> {
-  %0 = "tf.Const"() {device = "", name = "weights_quant/min", dtype = "tfdtype$DT_INT32", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> (tensor<2xi32>)
+  %0 = "tf.Const"() {device = "", name = "weights_quant/min", dtype = "tfdtype$DT_INT32", value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> (tensor<2xi32>)
   func.return %0: tensor<2xi32>
 
 // CHECK-LABEL: @const
-// CHECK: "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
+// CHECK: "tfl.pseudo_const"() {value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
 }
 
 func.func @shape(%arg0: tensor<?x1001xf32>) -> tensor<2xi32> {
diff --git a/tensorflow/compiler/mlir/lite/utils/validators.cc b/tensorflow/compiler/mlir/lite/utils/validators.cc
index cfe1ad6..d2059c3 100644
--- a/tensorflow/compiler/mlir/lite/utils/validators.cc
+++ b/tensorflow/compiler/mlir/lite/utils/validators.cc
@@ -19,6 +19,7 @@
 
 #include "mlir/Dialect/Traits.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributeInterfaces.h"  // from @llvm-project
 
 namespace mlir {
 namespace TFL {
@@ -100,7 +101,7 @@
   });
 }
 
-bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b) {
+bool IsBroadcastableElementsAttrs(mlir::TypedAttr a, mlir::TypedAttr b) {
   // This would return false if we had unranked tensors (where they should
   // probably be considered as broadcastable), but given we are working with
   // attributes here that shouldn't be an issue,
@@ -116,7 +117,7 @@
   return true;
 }
 
-bool IsDimensionsDegenerateExceptLastOne(Attribute val) {
+bool IsDimensionsDegenerateExceptLastOne(TypedAttr val) {
   if (auto ranked_type = val.getType().dyn_cast<RankedTensorType>()) {
     return IsDimensionsDegenerateExceptLastOne(ranked_type.getShape());
   }
diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h
index da1bf17..e16feb9 100644
--- a/tensorflow/compiler/mlir/lite/utils/validators.h
+++ b/tensorflow/compiler/mlir/lite/utils/validators.h
@@ -103,9 +103,9 @@
 
 /// Returns whether the given `a` and `b` have broadcast-compatible
 /// types.
-bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b);
+bool IsBroadcastableElementsAttrs(mlir::TypedAttr a, mlir::TypedAttr b);
 // Returns true if every dimension of the attribute is 1 except the last one.
-bool IsDimensionsDegenerateExceptLastOne(mlir::Attribute val);
+bool IsDimensionsDegenerateExceptLastOne(mlir::TypedAttr val);
 // Returns true if every element is 1 except the last one.
 bool IsDimensionsDegenerateExceptLastOne(ArrayRef<int64_t> elements_shape);
 
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir
index 68452de..e7e485a 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir
@@ -15,7 +15,7 @@
 // RUN: tf-quant-opt %s -split-input-file -quant-lift-quantizable-spots-as-functions -quant-quantize -verify-each=false | FileCheck %s
 
 func.func private @conv(%input: tensor<1x3x4x3xf32> {tf._user_specified_name = "input_tensor"}) -> tensor<*xf32> attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x3x4x3>]} {
-  %weight = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x3x3x2xf32>
+  %weight = arith.constant dense_resource<__elided__> : tensor<2x3x3x2xf32>
   %bias = arith.constant dense<[7.11401462, 7.05456924]> : tensor<2xf32>
 
   %q_input= "quantfork.qcast"(%input) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform<i8:f32, 0.58810077742034317:-128>>
@@ -34,7 +34,7 @@
 }
 
 // CHECK-DAG: [[bias:%.+]] = "arith.constant"() {value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-DAG: [[weight:%.+]] = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2x!quant.uniform<i8:f32, 0.074855112561992565:-1>>
+// CHECK-DAG: [[weight:%.+]] = "arith.constant"() {value = dense_resource<__elided__> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2x!quant.uniform<i8:f32, 0.074855112561992565:-1>>
 // CHECK: [[q_input:%.+]] = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform<i8:f32, 0.58810077742034317:-128>>
 // CHECK-NEXT: [[q_bias:%.+]] = "quantfork.qcast"([[bias]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i32:f32, 0.044022349891595126>>
 // CHECK-NEXT: [[conv:%.+]] = "tf.PartitionedCall"([[q_input]], [[weight]], [[q_bias]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @[[composite_fn:composite_conv2d_with_bias_and_relu6_fn.*]]} : (tensor<1x3x4x3x!quant.uniform<i8:f32, 0.58810077742034317:-128>>, tensor<2x3x3x2x!quant.uniform<i8:f32, 0.074855112561992565:-1>>, tensor<2x!quant.uniform<i32:f32, 0.044022349891595126>>) -> tensor<*x!quant.uniform<i8:f32, 0.023529411764705882:-128>>
diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
index d391ffb..348f579 100644
--- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
+++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
@@ -369,7 +369,7 @@
           return InvalidArgument("Missing attribute '", output_arg.type_attr(),
                                  "' required for output '", output_arg.name(),
                                  "'");
-        TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
+        TypedAttr type_attr = attr.dyn_cast<TypedAttr>();
         if (!type_attr)
           return InvalidArgument("Attribute '", output_arg.type_attr(),
                                  "' required for output '", output_arg.name(),
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
index 8edd286..35136a9 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
@@ -1539,7 +1539,8 @@
     // we want to provide more flexibility by allowing attributes of scalar
     // types. But we need to wrap it up with ElementsAttr to construct
     // valid TensorFlow constants.
-    type = RankedTensorType::get(/*shape=*/{}, value.getType());
+    auto typed_attr = value.cast<TypedAttr>();
+    type = RankedTensorType::get(/*shape=*/{}, typed_attr.getType());
     return ConstOp::build(builder, result, DenseElementsAttr::get(type, value));
   }
   // TODO(jpienaar): support other TensorFlow specific types.
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc
index 4af02a2..abb18f5 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc
@@ -63,8 +63,8 @@
 
 LogicalResult GlobalTensorOp::verify() {
   GlobalTensorOp global_tensor = *this;
-  if (failed(VerifyTensorTypesCompatible(
-          global_tensor.type(), global_tensor.value().Attribute::getType()))) {
+  if (failed(VerifyTensorTypesCompatible(global_tensor.type(),
+                                         global_tensor.value().getType()))) {
     return global_tensor.emitError() << "'type' and 'value' attributes should "
                                         "have compatible tensor types";
   }
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index 40474ae..c8e6127 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -3491,7 +3491,7 @@
           builder.getUnknownLoc(),
           builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
           value_attr,
-          /*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()),
+          /*type=*/mlir::TypeAttr::get(value_attr.getType()),
           /*is_mutable=*/nullptr);
       op->setAttr(
           "tf_saved_model.exported_names",
diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc
index a72aaea..f9e636c 100644
--- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc
+++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc
@@ -393,7 +393,9 @@
     if (matchPattern(cst_tensor_op.arg(), m_Constant(&array))) {
       llvm::DenseSet<Type> all_types;
       for (auto it : array) {
-        all_types.insert(it.getType());
+        TypedAttr typed_attr = it.dyn_cast<TypedAttr>();
+        if (!typed_attr) return failure();
+        all_types.insert(typed_attr.getType());
       }
       if (all_types.size() != 1) return failure();
       ShapedType new_out_type = RankedTensorType::get(
@@ -408,7 +410,7 @@
       return success();
     }
 
-    Attribute scalar;
+    TypedAttr scalar;
     if (matchPattern(cst_tensor_op.arg(), m_Constant(&scalar))) {
       Type new_out_type = RankedTensorType::get({}, scalar.getType());
       new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, scalar);
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir
index aa88158..bd5ef12 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir
@@ -19,9 +19,8 @@
 
 // Compute the size of an individual element.
 // CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr<f32>
-// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]]{{\[}}[[C1]]]
-// CHECK-SAME:            (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]]{{\[}}1]
+// CHECK-SAME:            (!llvm.ptr<f32>) -> !llvm.ptr<f32>
 // CHECK: [[SIZE_OF_FLOAT:%.*]] = llvm.ptrtoint [[GEP]]
 // CHECK-SAME:            !llvm.ptr<f32> to i64
 
@@ -163,8 +162,7 @@
 func.func @jit_compile_from_str(%ctx: !tf_framework.op_kernel_context)
     -> !tf_framework.jit_callable {
   // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @[[CODE]]
-  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index)
-  // CHECK: %[[CODE_PTR:.*]] = llvm.getelementptr %[[ADDR]][%[[C0]], %[[C0]]]
+  // CHECK: %[[CODE_PTR:.*]] = llvm.getelementptr %[[ADDR]][0, 0]
 
   // Create stack-allocated array for the tile sizes.
   // CHECK: %[[NUM_TILE_SIZES:.*]] = llvm.mlir.constant(3 : i64)
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_kernel_gpu_launch_to_llvm.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_kernel_gpu_launch_to_llvm.mlir
index 61b40cc..cf5b8f9 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_kernel_gpu_launch_to_llvm.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_kernel_gpu_launch_to_llvm.mlir
@@ -19,13 +19,11 @@
 func.func @launch(%ctx: !tf_framework.op_kernel_context, %memref: memref<?x10xf32>) {
   // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
   // CHECK: %[[BLOB:.*]] = llvm.mlir.addressof @kernel_module_blob : !llvm.ptr<array<5 x i8>>
-  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-  // CHECK: %[[BLOB_PTR:.*]] = llvm.getelementptr %[[BLOB]][%[[C0]], %[[C0]]] : (!llvm.ptr<array<5 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  // CHECK: %[[BLOB_PTR:.*]] = llvm.getelementptr %[[BLOB]][0, 0] : (!llvm.ptr<array<5 x i8>>) -> !llvm.ptr<i8>
   // CHECK: %[[NAME:.*]] = llvm.mlir.addressof @kernel_module_the_kernel_kernel_name : !llvm.ptr<array<11 x i8>>
-  // CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : index) : i64
-  // CHECK: %[[NAME_PTR:.*]] = llvm.getelementptr %[[NAME]][%[[C0_1]], %[[C0_1]]] : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  // CHECK: %[[NAME_PTR:.*]] = llvm.getelementptr %[[NAME]][0, 0] : (!llvm.ptr<array<11 x i8>>) -> !llvm.ptr<i8>
   // CHECK: %[[C7:.*]] = llvm.mlir.constant(7 : i32) : i32
-  // CHECK: %[[ARGS:.*]] = llvm.alloca %24 x !llvm.ptr<i8> : (i32) -> !llvm.ptr<ptr<i8>>
+  // CHECK: %[[ARGS:.*]] = llvm.alloca %22 x !llvm.ptr<i8> : (i32) -> !llvm.ptr<ptr<i8>>
   // CHECK: llvm.call @_mlir_ciface_tf_launch_kernel(%[[CTX]], %[[BLOB_PTR]], %[[NAME_PTR]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[ARGS]])
   %c1 = arith.constant 1 : index
   gpu.launch_func  @kernel_module::@the_kernel
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index a25a78b..eeea097 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -1342,7 +1342,7 @@
 // CHECK-LABEL: @opaque_const
 func.func @opaque_const() -> tensor<!tf_type.variant<tensor<2xi32>>> {
   // CHECK-NOT: mhlo.constant
-  %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<2xi32>>>
+  %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<2xi32>>>
   func.return %0 : tensor<!tf_type.variant<tensor<2xi32>>>
 }
 
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/missing_main.mlir b/tensorflow/compiler/mlir/xla/tests/translate/missing_main.mlir
index edd8d5f..976dd72 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/missing_main.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/missing_main.mlir
@@ -2,6 +2,6 @@
 
 // CHECK: conversion requires module with `main`
 func.func @non_main() {
-  %0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
+  %0 = "mhlo.constant"() {value = dense_resource<__elided__> : tensor<4xf32>} : () -> tensor<4xf32>
   func.return
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/opaque_elements_attr.mlir b/tensorflow/compiler/mlir/xla/tests/translate/opaque_elements_attr.mlir
index b3ef4c1..8bb1377 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/opaque_elements_attr.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/opaque_elements_attr.mlir
@@ -2,7 +2,7 @@
 
 // CHECK: Only dense elements attr are supported
 func.func @main() {
-  %0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
+  %0 = "mhlo.constant"() {value = dense_resource<__elided__> : tensor<4xf32>} : () -> tensor<4xf32>
   func.return
 }
 
diff --git a/tensorflow/compiler/xla/mlir_hlo/BUILD b/tensorflow/compiler/xla/mlir_hlo/BUILD
index 7507b60..d9d37f6 100644
--- a/tensorflow/compiler/xla/mlir_hlo/BUILD
+++ b/tensorflow/compiler/xla/mlir_hlo/BUILD
@@ -29,6 +29,7 @@
     compatible_with = get_compatible_with_cloud(),
     includes = ["include"],
     deps = [
+        "@llvm-project//mlir:BuiltinDialectTdFiles",
         "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
         "@llvm-project//mlir:CopyOpInterfaceTdFiles",
         "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td
index 6cb8b05..f0990a2 100644
--- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td
+++ b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td
@@ -30,6 +30,7 @@
 #define CHLO_OPS
 
 include "mlir/IR/OpBase.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -619,7 +620,7 @@
   }];
 
   // TODO(jpienaar): value's type could be tightened.
-  let arguments = (ins AnyAttr:$value, HLO_Tensor:$operand);
+  let arguments = (ins TypedAttrInterface:$value, HLO_Tensor:$operand);
   let results = (outs HLO_Tensor);
 
   let hasFolder = 1;
diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index 4749859..db350ec 100644
--- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -19,10 +19,10 @@
 #define HLO_OPS
 
 include "mlir/Dialect/Shape/IR/ShapeBase.td"
-include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/OpBase.td"
 include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
 include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
 
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index 98260cc..23e9b06 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -722,7 +722,8 @@
     // All XLA types must be tensor types. In the build() method, we want to
     // provide more flexibility by allowing attributes of scalar types. But we
     // need to wrap it up with ElementsAttr to construct valid XLA constants.
-    type = RankedTensorType::get(/*shape=*/{}, value.getType());
+    type =
+        RankedTensorType::get(/*shape=*/{}, value.cast<TypedAttr>().getType());
     value = DenseElementsAttr::get(type.cast<TensorType>(), value);
   }
 
@@ -733,9 +734,11 @@
 }
 
 LogicalResult ConstantOp::inferReturnTypes(
-    MLIRContext*, Optional<Location>, ValueRange, DictionaryAttr attributes,
-    RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
-  Type type = attributes.get("value").getType();
+    MLIRContext*, Optional<Location>, ValueRange operands,
+    DictionaryAttr attributes, RegionRange,
+    SmallVectorImpl<Type>& inferredReturnTypes) {
+  ConstantOpAdaptor adaptor(operands, attributes);
+  Type type = adaptor.value().getType();
   inferredReturnTypes.push_back(type);
   return success();
 }
@@ -9385,13 +9388,14 @@
 
 Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
                                             Type type, Location loc) {
-  // HLO dialect constants require the type of value and result to match.
-  if (type != value.getType()) return nullptr;
+  auto elementsAttr = value.dyn_cast<ElementsAttr>();
   // HLO dialect constants only support ElementsAttr unlike standard dialect
   // constant which supports all attributes.
-  if (auto elementsAttr = value.dyn_cast<ElementsAttr>())
-    return builder.create<mhlo::ConstantOp>(loc, type, elementsAttr);
-  return nullptr;
+  if (!elementsAttr) return nullptr;
+  // HLO dialect constants require the type of value and result to match.
+  if (type != elementsAttr.getType()) return nullptr;
+
+  return builder.create<mhlo::ConstantOp>(loc, type, elementsAttr);
 }
 
 LogicalResult MhloDialect::verifyRegionArgAttribute(Operation* op,
diff --git a/tensorflow/core/ir/types/attributes.td b/tensorflow/core/ir/types/attributes.td
index ba3b314..e924fda 100644
--- a/tensorflow/core/ir/types/attributes.td
+++ b/tensorflow/core/ir/types/attributes.td
@@ -462,7 +462,7 @@
 // TensorProtoAttr
 //===----------------------------------------------------------------------===//
 
-def TF_TensorProtoAttr : TFType_Attr<"TensorProto", [ElementsAttrInterface]> {
+def TF_TensorProtoAttr : TFType_Attr<"TensorProto", [ElementsAttrInterface, TypedAttrInterface]> {
   let mnemonic = "tensor_proto";
 
   let summary = "Attribute that stores TensorFlow TensorProto debug string";
diff --git a/tensorflow/core/transforms/constant_folding/pass.cc b/tensorflow/core/transforms/constant_folding/pass.cc
index f3cc91f..1627fc8 100644
--- a/tensorflow/core/transforms/constant_folding/pass.cc
+++ b/tensorflow/core/transforms/constant_folding/pass.cc
@@ -29,6 +29,7 @@
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/Twine.h"
 #include "mlir/Dialect/Traits.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributeInterfaces.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
@@ -110,7 +111,7 @@
 
 static FailureOr<TFOp> CreateConstantTensorOp(
     OpBuilder &builder, Location loc, StringRef name_prefix, Type type,
-    ValueRange control_operands, Attribute tensor_value,
+    ValueRange control_operands, TypedAttr tensor_value,
     ArrayRef<NamedAttribute> other_attrs = llvm::None) {
   if (type.isa<VariantType>()) return failure();
   // TODO(chiahungduan): Reuse ConstOp Like
@@ -641,7 +642,7 @@
       }
     }
 
-    SmallVector<Attribute> result;
+    SmallVector<TypedAttr> result;
     if (failed(util::EvaluateOperation(cpu_device_.get(), resource_mgr_.get(),
                                        op, const_operands, result))) {
       return failure();
@@ -655,7 +656,7 @@
     StringAttr device_attr = TFOp(op).deviceAttr();
     SmallVector<TFOp> const_ops;
     for (auto &it : llvm::enumerate(result)) {
-      Attribute attr = it.value();
+      TypedAttr attr = it.value();
       FailureOr<TFOp> const_op = CreateConstantTensorOp(
           rewriter, op->getLoc(),
           (Twine(TFOp(op).name(), "/eval_") + Twine(it.index())).str(),
diff --git a/tensorflow/core/transforms/utils/eval_utils.cc b/tensorflow/core/transforms/utils/eval_utils.cc
index 8296856..011f9e1 100644
--- a/tensorflow/core/transforms/utils/eval_utils.cc
+++ b/tensorflow/core/transforms/utils/eval_utils.cc
@@ -78,7 +78,7 @@
 LogicalResult EvaluateOperation(tensorflow::DeviceBase *cpu_device,
                                 tensorflow::ResourceMgr *resource_mgr, TFOp op,
                                 ArrayRef<ElementsAttr> operands,
-                                SmallVectorImpl<Attribute> &results) {
+                                SmallVectorImpl<TypedAttr> &results) {
   assert(cpu_device && "cpu device can't be null");
   assert(resource_mgr && "ResourceMgr can't be null");
 
diff --git a/tensorflow/core/transforms/utils/eval_utils.h b/tensorflow/core/transforms/utils/eval_utils.h
index 7ea2287..39cc303 100644
--- a/tensorflow/core/transforms/utils/eval_utils.h
+++ b/tensorflow/core/transforms/utils/eval_utils.h
@@ -57,11 +57,11 @@
 
 // Attempts to evaluates an MLIR Operation with the op registered kernel. The op
 // is always executed on the local host CPU irrespective of the device attribute
-// of the given op. The results will be filled in the results vecotr.
+// of the given op. The results will be filled in the results vector.
 LogicalResult EvaluateOperation(tensorflow::DeviceBase* cpu_device,
                                 tensorflow::ResourceMgr* resource_mgr, TFOp op,
                                 ArrayRef<ElementsAttr> operands,
-                                SmallVectorImpl<Attribute>& results);
+                                SmallVectorImpl<TypedAttr>& results);
 }  // namespace util
 }  // namespace tfg
 }  // namespace mlir
diff --git a/tensorflow/core/transforms/utils/eval_utils_test.cc b/tensorflow/core/transforms/utils/eval_utils_test.cc
index 35c4fe5..01d33c8 100644
--- a/tensorflow/core/transforms/utils/eval_utils_test.cc
+++ b/tensorflow/core/transforms/utils/eval_utils_test.cc
@@ -60,7 +60,7 @@
   auto cpu_device = std::make_unique<util::SimpleDevice>();
   auto resource_mgr = std::make_unique<tensorflow::ResourceMgr>();
 
-  llvm::SmallVector<Attribute> result;
+  llvm::SmallVector<TypedAttr> result;
 
   // The operand 1 of SwitchOp is not scalar.
   EXPECT_TRUE(failed(
@@ -96,7 +96,7 @@
   auto cpu_device = std::make_unique<util::SimpleDevice>();
   auto resource_mgr = std::make_unique<tensorflow::ResourceMgr>();
 
-  llvm::SmallVector<Attribute> result;
+  llvm::SmallVector<TypedAttr> result;
 
   ASSERT_TRUE(succeeded(util::EvaluateOperation(
       cpu_device.get(), resource_mgr.get(), const_0,
@@ -159,7 +159,7 @@
   auto cpu_device = std::make_unique<util::SimpleDevice>();
   auto resource_mgr = std::make_unique<tensorflow::ResourceMgr>();
 
-  llvm::SmallVector<Attribute> result;
+  llvm::SmallVector<TypedAttr> result;
 
   ASSERT_TRUE(succeeded(
       util::EvaluateOperation(cpu_device.get(), resource_mgr.get(), switch_op,
diff --git a/tensorflow/lite/python/analyzer_test.py b/tensorflow/lite/python/analyzer_test.py
index 8b6dd03..c8ca6ce 100644
--- a/tensorflow/lite/python/analyzer_test.py
+++ b/tensorflow/lite/python/analyzer_test.py
@@ -70,7 +70,7 @@
           model_path=model_path, experimental_use_mlir=True)
     mlir = mock_stdout.getvalue()
     self.assertIn(
-        '%1 = "tfl.pseudo_const"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : '
+        '%1 = "tfl.pseudo_const"() {value = dense_resource<__elided__> : '
         'tensor<3x3x3x8xf32>} : () -> tensor<3x3x3x8xf32>', mlir)
 
   def testTxtWithFlatBufferModel(self):
diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
index 2ed0e9a..872fe4a 100644
--- a/third_party/llvm/workspace.bzl
+++ b/third_party/llvm/workspace.bzl
@@ -4,8 +4,8 @@
 
 def repo(name):
     """Imports LLVM."""
-    LLVM_COMMIT = "af1328ef452b9eaa4e9f0bc115aeda8f40c4bbff"
-    LLVM_SHA256 = "acd25ed3aeb522437093398a4368a9b9e20a5290349442f17bcba0d5b7d59a71"
+    LLVM_COMMIT = "6c66b089bcd7f6dcdb8e2a3a14428a29c4c3da2b"
+    LLVM_SHA256 = "26c4c3ee742f1f4864a5596fa9809f409cfe81e50fcabe21d36175db950e254c"
 
     tf_http_archive(
         name = name,