| // RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file | FileCheck %s |
| |
| // CHECK: llvm.func @_mlir_ciface_tf_alloc_raw |
| // CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8> |
| |
| // CHECK-LABEL: llvm.func @alloc_raw( |
| // CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>, |
| // CHECK-SAME: [[SIZE_0:%.*]]: !llvm.i64, |
| // CHECK-SAME: [[SIZE_2:%.*]]: !llvm.i64) -> [[DESC_TY:!.*]] { |
| func @alloc_raw(%ctx: !tf_framework.op_kernel_context, |
| %size_0 : index , %size_2 : index) -> memref<?x10x?xf32> { |
| %buf = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref<?x10x?xf32> |
| std.return %buf : memref<?x10x?xf32> |
| } |
| // Compute number of elements. |
| // CHECK: [[SIZE_1:%.*]] = llvm.mlir.constant(10 : index) : !llvm.i64 |
| // CHECK: [[NUM_ELEM_0:%.*]] = llvm.mul [[SIZE_0]], [[SIZE_1]] : !llvm.i64 |
| // CHECK: [[NUM_ELEM_1:%.*]] = llvm.mul [[NUM_ELEM_0]], [[SIZE_2]] : !llvm.i64 |
| |
| // Compute the size of an individual element. |
| // CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr<float> |
| // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 |
| // CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]]{{\[}}[[C1]]] |
| // CHECK-SAME: (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float> |
| // CHECK: [[SIZE_OF_FLOAT:%.*]] = llvm.ptrtoint [[GEP]] |
| // CHECK-SAME: !llvm.ptr<float> to !llvm.i64 |
| |
| // Allocate memory. |
| // CHECK: [[NUM_BYTES:%.*]] = llvm.mul [[NUM_ELEM_1]], [[SIZE_OF_FLOAT]] |
| // CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]]) |
| // CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8> |
| |
| // Build memref descriptor. |
| // CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : [[DESC_TY]] |
| |
| // Set pointers and offset. |
| // CHECK: [[FLOAT_PTR:%.*]] = llvm.bitcast [[BYTES_PTR]] |
| // CHECK-SAME: !llvm.ptr<i8> to !llvm.ptr<float> |
| // CHECK: [[DESC_1:%.*]] = llvm.insertvalue [[FLOAT_PTR]], [[DESC_0]][0] |
| // CHECK: [[DESC_2:%.*]] = llvm.insertvalue [[FLOAT_PTR]], [[DESC_1]][1] |
| // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 |
| // CHECK: [[DESC_3:%.*]] = llvm.insertvalue [[C0]], [[DESC_2]][2] : [[DESC_TY]] |
| |
| // Set sizes and strides. |
| // CHECK: [[STRIDE_2:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 |
| // CHECK: [[DESC_4:%.*]] = llvm.insertvalue [[SIZE_2]], [[DESC_3]][3, 2] |
| // CHECK: [[DESC_5:%.*]] = llvm.insertvalue [[STRIDE_2]], [[DESC_4]][4, 2] |
| // CHECK: [[STRIDE_1:%.*]] = llvm.mul [[STRIDE_2]], [[SIZE_2]] : !llvm.i64 |
| // CHECK: [[DESC_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[DESC_5]][3, 1] |
| // CHECK: [[DESC_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[DESC_6]][4, 1] |
| // CHECK: [[STRIDE_0:%.*]] = llvm.mul [[STRIDE_1]], [[SIZE_1]] : !llvm.i64 |
| // CHECK: [[DESC_8:%.*]] = llvm.insertvalue [[SIZE_0]], [[DESC_7]][3, 0] |
| // CHECK: [[DESC_9:%.*]] = llvm.insertvalue [[STRIDE_0]], [[DESC_8]][4, 0] |
| // CHECK: llvm.return [[DESC_9]] : [[DESC_TY]] |
| |
| // ----- |
| |
| // CHECK: llvm.func @_mlir_ciface_tf_dealloc_raw(!llvm.ptr<i8>) |
| |
| // CHECK-LABEL: llvm.func @dealloc_raw( |
| // CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>, |
| func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, |
| %memref : memref<?x10xf32>) { |
| tf_framework.dealloc_raw(%ctx, %memref) : memref<?x10xf32> |
| return |
| } |
| // Extract allocated ptr from the memref descriptor. |
| // CHECK: %{{.*}} = llvm.mlir.undef : [[DESC_TY:!.*]] |
| // CHECK: [[FLOAT_PTR:%.*]] = llvm.extractvalue %{{.*}}[0] : [[DESC_TY]] |
| // CHECK-NEXT: [[VOID_PTR:%.*]] = llvm.bitcast [[FLOAT_PTR]] |
| // CHECK-SAME: !llvm.ptr<float> to !llvm.ptr<i8> |
| |
| // Deallocate. |
| // CHECK: llvm.call @_mlir_ciface_tf_dealloc_raw( |
| // CHECK-SAME: [[TF_CTX]], [[VOID_PTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> () |