Define TensorProto attribute to replace the usage of OpaqueElementsAttr
Legacy OpaqueElementsAttr has been removed from the builtin dialect given the support for defining attributes in dialects.
TensorProto attribute stores debug string of TensorFlow TensorProto as string. It uses ElementsAttr interface similar to OpaqueElementsAttr.
PiperOrigin-RevId: 465513015
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 2a49672..abe282e 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -386,6 +386,7 @@
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:status",
"//tensorflow/stream_executor/lib",
@@ -969,6 +970,7 @@
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index 1baa9fe..6cc1ff5 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -69,6 +69,7 @@
#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
@@ -662,11 +663,10 @@
value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
} else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
- auto dialect = elem_type.getContext()->getLoadedDialect("tf");
tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
- value = mlir::OpaqueElementsAttr::get(dialect, shaped_type, mangled);
+ value = mlir::TF::TensorProtoAttr::get(shaped_type, mangled);
} else {
return errors::Unimplemented("Constant of unsupported type");
}
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir
index 9c04580..3d2550b 100644
--- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir
@@ -10,15 +10,15 @@
func.func @complex64() -> tensor<4xcomplex<f32>> {
// CHECK-LABEL: @complex64
- // CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3230303F5C3030305C3030305C303030405C3030305C3030305C303030405C3030305C30303040405C3030305C30303040405C3030305C3030305C323030405C3030305C3030305C3230304022"> : tensor<4xcomplex<f32>>
- %0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3230303F5C3030305C3030305C303030405C3030305C3030305C303030405C3030305C30303040405C3030305C30303040405C3030305C3030305C323030405C3030305C3030305C3230304022"> : tensor<4xcomplex<f32>> } : () -> tensor<4xcomplex<f32>>
+ // CHECK: value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3230303F5C3030305C3030305C303030405C3030305C3030305C303030405C3030305C30303040405C3030305C30303040405C3030305C3030305C323030405C3030305C3030305C3230304022"> : tensor<4xcomplex<f32>>
+ %0 = "tfl.pseudo_const"() { value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3230303F5C3030305C3030305C303030405C3030305C3030305C303030405C3030305C30303040405C3030305C30303040405C3030305C3030305C323030405C3030305C3030305C3230304022"> : tensor<4xcomplex<f32>> } : () -> tensor<4xcomplex<f32>>
func.return %0 : tensor<4xcomplex<f32>>
}
func.func @complex128() -> tensor<4xcomplex<f64>> {
// CHECK-LABEL: @complex128
- // CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex<f64>>
- %0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex<f64>> } : () -> tensor<4xcomplex<f64>>
+ // CHECK: value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex<f64>>
+ %0 = "tfl.pseudo_const"() { value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex<f64>> } : () -> tensor<4xcomplex<f64>>
func.return %0 : tensor<4xcomplex<f64>>
}
diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
index b03fcc7..3bfecee 100644
--- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
@@ -7,7 +7,7 @@
// CHECK-DAG: %[[ELEMENT0:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK-DAG: %[[ELEMENT1:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: %[[LIST:.*]] = "tf.Pack"(%[[ELEMENT0]], %[[ELEMENT1]]) {axis = 0 : i64} : (tensor<3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
- %0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A2022485C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C3030335C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030305C3030315C3030325C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030335C3030345C30303522"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<3xi32>>>
+ %0 = "tf.Const"() {value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A2022485C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C3030335C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030305C3030315C3030325C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030335C3030345C30303522"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<3xi32>>>
// CHECK: return %[[LIST]]
%1 = "tf.TensorListStack"(%0, %arg0) : (tensor<!tf_type.variant<tensor<3xi32>>>, tensor<1xi32>) -> tensor<2x3xi32>
@@ -18,7 +18,7 @@
// CHECK-LABEL: func @emptyTensorlistConst
func.func @emptyTensorlistConst(%arg0 : tensor<1xi32>) -> tensor<0x3xi32> {
- %0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20222A5C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C30303322"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<3xi32>>>
+ %0 = "tf.Const"() {value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20222A5C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C30303322"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<3xi32>>>
// CHECK: "tf.Const"() {value = dense<> : tensor<0x3xi32>} : () -> tensor<0x3xi32>
// CHECK-NOT: tf.TensorListStack
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variant_type_on_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variant_type_on_op.mlir
index 0c84c62..02d3260 100644
--- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variant_type_on_op.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variant_type_on_op.mlir
@@ -34,6 +34,6 @@
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
func.func @main() -> tensor<!tf_type.variant<tensor<2xi32>>> {
- %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<2xi32>>>
+ %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<2xi32>>>
func.return %0 : tensor<!tf_type.variant<tensor<2xi32>>>
}
diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
index 7abacf8..f634afa 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
@@ -246,12 +246,12 @@
LogicalResult matchAndRewrite(
TF::ConstOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Verify that the opaque elements attribute contains tensor of type variant
- // and scalar shape. The variant type should hold a TensorList.
- auto opaque_attr = op.value().dyn_cast<OpaqueElementsAttr>();
- if (!opaque_attr) return failure();
+ // Verify that the tensor proto contains tensor of type variant and scalar
+ // shape. The variant type should hold a TensorList.
+ auto proto_attr = op.value().dyn_cast<TF::TensorProtoAttr>();
+ if (!proto_attr) return failure();
tensorflow::Tensor tensor;
- if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok())
+ if (!tensorflow::ConvertToTensor(proto_attr, &tensor).ok())
return failure();
if (tensor.dtype() != tensorflow::DT_VARIANT) return failure();
if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc
index 87bb95a..c06338a 100644
--- a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc
@@ -18,6 +18,7 @@
#include <string>
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/core/framework/tensor.pb.h"
@@ -49,7 +50,6 @@
} else if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
auto etype = complex_type.getElementType();
if (etype.isF32()) {
- auto dialect = etype.getContext()->getLoadedDialect("tf");
tensorflow::TensorProto repr;
repr.set_dtype(tensorflow::DT_COMPLEX64);
@@ -63,7 +63,7 @@
repr.set_tensor_content(content);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
- attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
+ attr = mlir::TF::TensorProtoAttr::get(scalar_type, mangled);
} else {
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_quant_ops_to_mhlo.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_quant_ops_to_mhlo.cc
index 1db35ed..531834b 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_quant_ops_to_mhlo.cc
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_quant_ops_to_mhlo.cc
@@ -90,8 +90,8 @@
return failure();
// Check whether the rhs operand has constant op.
- OpaqueElementsAttr opaque_attr;
- if (!matchPattern(this_op.rhs(), m_Constant(&opaque_attr)))
+ TF::TensorProtoAttr tensor_proto_attr;
+ if (!matchPattern(this_op.rhs(), m_Constant(&tensor_proto_attr)))
return failure();
// Check whether the rhs_scales operand has constant op.
@@ -126,7 +126,7 @@
Type rhs_type = getSameShapeTensorType(
this_op.rhs().getType().cast<TensorType>(), rhs_elem_ty);
- llvm::StringRef mangled_tensor = opaque_attr.getValue();
+ llvm::StringRef mangled_tensor = tensor_proto_attr.getValue();
absl::string_view tensor_view(mangled_tensor.data(), mangled_tensor.size());
tensorflow::TensorProto tensor_proto;
tensorflow::Status status =
@@ -141,14 +141,14 @@
}
auto arr = t.flat<tensorflow::qint8>();
- auto new_opaque_attr = ElementsAttr(mlir::DenseElementsAttr::get(
+ auto dense_attr = ElementsAttr(mlir::DenseElementsAttr::get(
getSameShapeTensorType(rhs_type.cast<TensorType>(), storage_type),
llvm::makeArrayRef(arr.data(), arr.size())));
Value lhs = this_op.lhs();
rewriter.setInsertionPointAfterValue(this_op.rhs());
Value rhs = rewriter.create<mhlo::ConstantOp>(rewriter.getUnknownLoc(),
- rhs_type, new_opaque_attr);
+ rhs_type, dense_attr);
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<mhlo::DotOp>(op, lhs, rhs,
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td
index 48c7638..f8eb431 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td
@@ -23,7 +23,7 @@
// Converts arith.constant ops from freezing passes back to tf.Const ops.
def ConvertArithConstToTfConst : Pat<
- (Arith_ConstantOp:$res NonOpaqueElementsAttr:$value),
+ (Arith_ConstantOp:$res DenseElementsAttr:$value),
(TF_ConstOp $value),
[(AnyStaticShapeTensor $res)]>;
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.td
index 5623065..88864df 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.td
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.td
@@ -23,6 +23,6 @@
// Converts tf.Const to arith.constant for statically shaped, non-opaque constants.
// Needed for QuantizationDriver to recognize constants.
def ConvertTfConstToArithConst : Pat<
- (TF_ConstOp:$res NonOpaqueElementsAttr:$value),
+ (TF_ConstOp:$res DenseElementsAttr:$value),
(Arith_ConstantOp $value),
[(AnyStaticShapeTensor $res)], (addBenefit 10)>;
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc
index 3442a03..8dbd547 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc
@@ -601,8 +601,8 @@
}
ShapedType tensor_qtype = q_op.getResult().getType().cast<ShapedType>();
- Attribute quantized_attr = Quantize(attr, tensor_qtype);
- if (!quantized_attr) {
+ Attribute tensor_proto_attr = Quantize(attr, tensor_qtype);
+ if (!tensor_proto_attr) {
return failure();
}
@@ -616,23 +616,21 @@
new_type = ConvertIntToQint(new_type, rewriter.getContext());
tensor_qtype = ConvertIntToQint(tensor_qtype, rewriter.getContext());
- // TODO(b/225793355): It adds OpaqueElementsAttr to the constant as a
+ // TODO(b/225793355): It adds TensorProtoAttr to the constant as a
// workaround.
tensorflow::TensorProto tensor_proto;
- if (!mlir::tfg::ConvertToTensorProto(quantized_attr, &tensor_proto)
+ if (!mlir::tfg::ConvertToTensorProto(tensor_proto_attr, &tensor_proto)
.ok()) {
return failure();
}
tensor_proto.set_dtype(tensorflow::DT_QINT8);
- Dialect* dialect = rewriter.getContext()->getLoadedDialect("tf");
-
- quantized_attr = ElementsAttr(OpaqueElementsAttr::get(
- dialect, new_type,
- tensorflow::mangling_util::MangleTensor(tensor_proto)));
+ tensor_proto_attr = ElementsAttr(TF::TensorProtoAttr::get(
+ new_type, tensorflow::mangling_util::MangleTensor(tensor_proto)));
}
- auto const_op = rewriter.create<TF::ConstOp>(loc, new_type, quantized_attr);
+ auto const_op =
+ rewriter.create<TF::ConstOp>(loc, new_type, tensor_proto_attr);
// Add scast op to match quantize -> composition pattern. The added scast
// is then removed by canonicalization. ([scast - scast] -> [])
auto scast_op = rewriter.create<quantfork::StorageCastOp>(
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.td
index 5d59735..fbe7b62 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.td
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.td
@@ -23,6 +23,6 @@
// Converts reamaining arith.constant ops from quantization passes back to
// tf.Const ops.
def ConvertArithConstToTfConst : Pat<
- (Arith_ConstantOp:$res NonOpaqueElementsAttr:$value),
+ (Arith_ConstantOp:$res DenseElementsAttr:$value),
(TF_ConstOp $value),
[(AnyStaticShapeTensor $res)], (addBenefit 20)>;
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td
index 90e7262..30ce9c6 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td
@@ -16,8 +16,8 @@
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
include "mlir/IR/PatternBase.td"
-def NonOpaqueElementsAttr : ElementsAttrBase<
- CPred<"!$_self.isa<OpaqueElementsAttr>()">,
+def DenseElementsAttr : ElementsAttrBase<
+ CPred<"$_self.isa<DenseElementsAttr>()">,
"non-opaque constant tensor">;
// Checks if the data format is "NHWC".
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tf_quant_ops_to_mhlo.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tf_quant_ops_to_mhlo.mlir
index 825438e..96c82c2 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tf_quant_ops_to_mhlo.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tf_quant_ops_to_mhlo.mlir
@@ -15,7 +15,7 @@
// RUN: tf-quant-opt %s -quant-convert-tf-quant-ops-to-mhlo | FileCheck %s
func.func @quantized_matmul_fn(%input: tensor<*xf32>) -> tensor<*xf32> {
- %weight = "tf.Const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F51494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2032207D2064696D207B2073697A653A2032207D207D2074656E736F725F636F6E74656E743A20225C3030315C3030325C3030335C30303422"> : tensor<2x2x!tf_type.qint8> } : () -> tensor<2x2x!tf_type.qint8>
+ %weight = "tf.Const"() { value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F51494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2032207D2064696D207B2073697A653A2032207D207D2074656E736F725F636F6E74656E743A20225C3030315C3030325C3030335C30303422"> : tensor<2x2x!tf_type.qint8> } : () -> tensor<2x2x!tf_type.qint8>
%weight_scales = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
%weight_zps = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h
index ec7480d..d522387 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h
@@ -28,6 +28,7 @@
using mlir::tf_type::FuncAttr; // NOLINT
using mlir::tf_type::PlaceholderAttr; // NOLINT
using mlir::tf_type::ShapeAttr; // NOLINT
+using mlir::tf_type::TensorProtoAttr; // NOLINT
} // end namespace TF
} // end namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h
index a3ff7fa..93e9de1 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h
@@ -91,23 +91,12 @@
return failure();
}
- using DecodeConstantHook = LogicalResult (*)(OpaqueElementsAttr input,
- ElementsAttr &output);
- static void RegisterDecodeConstantHook(DecodeConstantHook fn) {
- decode_constant_hook_ = std::move(fn);
- }
- static LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) {
- if (decode_constant_hook_) return decode_constant_hook_(input, output);
- return failure();
- }
-
// Provides a hook for op interface.
void *getRegisteredInterfaceForOp(mlir::TypeID interface,
mlir::OperationName opName) override;
private:
static ConstantFoldHook constant_fold_hook_;
- static DecodeConstantHook decode_constant_hook_;
// Storage for a custom fallback interface.
TensorFlowRegistryEffectInterfaceFallback *fallback_effect_op_interface_;
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index 3b651b6..9c72980d 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -57,7 +57,6 @@
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
-#include "mlir/Interfaces/DecodeAttributesInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Parser/Parser.h" // from @llvm-project
@@ -93,15 +92,6 @@
}
};
-struct TFDecodeAttributesInterface : public DialectDecodeAttributesInterface {
- TFDecodeAttributesInterface(Dialect *dialect)
- : DialectDecodeAttributesInterface(dialect) {}
- LogicalResult decode(OpaqueElementsAttr input,
- ElementsAttr &output) const override {
- return TensorFlowDialect::decode(input, output);
- }
-};
-
// Helper function that implements the multi-device inlining policy behavior
// for the inliner hook. In particular, for all function body nodes set unset
// placement attributes to match the function call node.
@@ -311,7 +301,6 @@
}
TensorFlowDialect::ConstantFoldHook TensorFlowDialect::constant_fold_hook_;
-TensorFlowDialect::DecodeConstantHook TensorFlowDialect::decode_constant_hook_;
TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
: Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>()) {
@@ -324,8 +313,7 @@
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc.inc"
>();
- addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
- TFConstantFoldInterface>();
+ addInterfaces<TFInlinerInterface, TFConstantFoldInterface>();
fallback_effect_op_interface_ =
new TensorFlowRegistryEffectInterfaceFallback();
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
index 1d1eda3..8af56e9 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
@@ -148,9 +148,9 @@
// CHECK-LABEL: func @tfConst
func.func @tfConst() -> (tensor<4xf32>, tensor<1x1x6x2xf32>) {
- %0 = "tf.Const"() {device = "", name = "Const", dtype = "tfdtype$DT_FLOAT", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F464C4F41540A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20340A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3030305C3030305C323430405C3030305C30303020405C3030305C303030205C3330315C3030305C3030305C3230305C323737220A"> : tensor<4xf32>} : () -> tensor<4xf32>
+ %0 = "tf.Const"() {device = "", name = "Const", dtype = "tfdtype$DT_FLOAT", value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F464C4F41540A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20340A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3030305C3030305C323430405C3030305C30303020405C3030305C303030205C3330315C3030305C3030305C3230305C323737220A"> : tensor<4xf32>} : () -> tensor<4xf32>
%21 = "tf.Const"() {device = "", name = "Const_143", dtype = "tfdtype$DT_FLOAT", value = dense<0.24288677062973696> : tensor<1x1x6x2xf32>} : () -> tensor<1x1x6x2xf32>
- // CHECK-DAG: value = opaque<"tf"
+ // CHECK-DAG: value = #tf_type<tensor_proto
// CHECK-DAG: tf.Const{{.*}} dense<0.242886767> : tensor<1x1x6x2xf32>
func.return %0, %21 : tensor<4xf32>, tensor<1x1x6x2xf32>
}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt
index 1016ae8..a690805 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt
@@ -28,5 +28,5 @@
}
# CHECK: tf.Const
-# CHECK-SAME: value = opaque<"tf", "{{0[xX][0-9a-fA-F]*}}"> : tensor<!tf_type.quint8>
+# CHECK-SAME: value = #tf_type<tensor_proto : "{{0[xX][0-9a-fA-F]*}}"> : tensor<!tf_type.quint8>
# CHECK-SAME: loc(fused["Const:", "Quantized_Constant"])
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple.mlir
index 8ec0e3f..61a2dc5 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple.mlir
@@ -22,7 +22,7 @@
// CHECK-NEXT: }
// CHECK-NEXT: tensor_content: "\200\000\000\000\200\000\000\000"
tf_executor.graph {
- %0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", dtype = "tfdtype$DT_INT32", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32> loc("Empty/shape")
+ %0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", dtype = "tfdtype$DT_INT32", value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32> loc("Empty/shape")
tf_executor.fetch
}
func.return
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple_tf_dialect_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple_tf_dialect_op.mlir
index 7214cdb..780406e 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple_tf_dialect_op.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple_tf_dialect_op.mlir
@@ -25,7 +25,7 @@
// CHECK-NEXT: original_node_names: "n1"
// CHECK-NEXT: original_func_names: "f1"
// CHECK-NEXT: }
- %0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> (tensor<2xi32>) loc(fused[callsite("n1@f1" at callsite("node_name" at "file_loc"))])
+ %0 = "tf.Const"() {value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> (tensor<2xi32>) loc(fused[callsite("n1@f1" at callsite("node_name" at "file_loc"))])
func.return
}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
index 1d8e768..9424dfc 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
@@ -5,23 +5,23 @@
// TODO(hinsu): Remove tests for ops without custom verifiers. These tests were
// added along with manual op definition and are obsolete now that the op
// definitions are auto-generated.
+
// TODO(hinsu): Move attribute and type tests to types.mlir file.
-
//===--------------------------------------------------------------------===//
-// Test TF opaque attributes
+// Test TF TensorProto attributes
//===--------------------------------------------------------------------===//
-// CHECK-LABEL: func @opaquetensorattr
-func.func @opaquetensorattr() -> () {
+// CHECK-LABEL: func @tensorProtoAttr
+func.func @tensorProtoAttr() -> () {
^bb0:
-// CHECK: "tf.opaqueIntTensor"() {bar = opaque<"tf", "0x68656C6C6F"> : tensor<2x1x4xi32>} : () -> ()
- "tf.opaqueIntTensor"(){bar = opaque<"tf", "0x68656C6C6F"> : tensor<2x1x4xi32>} : () -> ()
-// CHECK: "tf.opaqueFloatTensor"() {bar = opaque<"tf", "0x68656C6C6F"> : tensor<2x1x4xf32>} : () -> ()
- "tf.opaqueFloatTensor"(){bar = opaque<"tf", "0x68656C6C6F"> : tensor<2x1x4xf32>} : () -> ()
-// CHECK: "tf.opaqueStringTensor"() {bar = opaque<"tf", "0x68656C6C6F"> : tensor<2x1x4x!tf_type.string>} : () -> ()
- "tf.opaqueStringTensor"(){bar = opaque<"tf", "0x68656C6C6F"> : tensor<2x1x4x!tf_type.string>} : () -> ()
-// CHECK: "tf.opaqueResourceTensor"() {bar = opaque<"tf", "0x68656C6C6F"> : tensor<2x1x4x!tf_type.resource>} : () -> ()
- "tf.opaqueResourceTensor"(){bar = opaque<"tf", "0x68656C6C6F"> : tensor<2x1x4x!tf_type.resource>} : () -> ()
+// CHECK: "tf.TensorProtoIntTensor"() {bar = #tf_type<tensor_proto : "0x68656C6C6F"> : tensor<2x1x4xi32>} : () -> ()
+ "tf.TensorProtoIntTensor"(){bar = #tf_type<tensor_proto : "0x68656C6C6F"> : tensor<2x1x4xi32>} : () -> ()
+// CHECK: "tf.TensorProtoFloatTensor"() {bar = #tf_type<tensor_proto : "0x68656C6C6F"> : tensor<2x1x4xf32>} : () -> ()
+ "tf.TensorProtoFloatTensor"(){bar = #tf_type<tensor_proto : "0x68656C6C6F"> : tensor<2x1x4xf32>} : () -> ()
+// CHECK: "tf.TensorProtoStringTensor"() {bar = #tf_type<tensor_proto : "0x68656C6C6F"> : tensor<2x1x4x!tf_type.string>} : () -> ()
+ "tf.TensorProtoStringTensor"(){bar = #tf_type<tensor_proto : "0x68656C6C6F"> : tensor<2x1x4x!tf_type.string>} : () -> ()
+// CHECK: "tf.TensorProtoResourceTensor"() {bar = #tf_type<tensor_proto : "0x68656C6C6F"> : tensor<2x1x4x!tf_type.resource>} : () -> ()
+ "tf.TensorProtoResourceTensor"(){bar = #tf_type<tensor_proto : "0x68656C6C6F"> : tensor<2x1x4x!tf_type.resource>} : () -> ()
func.return
}
@@ -3624,7 +3624,7 @@
// CHECK-LABEL: func @testParseExampleV2DenseOnlyValid
func.func @testParseExampleV2DenseOnlyValid(%serialized: tensor<32x!tf_type.string>, %names : tensor<32x!tf_type.string>, %dense_keys : tensor<2x!tf_type.string>, %dense_default_0 : tensor<?xf32>, %dense_default_1 : tensor<?xf32>) -> (tensor<32xf32>) {
- %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
+ %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
%result:2 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %dense_keys, %empty_str_vector, %dense_default_0, %dense_default_1) {dense_shapes = [#tf_type.shape<>, #tf_type.shape<>], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 2, 0, 0]> : vector<6xi32>} : (tensor<32x!tf_type.string>, tensor<32x!tf_type.string>, tensor<0x!tf_type.string>, tensor<2x!tf_type.string>, tensor<0x!tf_type.string>, tensor<?xf32>, tensor<?xf32>) -> (tensor<32xf32>, tensor<32xf32>)
func.return %result#0 : tensor<32xf32>
}
@@ -3632,7 +3632,7 @@
// -----
func.func @testParseExampleV2DenseMismatchedInputOutput(%serialized: tensor<32x!tf_type.string>, %names : tensor<32x!tf_type.string>, %dense_keys : tensor<2x!tf_type.string>, %dense_default_0 : tensor<?xf32>, %dense_default_1 : tensor<?xf32>) -> (tensor<32xf32>) {
- %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
+ %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
// expected-error @+1 {{output 'dense_values' should have same length as attribute 'Tdense'}}
%result:3 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %dense_keys, %empty_str_vector, %dense_default_0, %dense_default_1) {dense_shapes = [#tf_type.shape<>, #tf_type.shape<>], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 3, 0, 0]> : vector<6xi32>} : (tensor<32x!tf_type.string>, tensor<32x!tf_type.string>, tensor<0x!tf_type.string>, tensor<2x!tf_type.string>, tensor<0x!tf_type.string>, tensor<?xf32>, tensor<?xf32>) -> (tensor<32xf32>, tensor<32xf32>, tensor<32xi64>)
func.return %result#0 : tensor<32xf32>
@@ -3642,7 +3642,7 @@
// CHECK-LABEL: func @testParseExampleV2SparseOnlyValid
func.func @testParseExampleV2SparseOnlyValid(%serialized: tensor<32x!tf_type.string>, %names : tensor<32x!tf_type.string>, %sparse_keys : tensor<2x!tf_type.string>) -> (tensor<?x2xi64>) {
- %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
+ %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
%result:6 = "tf.ParseExampleV2"(%serialized, %names, %sparse_keys, %empty_str_vector, %empty_str_vector) {dense_shapes = [], num_sparse = 2 : i64, result_segment_sizes = dense<[2, 2, 2, 0, 0, 0]> : vector<6xi32>} : (tensor<32x!tf_type.string>, tensor<32x!tf_type.string>, tensor<2x!tf_type.string>, tensor<0x!tf_type.string>, tensor<0x!tf_type.string>) -> (tensor<?x2xi64>, tensor<?x2xi64>, tensor<?x!tf_type.string>, tensor<?xi64>, tensor<2xi64>, tensor<2xi64>)
func.return %result#0 : tensor<?x2xi64>
}
@@ -3650,7 +3650,7 @@
// -----
func.func @testParseExampleV2SparseInvalidNumSparse(%serialized: tensor<32x!tf_type.string>, %names : tensor<32x!tf_type.string>, %sparse_keys : tensor<2x!tf_type.string>) -> (tensor<?x2xi64>) {
- %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
+ %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
// expected-error @+1 {{attribute 'num_sparse' should be the same as the length of attribute 'sparse_types'}}
%result:6 = "tf.ParseExampleV2"(%serialized, %names, %sparse_keys, %empty_str_vector, %empty_str_vector) {dense_shapes = [], num_sparse = 3 : i64, result_segment_sizes = dense<[2, 2, 2, 0, 0, 0]> : vector<6xi32>} : (tensor<32x!tf_type.string>, tensor<32x!tf_type.string>, tensor<2x!tf_type.string>, tensor<0x!tf_type.string>, tensor<0x!tf_type.string>) -> (tensor<?x2xi64>, tensor<?x2xi64>, tensor<?x!tf_type.string>, tensor<?xi64>, tensor<2xi64>, tensor<2xi64>)
func.return %result#0 : tensor<?x2xi64>
@@ -3659,7 +3659,7 @@
// -----
func.func @testParseExampleV2SparseInvalidSparseIndicesOutput(%serialized: tensor<32x!tf_type.string>, %names : tensor<32x!tf_type.string>, %sparse_keys : tensor<2x!tf_type.string>) -> (tensor<?x2xi64>) {
- %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
+ %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
// expected-error @+1 {{output 'sparse_indices' should have same length as attribute 'sparse_types'}}
%result:5 = "tf.ParseExampleV2"(%serialized, %names, %sparse_keys, %empty_str_vector, %empty_str_vector) {dense_shapes = [], num_sparse = 2 : i64, result_segment_sizes = dense<[1, 2, 2, 0, 0, 0]> : vector<6xi32>} : (tensor<32x!tf_type.string>, tensor<32x!tf_type.string>, tensor<2x!tf_type.string>, tensor<0x!tf_type.string>, tensor<0x!tf_type.string>) -> (tensor<?x2xi64>, tensor<?x!tf_type.string>, tensor<?xi64>, tensor<2xi64>, tensor<2xi64>)
func.return %result#0 : tensor<?x2xi64>
@@ -3668,7 +3668,7 @@
// -----
func.func @testParseExampleV2SparseOnlyValid(%serialized: tensor<32x!tf_type.string>, %names : tensor<32x!tf_type.string>, %sparse_keys : tensor<2x!tf_type.string>) -> (tensor<?x2xi64>) {
- %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
+ %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
// expected-error @+1 {{output 'sparse_shapes' should have same length as attribute 'sparse_types'}}
%result:5 = "tf.ParseExampleV2"(%serialized, %names, %sparse_keys, %empty_str_vector, %empty_str_vector) {dense_shapes = [], num_sparse = 2 : i64, result_segment_sizes = dense<[2, 2, 1, 0, 0, 0]> : vector<6xi32>} : (tensor<32x!tf_type.string>, tensor<32x!tf_type.string>, tensor<2x!tf_type.string>, tensor<0x!tf_type.string>, tensor<0x!tf_type.string>) -> (tensor<?x2xi64>, tensor<?x2xi64>, tensor<?x!tf_type.string>, tensor<?xi64>, tensor<2xi64>)
func.return %result#0 : tensor<?x2xi64>
@@ -3678,7 +3678,7 @@
// CHECK-LABEL: func @testParseExampleV2RaggedOnlyValid
func.func @testParseExampleV2RaggedOnlyValid(%serialized: tensor<32x!tf_type.string>, %names : tensor<32x!tf_type.string>, %ragged_keys : tensor<2x!tf_type.string>) -> (tensor<?xf32>) {
- %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
+ %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
%result:4 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %empty_str_vector, %ragged_keys) {dense_shapes = [], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 0, 2, 2]> : vector<6xi32>} : (tensor<32x!tf_type.string>, tensor<32x!tf_type.string>, tensor<0x!tf_type.string>, tensor<0x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<?xf32>, tensor<?x!tf_type.string>, tensor<?xi32>, tensor<?xi64>)
func.return %result#0 : tensor<?xf32>
}
@@ -3686,7 +3686,7 @@
// -----
func.func @testParseExampleV2RaggedMismatchedOutputLengths(%serialized: tensor<32x!tf_type.string>, %names : tensor<32x!tf_type.string>, %ragged_keys : tensor<2x!tf_type.string>) -> (tensor<?xf32>) {
- %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
+ %empty_str_vector = "tf.Const"() {dtype = !tf_type.string, value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string>
// expected-error @+1 {{attribute 'ragged_value_types' should have same length as attribute 'ragged_split_types'}}
%result:3 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %empty_str_vector, %ragged_keys) {dense_shapes = [], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 0, 2, 1]> : vector<6xi32>} : (tensor<32x!tf_type.string>, tensor<32x!tf_type.string>, tensor<0x!tf_type.string>, tensor<0x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<?xf32>, tensor<?x!tf_type.string>, tensor<?xi32>)
func.return %result#0 : tensor<?xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc
index 09fac6e..8136e7a 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc
@@ -32,24 +32,5 @@
namespace mlir {
namespace {
-// Since this method is passed to MLIR as decode hook it has to conform
-// to LLVM style used by MLIR.
-LogicalResult DecodeOpaqueTensorHook(const OpaqueElementsAttr input,
- ElementsAttr& output) { // NOLINT
- Builder builder(input.getType().getContext());
- auto decoded_attr_or = tensorflow::DecodeOpaqueTensor(input, builder);
- if (!decoded_attr_or.ok()) {
- VLOG(2) << decoded_attr_or.status().error_message();
- return failure();
- }
-
- output = decoded_attr_or.ValueOrDie();
- return success();
-}
-
-static bool init_hooks = ([] () {
- TF::TensorFlowDialect::RegisterDecodeConstantHook(DecodeOpaqueTensorHook);
-}(), true);
-
} // anonymous namespace
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
index baf420f..a187fb1 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
@@ -29,6 +29,7 @@
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
@@ -51,7 +52,6 @@
using mlir::Builder;
using mlir::DenseStringElementsAttr;
using mlir::ElementsAttr;
-using mlir::OpaqueElementsAttr;
using mlir::RankedTensorType;
using mlir::ShapedType;
using mlir::Type;
@@ -152,11 +152,9 @@
case DT_STRING:
return ConvertStringTensor(input_tensor, type);
default:
- // TODO(shpeisman): restructure code to reuse dialect pointer across
- // calls.
- auto* dialect = builder->getContext()->getLoadedDialect("tf");
+ // TODO(hinsu): Remove mangling now that there is a special attribute.
return ElementsAttr(
- OpaqueElementsAttr::get(dialect, type, MangleTensor(input_tensor)));
+ mlir::TF::TensorProtoAttr::get(type, MangleTensor(input_tensor)));
}
#undef CONVERT_FLAT
@@ -305,15 +303,12 @@
}
}
-// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
-Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
- TensorProto* output_tensor) {
- if (attr.isa<OpaqueElementsAttr>()) {
- auto mangled_tensor = attr.cast<OpaqueElementsAttr>().getValue();
- absl::string_view tensor_view(mangled_tensor.data(), mangled_tensor.size());
- return mangling_util::DemangleTensor(tensor_view, output_tensor);
- }
- return InvalidArgument("Unexpected elements attribute type from MLIR.");
+// Converts an Tensor proto attribute to a TensorFlow tensor proto.
+Status ConvertTensorProtoAttr(const mlir::TF::TensorProtoAttr attr,
+ TensorProto* output_tensor) {
+ auto mangled_tensor = attr.getValue();
+ absl::string_view tensor_view(mangled_tensor.data(), mangled_tensor.size());
+ return mangling_util::DemangleTensor(tensor_view, output_tensor);
}
template <typename T>
@@ -404,8 +399,8 @@
output->set_dtype(output_dtype);
ConvertToTensorShapeProto(shape, output->mutable_tensor_shape());
- if (attr.isa<OpaqueElementsAttr>())
- return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
+ if (auto tensor_attr = attr.dyn_cast<mlir::TF::TensorProtoAttr>())
+ return ConvertTensorProtoAttr(tensor_attr, output);
auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
@@ -492,14 +487,4 @@
return OkStatus();
}
-StatusOr<mlir::ElementsAttr> DecodeOpaqueTensor(
- const mlir::OpaqueElementsAttr input_attr, mlir::Builder builder) {
- // TODO(antiagainst): The following logic, albeit simple, involves copying the
- // tensor content multiple times, which is bad. Figure out a better way to
- // achieve the purpose.
- Tensor tensor;
- TF_RETURN_IF_ERROR(ConvertToTensor(input_attr, &tensor));
- return ConvertTensor(tensor, &builder);
-}
-
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h
index 294453e..7b30289 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h
@@ -59,11 +59,6 @@
// Converts an MLIR elements attribute to a TensorFlow tensor.
Status ConvertToTensor(mlir::ElementsAttr attr, Tensor* output_tensor);
-// Decodes the given opaque elements attribute holding tensor content into a
-// human-readable elements attribute.
-StatusOr<mlir::ElementsAttr> DecodeOpaqueTensor(
- mlir::OpaqueElementsAttr input_attr, mlir::Builder builder);
-
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TENSOR_H_
diff --git a/tensorflow/compiler/mlir/tfrt/tests/analysis/compatibility_analysis.mlir b/tensorflow/compiler/mlir/tfrt/tests/analysis/compatibility_analysis.mlir
index 1c6b7be..6be51d6 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/analysis/compatibility_analysis.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/analysis/compatibility_analysis.mlir
@@ -56,7 +56,7 @@
%2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<f32>>>
"tf.AssignVariableOp"(%2, %1) : (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
%empty_str_vector = "tf.Const"()
- {dtype = !tf_type.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>}
+ {dtype = !tf_type.string, value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf_type.string>}
: () -> tensor<0x!tf_type.string>
%result:2 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %dense_keys, %empty_str_vector, %dense_default_0, %dense_default_1)
{dense_shapes = [#tf_type.shape<>, #tf_type.shape<>], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 2, 0, 0]> : vector<6xi32>}
diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/const_tensor.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/const_tensor.mlir
index 191663b..7e8655c 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/const_tensor.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/const_tensor.mlir
@@ -25,6 +25,6 @@
func.func @tensor_proto() -> tensor<!tf_type.quint8> {
// tfrt_fallback_async.const_tensor_proto accepts a serialized tensor proto.
// CHECK: tfrt_fallback_async.const_tensor_proto "\08\0C\12\00\22\01@"
- %0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F5155494E54382074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20224022"> : tensor<!tf_type.quint8>} : () -> tensor<!tf_type.quint8>
+ %0 = "tf.Const"() {value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F5155494E54382074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20224022"> : tensor<!tf_type.quint8>} : () -> tensor<!tf_type.quint8>
func.return %0 : tensor<!tf_type.quint8>
}
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index a753be2..d6100cf 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -1769,17 +1769,17 @@
StatusOr<xla::Literal> CreateArrayLiteralFromAttr(ElementsAttr attr,
xla::Layout layout) {
- if (attr.isa<OpaqueElementsAttr>())
+ auto dense_attr = attr.dyn_cast<DenseElementsAttr>();
+ if (!dense_attr)
return tensorflow::errors::Unimplemented(
- "Opaque elements attr not supported");
+ "Only dense elements attr are supported");
- xla::Shape shape = xla::TypeToShape(attr.getType());
+ xla::Shape shape = xla::TypeToShape(dense_attr.getType());
#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \
case xla_type: { \
xla::Array<cpp_type> source_data(shape.dimensions()); \
- source_data.SetValues( \
- attr.cast<DenseElementsAttr>().getValues<cpp_type>()); \
+ source_data.SetValues(dense_attr.getValues<cpp_type>()); \
return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \
}
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir
index f4308d3..2e6593b 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir
@@ -1322,10 +1322,10 @@
// -----
-// CHECK-LABEL: @opaque_const
-func.func @opaque_const() -> tensor<!tf_type.variant<tensor<2xi32>>> {
+// CHECK-LABEL: @const_with_tensor_proto_attr
+func.func @const_with_tensor_proto_attr() -> tensor<!tf_type.variant<tensor<2xi32>>> {
// CHECK-NOT: mhlo.constant
- %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<2xi32>>>
+ %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<2xi32>>>
func.return %0 : tensor<!tf_type.variant<tensor<2xi32>>>
}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/opaque_elements_attr.mlir b/tensorflow/compiler/mlir/xla/tests/translate/opaque_elements_attr.mlir
index dc79fdc..b3ef4c1 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/opaque_elements_attr.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/opaque_elements_attr.mlir
@@ -1,6 +1,6 @@
// RUN: not tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s 2>&1 | FileCheck %s
-// CHECK: Opaque elements attr not supported
+// CHECK: Only dense elements attr are supported
func.func @main() {
%0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
func.return
diff --git a/tensorflow/core/ir/types/BUILD b/tensorflow/core/ir/types/BUILD
index b8959fc..d9fb158 100644
--- a/tensorflow/core/ir/types/BUILD
+++ b/tensorflow/core/ir/types/BUILD
@@ -18,6 +18,7 @@
],
includes = ["include"],
deps = [
+ "@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:CallInterfacesTdFiles",
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
@@ -104,6 +105,7 @@
":AttributesIncGen",
":DialectIncGen",
":TypesIncGen",
+ "@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
diff --git a/tensorflow/core/ir/types/attributes.td b/tensorflow/core/ir/types/attributes.td
index 50c66aa..ba3b314 100644
--- a/tensorflow/core/ir/types/attributes.td
+++ b/tensorflow/core/ir/types/attributes.td
@@ -19,6 +19,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/SubElementInterfaces.td"
include "tensorflow/core/ir/types/dialect.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
// Base class for TFType dialect attributes.
class TFType_Attr<string name, list<Trait> traits = []>
@@ -447,6 +448,7 @@
//===----------------------------------------------------------------------===//
// Tensorflow devices metadata
+//===----------------------------------------------------------------------===//
// Tensorflow GPU device metadata.
def TFType_GpuDeviceMetadata : TFType_Attr<"GpuDeviceMetadata"> {
@@ -456,5 +458,29 @@
let assemblyFormat = "`<` struct(params) `>`";
}
+//===----------------------------------------------------------------------===//
+// TensorProtoAttr
+//===----------------------------------------------------------------------===//
+
+def TF_TensorProtoAttr : TFType_Attr<"TensorProto", [ElementsAttrInterface]> {
+ let mnemonic = "tensor_proto";
+
+ let summary = "Attribute that stores TensorFlow TensorProto debug string";
+
+ let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type,
+ StringRefParameter<"">:$value);
+ let builders = [
+ AttrBuilderWithInferredContext<(ins "ShapedType":$type,
+ "StringRef":$value), [{
+ return $_get(type.getContext(), type, value);
+ }]>,
+ ];
+ let extraClassDeclaration = [{
+ using ValueType = StringRef;
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+}
+
#endif
diff --git a/tensorflow/core/ir/types/dialect.cc b/tensorflow/core/ir/types/dialect.cc
index 342d3d7..a8eb3b5 100644
--- a/tensorflow/core/ir/types/dialect.cc
+++ b/tensorflow/core/ir/types/dialect.cc
@@ -18,13 +18,16 @@
#include <cstdint>
#include <string>
+#include "absl/strings/escaping.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Traits.h" // from @llvm-project
+#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
@@ -34,6 +37,7 @@
#include "mlir/IR/FunctionImplementation.h" // from @llvm-project
#include "mlir/IR/FunctionInterfaces.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
@@ -854,5 +858,28 @@
Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); }
+Attribute TensorProtoAttr::parse(AsmParser &parser, Type type) {
+ if (parser.parseColon()) {
+ return nullptr;
+ }
+
+ std::string data;
+ if (parser.parseString(&data)) {
+ return nullptr;
+ }
+ if (data.size() < 2 || data.substr(0, 2) != "0x") {
+ parser.emitError(parser.getNameLoc(), "Hex string doesn't start with `0x`");
+ return nullptr;
+ }
+
+ std::string bytes_data = absl::HexStringToBytes(data.substr(2));
+ return TensorProtoAttr::get(type, bytes_data);
+}
+
+void TensorProtoAttr::print(mlir::AsmPrinter &printer) const {
+ StringRef bytes_str = getValue();
+ printer << " : \"0x" << llvm::toHex(bytes_str) << "\"";
+}
+
} // namespace tf_type
} // namespace mlir