| // RUN: tf-opt -xla-legalize-tf %s | FileCheck %s |
| |
| //===----------------------------------------------------------------------===// |
| // BatchNorm op legalizations. |
| //===----------------------------------------------------------------------===// |
| |
| // CHECK-LABEL: fusedBatchNorm_notraining |
| func @fusedBatchNorm_notraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { |
| // CHECK-NEXT: "xla_hlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> |
| %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) |
| return %0#0 : tensor<8x8x8x8xf32> |
| } |
| |
| // CHECK-LABEL: fusedBatchNorm_training |
| func @fusedBatchNorm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { |
| // TODO(riverriddle) Support training. |
| // CHECK-NEXT: "tf.FusedBatchNorm" |
| %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) |
| return %0#0 : tensor<8x8x8x8xf32> |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Bias op legalizations. |
| //===----------------------------------------------------------------------===// |
| |
| // CHECK-LABEL: func @biasAdd_NHWC |
| func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { |
| // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} |
| %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> |
| return %0 : tensor<1x32x10x32xi32> |
| } |
| |
| // CHECK-LABEL: func @biasAdd_NCHW |
| func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { |
| // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} |
| %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> |
| return %0 : tensor<1x32x10x32xi32> |
| } |
| |
| // In the next two tests, the replacement fails because the bias dimension does |
| // not have the same size as the feature dimension. |
| |
| // CHECK-LABEL: func @biasAdd_NHWC_invalid |
| func @biasAdd_NHWC_invalid(%arg0: tensor<1x32x10x2xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x2xi32> { |
| // CHECK-NOT: xla_hlo.add |
| %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x2xi32>, tensor<32xi32>) -> tensor<1x32x10x2xi32> |
| return %0 : tensor<1x32x10x2xi32> |
| } |
| |
| // CHECK-LABEL: func @biasAdd_NCHW_invalid |
| func @biasAdd_NCHW_invalid(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x10x10x32xi32> { |
| // CHECK-NOT: xla_hlo.add |
| %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x10x10x32xi32>, tensor<32xi32>) -> tensor<1x10x10x32xi32> |
| return %0 : tensor<1x10x10x32xi32> |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Binary op legalizations. |
| //===----------------------------------------------------------------------===// |
| |
| // CHECK-LABEL: func @add |
| func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { |
| // CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg0 : tensor<2xi32> |
| // CHECK-NEXT: return %0 : tensor<2xi32> |
| %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> |
| return %0: tensor<2xi32> |
| } |
| |
| // CHECK-LABEL: func @broadcast_add |
| func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { |
| // CHECK-NEXT: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} |
| %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> |
| return %0: tensor<1x2xi32> |
| } |
| |
| // CHECK-LABEL: func @broadcast_multi_dim_add |
| func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { |
| // CHECK-NEXT: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} |
| %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> |
| return %0: tensor<4x4x4x4xi32> |
| } |
| |
| // CHECK-LABEL: func @div |
| func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { |
| // CHECK-NEXT: %0 = xla_hlo.div %arg0, %arg0 : tensor<2xi32> |
| // CHECK-NEXT: return %0 : tensor<2xi32> |
| %0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> |
| return %0: tensor<2xi32> |
| } |
| |
| // CHECK-LABEL: func @broadcast_div |
| func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { |
| // CHECK-NEXT: "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} |
| %0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> |
| return %0: tensor<1x2xi32> |
| } |
| |
| // CHECK-LABEL: func @mul |
| func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { |
| // CHECK-NEXT: %0 = xla_hlo.mul %arg0, %arg0 : tensor<2xi32> |
| // CHECK-NEXT: return %0 : tensor<2xi32> |
| %0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> |
| return %0: tensor<2xi32> |
| } |
| |
| // CHECK-LABEL: func @broadcast_mul |
| func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { |
| // CHECK-NEXT: "xla_hlo.mul"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} |
| %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> |
| return %0: tensor<1x2xi32> |
| } |
| |
| // CHECK-LABEL: func @real_div |
| func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { |
| // CHECK-NEXT: %0 = xla_hlo.div %arg0, %arg0 : tensor<2xi32> |
| %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> |
| return %0: tensor<2xi32> |
| } |
| |
| // CHECK-LABEL: func @broadcast_real_div |
| func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { |
| // CHECK-NEXT: "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} |
| %0 = "tf.RealDiv"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> |
| return %0: tensor<1x2xi32> |
| } |
| |
| // CHECK-LABEL: func @sub |
| func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { |
| // CHECK-NEXT: %0 = xla_hlo.sub %arg0, %arg0 : tensor<2xi32> |
| // CHECK-NEXT: return %0 : tensor<2xi32> |
| %0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> |
| return %0: tensor<2xi32> |
| } |
| |
| // CHECK-LABEL: func @broadcast_sub |
| func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { |
| // CHECK-NEXT: "xla_hlo.sub"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} |
| %0 = "tf.Sub"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> |
| return %0: tensor<1x2xi32> |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Identity op legalizations. |
| //===----------------------------------------------------------------------===// |
| |
| // CHECK-LABEL: func @identity |
| func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> { |
| // CHECK-NEXT: return %arg0 : tensor<1xi32> |
| %0 = "tf.Identity"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> |
| return %0: tensor<1xi32> |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Nullary op legalizations. |
| //===----------------------------------------------------------------------===// |
| |
| // CHECK-LABEL: @const |
| func @const() -> tensor<2xi32> { |
| // tf.Const is legalized into xla_hlo.constant, which is folded into constant. |
| |
| // CHECK-NEXT: constant dense<0> : tensor<2xi32> |
| %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>) |
| return %0: tensor<2xi32> |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Relu op legalizations. |
| //===----------------------------------------------------------------------===// |
| |
| // CHECK-LABEL: func @relu |
| func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { |
| // CHECK-NEXT: %cst = constant dense<0> : tensor<1xi32> |
| // CHECK-NEXT: %0 = xla_hlo.max %arg0, %cst : tensor<1xi32> |
| %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> |
| return %0: tensor<1xi32> |
| } |
| |
| // CHECK-LABEL: func @relu6 |
| func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { |
| // CHECK-NEXT: %cst = constant dense<0> : tensor<1xi32> |
| // CHECK-NEXT: %cst_0 = constant dense<6> : tensor<1xi32> |
| // CHECK-NEXT: %0 = "xla_hlo.clamp"(%cst, %arg0, %cst_0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> |
| %0 = "tf.Relu6"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> |
| return %0: tensor<1xi32> |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Unary op legalizations. |
| //===----------------------------------------------------------------------===// |
| |
| // CHECK-LABEL: reshape |
| func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<1x1xf32> { |
| // CHECK: %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x1xf32> |
| %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x1xf32> |
| return %0 : tensor<1x1xf32> |
| } |
| |
| // CHECK-LABEL: reshape_dynamic |
| func @reshape_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> { |
| // CHECK: %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32> |
| %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32> |
| return %0 : tensor<?x?xf32> |
| } |
| |
| // CHECK-LABEL: squeeze |
| func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> { |
| // CHECK-NEXT: %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> |
| %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> |
| return %0 : tensor<1x10xf32> |
| } |
| |
| // CHECK-LABEL: squeeze_dynamic |
| func @squeeze_dynamic(%arg0: tensor<?x10xf32>) -> tensor<*xf32> { |
| // CHECK-NEXT: %0 = "tf.Squeeze"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32> |
| %0 = "tf.Squeeze"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32> |
| return %0 : tensor<*xf32> |
| } |