[XLA][MLIR] Add DynamicMemRefCastOp to LHLO with lowering to LLVM.
PiperOrigin-RevId: 314207315
Change-Id: Idb80930464e9f1b87cbafccba9cde86e28afc093
diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc
index 9e5fa4f..6f9b393 100644
--- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc
@@ -69,6 +69,22 @@
return success();
}
+//===----------------------------------------------------------------------===//
+// DynamicMemRefCastOp
+//===----------------------------------------------------------------------===//
+
+Value DynamicMemRefCastOp::getViewSource() {
+ return *getODSOperands(0).begin();
+}
+
+static LogicalResult Verify(DynamicMemRefCastOp op) {
+ // Check if `sizes` and `strides` args are compatible with the result type.
+ if (op.sizes().size() != op.getType().getRank())
+ return op.emitOpError(
+ "`sizes` args count must be equal to the rank of the output memref");
+ return success();
+}
+
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc.inc"
diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
index f3b5d2d..d9f3648 100644
--- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
@@ -305,6 +305,55 @@
}
//===----------------------------------------------------------------------===//
+// DynamicMemRefCastOp
+//===----------------------------------------------------------------------===//
+
+def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
+ [SameVariadicOperandSize, NoSideEffect,
+ DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
+ let summary = "dynamic memref cast operation";
+ let description = [{
+ Change sizes and strides of a memref using the values computed in runtime.
+
+ Example:
+ ```mlir
+ %buf_transformed =
+ xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
+ : memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
+ // The result of the op is a type-erased memref with `[%size_X, %size_Y]`
+ // shape and `[%step_X, %step_Y]` strides. The offset will be inherited
+ // from the input.
+ ```
+ }];
+
+ let arguments = (ins
+ Arg<LHLO_Buffer, "", []>:$operand,
+ Variadic<Index>:$sizes,
+ Variadic<Index>:$strides
+ );
+ let results = (outs Res<LHLO_Buffer, "", []>:$result);
+
+ let builders = [OpBuilder<
+ "OpBuilder &builder, OperationState &result, MemRefType resultType, " #
+ "Value operand, ValueRange sizes, ValueRange strides", [{
+ result.addOperands(operand);
+ result.addOperands(sizes);
+ result.addOperands(strides);
+ result.types.push_back(resultType);
+ }]>];
+
+ let extraClassDeclaration = [{
+ MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+ }];
+
+ let verifier = [{ return Verify(*this); }];
+ let assemblyFormat = [{
+ $operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->`
+ type($result)
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// XLA Other op definitions.
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir
index 0202f39..16aad8f 100644
--- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir
@@ -29,3 +29,37 @@
// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C5_]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE_2D]]
+
+// -----
+
+// CHECK-LABEL: func @dynamic_memref_cast
+func @dynamic_memref_cast(%buf : memref<?x?xf32>) {
+ %size_X = constant 10 : index
+ %size_Y = constant 50 : index
+ %stride_X = constant 1 : index
+ %stride_Y = constant 0 : index
+ %0 = xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
+ : memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
+ return
+}
+// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64
+// CHECK: %[[C50:.*]] = llvm.mlir.constant(50 : index) : !llvm.i64
+// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+
+// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE:!.*]]
+
+// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm<"float*"> to !llvm<"float*">
+// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE]]
+
+// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm<"float*"> to !llvm<"float*">
+// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE]]
+
+// CHECK: %[[SRC_OFFSET:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][2] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[SRC_OFFSET]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C50]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C0]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE]]
diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir
index cdae187..1a44428 100644
--- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir
@@ -226,3 +226,25 @@
: memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]>
return
}
+
+// -----
+
+func @dynamic_memref_cast(%in: memref<?xf32>) {
+ %size = constant 10 : index
+ %step = constant 1 : index
+ %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step]
+ : memref<?xf32> -> memref<?xf32, offset: 0, strides: [?]>
+ return
+}
+// CHECK-LABEL: func @dynamic_memref_cast
+
+// -----
+
+func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
+ // expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}}
+ %size = constant 10 : index
+ %step = constant 1 : index
+ %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step]
+ : memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
+ return
+}
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc
index 083365c..385e085 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc
@@ -76,11 +76,60 @@
}
};
+struct DynamicMemRefCastOpConverter
+ : public ConvertOpToLLVMPattern<DynamicMemRefCastOp> {
+ using ConvertOpToLLVMPattern<DynamicMemRefCastOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto cast_op = cast<DynamicMemRefCastOp>(op);
+
+ DynamicMemRefCastOpOperandAdaptor operands_adaptor(operands);
+ MemRefDescriptor sourceMemRef(operands_adaptor.operand());
+
+ MemRefType targetMemRefType =
+ cast_op.getResult().getType().cast<MemRefType>();
+ auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
+ .dyn_cast_or_null<LLVM::LLVMType>();
+ if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+ return failure();
+ // Create descriptor.
+ auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
+ Type llvmTargetElementTy = desc.getElementType();
+ // Set allocated ptr.
+ Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
+ allocated =
+ rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
+ desc.setAllocatedPtr(rewriter, loc, allocated);
+ // Set aligned ptr.
+ Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
+ ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
+ desc.setAlignedPtr(rewriter, loc, ptr);
+ // Copy offset of `sourceMemRef`.
+ desc.setOffset(rewriter, loc, sourceMemRef.offset(rewriter, loc));
+
+ // Fill size and stride descriptors in memref.
+ if (!cast_op.sizes().empty()) {
+ auto sizes = operands_adaptor.sizes();
+ auto strides = operands_adaptor.strides();
+ for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
+ desc.setSize(rewriter, loc, i, sizes[i]);
+ desc.setStride(rewriter, loc, i, strides[i]);
+ }
+ }
+ rewriter.replaceOp(op, {desc});
+ return success();
+ }
+};
+
} // namespace
void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter,
OwningRewritePatternList *patterns) {
- patterns->insert<StaticMemRefCastOpConverter>(*converter);
+ patterns->insert<DynamicMemRefCastOpConverter, StaticMemRefCastOpConverter>(
+ *converter);
}
} // namespace xla_lhlo