[TFG][ConstantFolding] Fix SimplifyReshapeOp asserting on i64 shapes
Assert occurs on call to ElementsAttr::getValues<int32_t>.
PiperOrigin-RevId: 469153186
diff --git a/tensorflow/core/transforms/constant_folding/pass.cc b/tensorflow/core/transforms/constant_folding/pass.cc
index edf271c..3a78deb 100644
--- a/tensorflow/core/transforms/constant_folding/pass.cc
+++ b/tensorflow/core/transforms/constant_folding/pass.cc
@@ -2070,12 +2070,14 @@
if (!shape_op || !dialect_->IsConstant(shape_op)) return failure();
auto shape_attr = shape_op->getAttrOfType<ElementsAttr>("value");
- SmallVector<int32_t> new_shape(shape_attr.getValues<int32_t>());
+ // TODO(tlongeri): only reason for SmallVector instead of range directly is
+ // that llvm::zip implementation requires copy assignment (it shouldn't)
+ SmallVector<APInt> new_shape(shape_attr.getValues<APInt>());
if (input_shape.getRank() != new_shape.size()) return failure();
for (const auto &it : llvm::zip(input_shape.getShape(), new_shape)) {
- int32_t dim_0 = std::get<0>(it);
- int32_t dim_1 = std::get<1>(it);
+ int64_t dim_0 = std::get<0>(it);
+ int64_t dim_1 = std::get<1>(it).getSExtValue();
if (dim_0 >= 0 && dim_1 >= 0 && dim_0 != dim_1) return failure();
}
diff --git a/tensorflow/core/transforms/constant_folding/tests/no_op_reshape.mlir b/tensorflow/core/transforms/constant_folding/tests/no_op_reshape.mlir
index 3f82510..6cf0ad5 100644
--- a/tensorflow/core/transforms/constant_folding/tests/no_op_reshape.mlir
+++ b/tensorflow/core/transforms/constant_folding/tests/no_op_reshape.mlir
@@ -34,6 +34,12 @@
// CHECK: Reshape{{.*}} name("r2")
%Reshape_30, %ctl_31 = Reshape(%VariableV2_26, %Const_28) name("r2") {T = f32, Tshape = i32} : (tensor<17x1xf32>, tensor<1xi32>) -> (tensor<17xf32>)
%Square_32, %ctl_33 = Square(%Reshape_30) name("s2") {T = f32} : (tensor<17xf32>) -> (tensor<17xf32>)
+ // CHECK: %[[VAR_5:.*]], %[[CTRL_6:.*]] = VariableV2 name("v5")
+ %VariableV2_27, %ctl_34 = VariableV2 name("v5") {container = "", dtype = f32, shape = #tf_type.shape<4294967296>, shared_name = ""} : () -> (tensor<4294967296xf32>)
+ // CHECK: %[[CONST_5:.*]], %[[CTRL_7:.*]] = Const {{.*}} name("c5")
+ %Const_29, %ctl_35 = Const [%ctl_34] name("c5") {dtype = f32, value = dense<4294967296> : tensor<1xi64>} : () -> (tensor<1xi64>)
+ // CHECK: Identity(%[[VAR_5]]) [%[[CTRL_7]]] name("r5")
+ %Reshape_31, %ctl_36 = Reshape(%VariableV2_27, %Const_29) name("r5") {T = f32, Tshape = i64} : (tensor<4294967296xf32>, tensor<1xi64>) -> (tensor<4294967296xf32>)
return
}
}