blob: 64f26af12a87d4dfc6515c037194e9d48941d547 [file] [log] [blame]
// Copyright 2022 The TensorFlow Runtime Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// RUN: lhlo-tfrt-opt %s -gpu-binary-to-jitrt | FileCheck %s
// Check that all gpu dialect `launch_func` operations gets converted to
// function calls bound to jitrt custom calls.
module attributes {gpu.container_module} {
// CHECK-NOT: gpu.module
gpu.module @gpu_module attributes {binary = "kernel binary"} {
gpu.func @main(%arg0: memref<4x4xf32>, %arg1: memref<4x4xf32>) kernel {
gpu.return
}
}
// CHECK: @func(
// CHECK: %[[ARG0:.*]]: memref<4x4xf32>,
// CHECK: %[[ARG1:.*]]: memref<4x4xf32>
// CHECK: )
func.func @func(%arg0: memref<4x4xf32>, %arg1: memref<4x4xf32>) {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C5:.*]] = arith.constant 5 : index
// CHECK: %[[C6:.*]] = arith.constant 6 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
// CHECK: %[[I1:.*]] = arith.index_cast %[[C1]] : index to i32
// CHECK: %[[I2:.*]] = arith.index_cast %[[C2]] : index to i32
// CHECK: %[[I3:.*]] = arith.index_cast %[[C3]] : index to i32
// CHECK: %[[I4:.*]] = arith.index_cast %[[C4]] : index to i32
// CHECK: %[[I5:.*]] = arith.index_cast %[[C5]] : index to i32
// CHECK: %[[I6:.*]] = arith.index_cast %[[C6]] : index to i32
// CHECK: call @[[LAUNCH:[_a-z]+]](%[[I1]], %[[I2]], %[[I3]], %[[I4]],
// CHECK-SAME: %[[I5]], %[[I6]], %[[ARG0]], %[[ARG1]])
// CHECK-DAG: kernel = "main"
// CHECK-DAG: ptx = "kernel binary"
gpu.launch_func @gpu_module::@main
blocks in (%c1, %c2, %c3)
threads in (%c4, %c5, %c6)
args(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>)
func.return
}
// CHECK: func private @[[LAUNCH]](i32, i32, i32, i32, i32, i32,
// CHECK-SAME: memref<4x4xf32>, memref<4x4xf32>)
// CHECK-SAME: attributes {rt.direct_custom_call = "xla.gpu.func.launch"}
}