Bug fix when the functions with Assign/Read doesn't have bound_input arguments.
PiperOrigin-RevId: 363100822
Change-Id: Ibc35bb355af3120f654855d2cf808e3aa72fe4d4
diff --git a/tensorflow/compiler/mlir/lite/tests/initialize_variables.mlir b/tensorflow/compiler/mlir/lite/tests/initialize_variables.mlir
index f6e36f1..b84b00d 100644
--- a/tensorflow/compiler/mlir/lite/tests/initialize_variables.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/initialize_variables.mlir
@@ -55,3 +55,40 @@
// CHECK: "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
}
}
+
+// -----
+
+// Test for func with no bound_input.
+module attributes {tf_saved_model.semantics} {
+ "tf_saved_model.global_tensor"() {is_mutable, sym_name = "Variable", type = tensor<1x10xf32>, value = dense<0.000000e+00> : tensor<1x10xf32>} : () -> ()
+ func @serving_default(%arg0: tensor<1x10xf32> {tf_saved_model.index_path = ["x"]}, %arg1: tensor<!tf.resource<tensor<1x10xf32>>>{tf_saved_model.bound_input = @Variable}) ->
+ (tensor<1x10xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+ %0 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
+ %1 = tfl.add %0, %arg0 {fused_activation_function = "NONE"} : tensor<1x10xf32>
+ "tf.AssignVariableOp"(%arg1, %1) : (tensor<!tf.resource<tensor<1x10xf32>>>, tensor<1x10xf32>) -> ()
+ %2 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
+ return %2 : tensor<1x10xf32>
+ }
+
+ func private @"FuncWithNoBoundInput"(%arg0: tensor<1x10xf32>, %arg1: tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32> {
+ "tf.AssignVariableOp"(%arg1, %arg0) {device = ""} : (tensor<!tf.resource<tensor<1x10xf32>>>, tensor<1x10xf32>) -> ()
+ %0 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
+ return %0 : tensor<1x10xf32>
+ }
+
+ // CHECK: func @SessionInitializerFunction() attributes {tf_saved_model.exported_names = ["SessionInitializerFunction"]} {
+ // CHECK: %[[RESOURCE:.*]] = "tfl.pseudo_const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ // CHECK: %[[VAL:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
+ // CHECK: "tfl.assign_variable"(%[[RESOURCE]], %[[VAL]]) : (tensor<1xi32>, tensor<1x10xf32>) -> ()
+ // CHECK: return
+ // CHECK: }
+ // CHECK: "tf_saved_model.session_initializer"() {initializers = [@SessionInitializerFunction]} : () -> ()
+ // CHECK: func @serving_default(%arg0: tensor<1x10xf32> {tf_saved_model.index_path = ["x"]}
+ // CHECK: "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
+ //
+ // CHECK: func private @FuncWithNoBoundInput(%arg0: tensor<1x10xf32>, %arg1: tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32> {
+ // CHECK: "tf.AssignVariableOp"(%arg1, %arg0) {device = ""} : (tensor<!tf.resource<tensor<1x10xf32>>>, tensor<1x10xf32>) -> ()
+ // CHECK: %0 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
+ // CHECK: return %0 : tensor<1x10xf32>
+ // CHECK: }
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/initialize_variables.cc b/tensorflow/compiler/mlir/lite/transforms/initialize_variables.cc
index 76f4308..271193b 100644
--- a/tensorflow/compiler/mlir/lite/transforms/initialize_variables.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/initialize_variables.cc
@@ -123,7 +123,11 @@
// with ops that accepts resource as input.
if (!llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp>(op))
return WalkResult::advance();
- tensors_to_initialize.insert(GetGlobalTensorOp(op, symbol_table, func));
+ auto global_tensor = GetGlobalTensorOp(op, symbol_table, func);
+ // In case the function doesn't have bound_input to a resource
+ // then we return nullptr.
+ // We need only to initialize the variables that are bounded.
+ if (global_tensor) tensors_to_initialize.insert(global_tensor);
return WalkResult::advance();
});
}
@@ -154,9 +158,9 @@
void runOnOperation() override {
auto module = getOperation();
- // Use ordered container to make sure ids are deterministic if we got tensor
- // ids from different part, since we have different passes that touches
- // variables.
+ // Use ordered container to make sure ids are deterministic if we got
+ // tensor ids from different part, since we have different passes that
+ // touches variables.
// TODO(b/149099381): Remove integer IDs after adding the new variable
// handle type.
std::map<std::string, int> global_tensor_id;