blob: 01591caa3e0f62b700514bf27e08ccc98a8555a3 [file] [log] [blame]
// RUN: tf-opt -tf-rewrite-tpu-embedding-ops %s | FileCheck %s
// CHECK-LABEL: func @recv_tpu_embedding_activations
func @recv_tpu_embedding_activations() -> (tensor<512x256xf32>) {
// CHECK: %[[DATA:.*]] = "tf._RecvTPUEmbeddingDeduplicationData"() {config = {{.*}}} : () -> tensor<!tf.variant>
// CHECK: %[[RESULT:.*]] = "tf._RecvTPUEmbeddingActivations"(%[[DATA]]) {config = {{.*}}} : (tensor<!tf.variant>) -> tensor<512x256xf32>
// CHECK: return %[[RESULT]]
// CHECK-NOT: tf.RecvTPUEmbeddingActivations
// CHECK-NOT: tf.SendTPUEmbeddingGradients
%0 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A%\0A\0Dwatches_table\10\F5\03\18\80\02 \01*\0C\1A\00j\05\0D\00\00\80?\88\01\01\10\02\18\80\04 \01(\02"} : () -> tensor<512x256xf32>
return %0 : tensor<512x256xf32>
}
// CHECK-LABEL: func @send_tpu_embedding_gradients
func @send_tpu_embedding_gradients(%arg0: tensor<512x256xf32>) -> () {
// CHECK: %[[DATA:.*]] = "tf._RecvTPUEmbeddingDeduplicationData"() {config = {{.*}}} : () -> tensor<!tf.variant>
// CHECK: "tf._SendTPUEmbeddingGradients"(%arg0, %[[DATA]]) {config = {{.*}}, operand_segment_sizes = dense<[1, 0, 1]> : vector<3xi32>} : (tensor<512x256xf32>, tensor<!tf.variant>) -> ()
// CHECK-NOT: tf.SendTPUEmbeddingGradients
// CHECK-NOT: tf.RecvTPUEmbeddingActivations
"tf.SendTPUEmbeddingGradients"(%arg0) {config = "\0A%\0A\0Dwatches_table\10\F5\03\18\80\02 \01*\0C\1A\00j\05\0D\00\00\80?\88\01\01\10\02\18\80\04 \01(\02", operand_segment_sizes = dense<[1, 0]> : vector<2xi32>} : (tensor<512x256xf32>) -> ()
return
}
// CHECK-LABEL: func @recv_send_ops
func @recv_send_ops() -> () {
// CHECK: %[[DATA:.*]] = "tf._RecvTPUEmbeddingDeduplicationData"()
// CHECK: %[[ACTIVATIONS:.*]] = "tf._RecvTPUEmbeddingActivations"(%[[DATA]])
// CHECK: "tf._SendTPUEmbeddingGradients"(%[[ACTIVATIONS]], %[[DATA]])
%0 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A%\0A\0Dwatches_table\10\F5\03\18\80\02 \01*\0C\1A\00j\05\0D\00\00\80?\88\01\01\10\02\18\80\04 \01(\02"} : () -> tensor<512x256xf32>
"tf.SendTPUEmbeddingGradients"(%0) {config = "\0A%\0A\0Dwatches_table\10\F5\03\18\80\02 \01*\0C\1A\00j\05\0D\00\00\80?\88\01\01\10\02\18\80\04 \01(\02", operand_segment_sizes = dense<[1, 0]> : vector<2xi32>} : (tensor<512x256xf32>) -> ()
return
}
// CHECK-LABEL: func @no_embedding_ops
func @no_embedding_ops(%arg0: tensor<2x2xf32>) -> (tensor<2x2xf32>) {
// CHECK: tf.Add
%0 = "tf.Add"(%arg0, %arg0) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: func @nested_embedding_op
func @nested_embedding_op(%arg0: tensor<i1>, %arg1: tensor<512x256xf32>) -> (tensor<512x256xf32>) {
%1 = "tf.IfRegion"(%arg0) ({
// CHECK: "tf._RecvTPUEmbeddingDeduplicationData"
// CHECK: "tf._RecvTPUEmbeddingActivations"
// CHECK-NOT: tf.RecvTPUEmbeddingActivations
%0 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A%\0A\0Dwatches_table\10\F5\03\18\80\02 \01*\0C\1A\00j\05\0D\00\00\80?\88\01\01\10\02\18\80\04 \01(\02"} : () -> tensor<512x256xf32>
"tf.Yield"(%0) : (tensor<512x256xf32>) -> ()
}, {
"tf.Yield"(%arg1) : (tensor<512x256xf32>) -> ()
}) { is_stateless = true}: (tensor<i1>) -> tensor<512x256xf32>
return %1 : tensor<512x256xf32>
}
// CHECK-LABEL: func @doubly_nested_embedding_op
func @doubly_nested_embedding_op(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<512x256xf32>) -> (tensor<512x256xf32>) {
%2 = "tf.IfRegion"(%arg0) ({
%1 = "tf.IfRegion"(%arg1) ({
// CHECK: "tf._RecvTPUEmbeddingDeduplicationData"
// CHECK: "tf._RecvTPUEmbeddingActivations"
// CHECK-NOT: tf.RecvTPUEmbeddingActivations
%0 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A%\0A\0Dwatches_table\10\F5\03\18\80\02 \01*\0C\1A\00j\05\0D\00\00\80?\88\01\01\10\02\18\80\04 \01(\02"} : () -> tensor<512x256xf32>
"tf.Yield"(%0) : (tensor<512x256xf32>) -> ()
}, {
"tf.Yield"(%arg2) : (tensor<512x256xf32>) -> ()
}) { is_stateless = true}: (tensor<i1>) -> tensor<512x256xf32>
"tf.Yield"(%1) : (tensor<512x256xf32>) -> ()
}, {
"tf.Yield"(%arg2) : (tensor<512x256xf32>) -> ()
}) { is_stateless = true}: (tensor<i1>) -> tensor<512x256xf32>
return %2 : tensor<512x256xf32>
}