[tfg] Allow ShapeAttr to print/parse negative dimensions
ShapeAttr can contain dimensions less than -1, but when printed, are printed as `?` and parsed back as -1. This causes a loss of information on MLIR textual roundtrip and makes it difficult to test certain cases.
PiperOrigin-RevId: 442243340
diff --git a/tensorflow/core/ir/importexport/tests/graphdef_to_mlir/negative_shape.pbtxt b/tensorflow/core/ir/importexport/tests/graphdef_to_mlir/negative_shape.pbtxt
index 9db4b82..62a5cc9 100644
--- a/tensorflow/core/ir/importexport/tests/graphdef_to_mlir/negative_shape.pbtxt
+++ b/tensorflow/core/ir/importexport/tests/graphdef_to_mlir/negative_shape.pbtxt
@@ -5,7 +5,7 @@
# Such shapes don't really make sense in TensorFlow, but grappler is using this
# during some "symbolic shape analysis".
# CHECK: _Arg name("test_model")
-# CHECK-SAME: #tf_type.shape<?>
+# CHECK-SAME: #tf_type.shape<-5>
node {
name: "test_model"
op: "_Arg"
diff --git a/tensorflow/core/ir/types/dialect.cc b/tensorflow/core/ir/types/dialect.cc
index c222312..3c13273 100644
--- a/tensorflow/core/ir/types/dialect.cc
+++ b/tensorflow/core/ir/types/dialect.cc
@@ -388,7 +388,7 @@
os << "<";
if (hasRank()) {
auto print_dim = [&](int64_t dim) {
- if (dim > -1)
+ if (dim != -1)
os << dim;
else
os << "?";
@@ -420,10 +420,9 @@
llvm::SMLoc loc = parser.getCurrentLocation();
if (succeeded(parser.parseOptionalQuestion())) {
shape.back() = ShapedType::kDynamicSize;
- } else if (failed(parser.parseInteger(shape.back())) ||
- shape.back() < 0) {
- parser.emitError(loc) << "expected a positive integer or `?` when "
- "parsing a tf.shape attribute";
+ } else if (failed(parser.parseInteger(shape.back()))) {
+ parser.emitError(loc)
+ << "expected an integer or `?` when parsing a tf.shape attribute";
return failure();
}
return success();
diff --git a/tensorflow/core/transforms/consolidate_attrs/tests/negative_shape.mlir b/tensorflow/core/transforms/consolidate_attrs/tests/negative_shape.mlir
new file mode 100644
index 0000000..b7a5488
--- /dev/null
+++ b/tensorflow/core/transforms/consolidate_attrs/tests/negative_shape.mlir
@@ -0,0 +1,10 @@
+// RUN: tfg-transforms-opt %s --tfg-consolidate-attrs | FileCheck %s
+
+// CHECK-LABEL: tfg.graph
+tfg.graph #tf_type.version<producer = 1, min_consumer = 1> {
+ // CHECK: A {tfg.regenerate_output_shapes}
+ // CHECK-SAME: tensor<?xi32>, tensor<4xi32>, tensor<?xi32>
+ %A:3, %ctl = A {
+ _output_shapes = [#tf_type.shape<-4>, #tf_type.shape<4>, #tf_type.shape<?>]
+ } : () -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>)
+}