blob: 4d8545604a235b0298de634fc904986f8d12f734 [file] [log] [blame]
// RUN: tf-opt %s -split-input-file -tf-tpu-variable-runtime-reformatting| FileCheck %s
// Tests that the pass can correctly transform a training loop with 2 replicas.
!tf_res_f32 = tensor<*x!tf_type.resource<tensor<f32>>>
!tf_res_md_f32 = tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
// CHECK-SAME: %[[ARG1:.*]]: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
// CHECK-SAME: %[[ARG2:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
// CHECK-SAME: %[[ARG3:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
func.func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
%arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
%arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"},
%arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"}) {
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[STATE0:.*]] = "tf.VarHandleOp"()
// CHECK-SAME: device = "/device:TPU:0"
// CHECK: %[[STATE1:.*]] = "tf.VarHandleOp"()
// CHECK-SAME: device = "/device:TPU:1"
// CHECK: %[[WHILE:.*]] = "tf.WhileRegion"(
%1 = "tf.WhileRegion"(%0) ({
// Condition region
// CHECK: ^bb
// CHECK: "tf.Yield"
^bb0(%carg0: tensor<i32>):
%c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%c1) : (tensor<i1>) -> ()
}, {
// Body region
// CHECK: ^bb0
^bb0(%barg0: tensor<i32>):
%b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
%compile:2 = "tf_device.launch"() ({
%b2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and 2 return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
tf_device.return %b2#0, %b2#1 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
"tf_device.launch"() ({
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf_type.string>) -> ()
}) {device = "/device:CPU:0"} : () -> ()
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[R0:.*]]: tensor<*x!tf_type.resource<tensor<f32>>>,
// CHECK-SAME: [%[[ARG2]], %[[ARG3]]] as %[[R1:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>,
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[R_STATE:.*]]: tensor<!tf_type.resource<tensor<2x!tf_type.string>>>
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
%rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf_type.resource<tensor<f32>>>,
[%arg2, %arg3] as %arg31: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>)
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
// CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
%id = "tf.Identity"(%arg30) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource<tensor<f32>>>
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
// CHECK-NEXT: tf_device.return
// CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
"tf_device.launch"() ({
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf_type.string>) -> ()
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
// CHECK: "tf.Yield"
"tf.Yield"(%b1) : (tensor<i32>) -> ()
}) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
// CHECK: %[[DEFAULT:.*]] = "tf.Const"()
// CHECK: tf_device.replicate
// CHECK-SAME: as %[[V0:.*]]: tensor<*x!tf_type.resource<tensor<f32>>>,
// CHECK-SAME: as %[[V1:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>,
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor<!tf_type.resource<tensor<2x!tf_type.string>>>
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]])
// CHECK-NEXT: tf_device.return
// -----
// Tests that the pass does not format variables with other uses.
!tf_res_f32 = tensor<*x!tf_type.resource<tensor<f32>>>
!tf_res_md_f32 = tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: func @main
// CHECK-NOT: TPUReshardVariables
func.func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
%arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
%arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"},
%arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"},
%arg4: !tf_res_f32 {tf.device = "/device:TPU:1"}) {
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.WhileRegion"(%0) ({
// Condition region
^bb0(%carg0: tensor<i32>):
%c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf._UnknownOp1_"(%arg1) : (!tf_res_f32) -> ()
"tf.Yield"(%c1) : (tensor<i1>) -> ()
}, {
// Body region
^bb0(%barg0: tensor<i32>):
%b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%compile:2 = "tf_device.launch"() ({
%b2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 3 parameter and 3 return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
tf_device.return %b2#0, %b2#1 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
"tf_device.launch"() ({
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf_type.string>) -> ()
}) {device = "/device:CPU:0"} : () -> ()
%id0 = "tf.Identity"(%arg3) : (!tf_res_md_f32) -> !tf_res_md_f32
"tf._Unknown_"(%id0) : (!tf_res_md_f32) -> ()
%newvar = "tf._SomeOp"() : () -> !tf_res_f32
%rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: !tf_res_f32,
[%arg2, %arg3] as %arg31: !tf_res_md_f32,
[%newvar, %arg4] as %arg32 : !tf_res_f32)
{_mirrored_variable_indices = [0, 1, 2], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
// %arg30 is used in the cond function, %arg31 has other uses (%id0), and
// %arg32 is not a pass-through.
"tf_device.launch"() ({
"tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %compile#1)
{device_var_reads_indices = [0, 1, 2], device_var_updates_indices = [0, 1, 2]}
: (!tf_res_f32, !tf_res_md_f32, !tf_res_f32, tensor<2x!tf_type.string>) -> ()
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
"tf.Yield"(%b1) : (tensor<i32>) -> ()
}) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
// -----
// Tests that the pass does not format variables when model parallelism is
// present.
!tf_res_f32 = tensor<*x!tf_type.resource<tensor<f32>>>
!tf_res_md_f32 = tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: func @main
// CHECK-NOT: TPUReshardVariables
func.func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
%arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
%arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"},
%arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"}) {
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.WhileRegion"(%0) ({
// Condition region
^bb0(%carg0: tensor<i32>):
%c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%c1) : (tensor<i1>) -> ()
}, {
// Body region
^bb0(%barg0: tensor<i32>):
%b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%compile:2 = "tf_device.launch"() ({
%b2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and 2 return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
tf_device.return %b2#0, %b2#1 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
"tf_device.launch"() ({
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf_type.string>) -> ()
}) {device = "/device:CPU:0"} : () -> ()
%rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf_type.resource<tensor<f32>>>,
[%arg2, %arg3] as %arg31: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>)
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
%id = "tf.Identity"(%arg30) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource<tensor<f32>>>
"tf_device.parallel_execute"() ({
"tf_device.launch"() ({
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf_type.string>) -> ()
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
}, {
}) {} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
"tf.Yield"(%b1) : (tensor<i32>) -> ()
}) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
// -----
// Tests that the pass can correctly transform a training loop with a packed
// variable.
!tf_res_f32 = tensor<*x!tf_type.resource<tensor<f32>>>
!tf_res_md_f32 = tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
// CHECK-SAME: %[[ARG1:.*]]: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
// CHECK-SAME: %[[ARG2:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:COMPOSITE:0"})
func.func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
%arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
%arg2: !tf_res_md_f32 {tf.device = "/device:COMPOSITE:0"}) {
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[STATE0:.*]] = "tf.VarHandleOp"()
// CHECK-SAME: device = "/device:TPU:0"
// CHECK: %[[STATE1:.*]] = "tf.VarHandleOp"()
// CHECK-SAME: device = "/device:TPU:1"
// CHECK: %[[WHILE:.*]] = "tf.WhileRegion"(
%1 = "tf.WhileRegion"(%0) ({
// Condition region
// CHECK: ^bb
// CHECK: "tf.Yield"
^bb0(%carg0: tensor<i32>):
%c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%c1) : (tensor<i1>) -> ()
}, {
// Body region
// CHECK: ^bb0
^bb0(%barg0: tensor<i32>):
%b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
%compile:2 = "tf_device.launch"() ({
%b2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and 2 return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
tf_device.return %b2#0, %b2#1 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
"tf_device.launch"() ({
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf_type.string>) -> ()
}) {device = "/device:CPU:0"} : () -> ()
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[R0:.*]]: tensor<*x!tf_type.resource<tensor<f32>>>,
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[R_STATE:.*]]: tensor<!tf_type.resource<tensor<2x!tf_type.string>>>
// CHECK-SAME: %[[ARG2]] as %[[R1:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
%rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf_type.resource<tensor<f32>>>,
%arg2 as %arg31: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>)
{_mirrored_variable_indices = [0, 1], _packed_input_indices = [1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
// CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
%id = "tf.Identity"(%arg30) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource<tensor<f32>>>
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
// CHECK-NEXT: tf_device.return
// CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
"tf_device.launch"() ({
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf_type.string>) -> ()
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
// CHECK: "tf.Yield"
"tf.Yield"(%b1) : (tensor<i32>) -> ()
}) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
// CHECK: %[[DEFAULT:.*]] = "tf.Const"()
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[V0:.*]]: tensor<*x!tf_type.resource<tensor<f32>>>,
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor<!tf_type.resource<tensor<2x!tf_type.string>>>
// CHECK-SAME: %[[ARG2]] as %[[V1:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]])
// CHECK-NEXT: tf_device.return