[XLA][MLIR] Add DynamicMemRefCastOp to LHLO with lowering to LLVM.

PiperOrigin-RevId: 314207315
Change-Id: Idb80930464e9f1b87cbafccba9cde86e28afc093
diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc
index 9e5fa4f..6f9b393 100644
--- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc
@@ -69,6 +69,22 @@
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// DynamicMemRefCastOp
+//===----------------------------------------------------------------------===//
+
+Value DynamicMemRefCastOp::getViewSource() {
+  return *getODSOperands(0).begin();
+}
+
+static LogicalResult Verify(DynamicMemRefCastOp op) {
+  // Check if `sizes` and `strides` args are compatible with the result type.
+  if (op.sizes().size() != op.getType().getRank())
+    return op.emitOpError(
+        "`sizes` args count must be equal to the rank of the output memref");
+  return success();
+}
+
 #define GET_OP_CLASSES
 #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc.inc"
 
diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
index f3b5d2d..d9f3648 100644
--- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
@@ -305,6 +305,55 @@
 }
 
 //===----------------------------------------------------------------------===//
+// DynamicMemRefCastOp
+//===----------------------------------------------------------------------===//
+
+def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
+    [SameVariadicOperandSize, NoSideEffect,
+     DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
+  let summary = "dynamic memref cast operation";
+  let description = [{
+    Change sizes and strides of a memref using the values computed in runtime.
+
+    Example:
+    ```mlir
+    %buf_transformed =
+        xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
+        : memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
+    // The result of the op is a type-erased memref with `[%size_X, %size_Y]`
+    // shape and `[%step_X, %step_Y]` strides. The offset will be inherited
+    // from the input.
+    ```
+  }];
+
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", []>:$operand,
+    Variadic<Index>:$sizes,
+    Variadic<Index>:$strides
+  );
+  let results = (outs Res<LHLO_Buffer, "", []>:$result);
+
+  let builders = [OpBuilder<
+    "OpBuilder &builder, OperationState &result, MemRefType resultType, " #
+    "Value operand, ValueRange sizes, ValueRange strides", [{
+       result.addOperands(operand);
+       result.addOperands(sizes);
+       result.addOperands(strides);
+       result.types.push_back(resultType);
+     }]>];
+
+  let extraClassDeclaration = [{
+    MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+  }];
+
+  let verifier = [{ return Verify(*this); }];
+  let assemblyFormat = [{
+    $operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->`
+    type($result)
+  }];
+}
+
+//===----------------------------------------------------------------------===//
 // XLA Other op definitions.
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir
index 0202f39..16aad8f 100644
--- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir
@@ -29,3 +29,37 @@
 // CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C5_]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE_2D]]
 // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
 // CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE_2D]]
+
+// -----
+
+// CHECK-LABEL: func @dynamic_memref_cast
+func @dynamic_memref_cast(%buf : memref<?x?xf32>) {
+  %size_X = constant 10 : index
+  %size_Y = constant 50 : index
+  %stride_X = constant 1 : index
+  %stride_Y = constant 0 : index
+  %0 = xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
+        : memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
+  return
+}
+// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64
+// CHECK: %[[C50:.*]] = llvm.mlir.constant(50 : index) : !llvm.i64
+// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+
+// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE:!.*]]
+
+// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm<"float*"> to !llvm<"float*">
+// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE]]
+
+// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm<"float*"> to !llvm<"float*">
+// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE]]
+
+// CHECK: %[[SRC_OFFSET:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][2] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[SRC_OFFSET]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C50]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE]]
+// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C0]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE]]
diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir
index cdae187..1a44428 100644
--- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir
@@ -226,3 +226,25 @@
            : memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]>
   return
 }
+
+// -----
+
+func @dynamic_memref_cast(%in: memref<?xf32>) {
+  %size = constant 10 : index
+  %step = constant 1 : index
+  %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step]
+           : memref<?xf32> -> memref<?xf32, offset: 0, strides: [?]>
+  return
+}
+// CHECK-LABEL: func @dynamic_memref_cast
+
+// -----
+
+func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
+  // expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}}
+  %size = constant 10 : index
+  %step = constant 1 : index
+  %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step]
+           : memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
+  return
+}
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc
index 083365c..385e085 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc
@@ -76,11 +76,60 @@
   }
 };
 
+struct DynamicMemRefCastOpConverter
+    : public ConvertOpToLLVMPattern<DynamicMemRefCastOp> {
+  using ConvertOpToLLVMPattern<DynamicMemRefCastOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult matchAndRewrite(
+      Operation *op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    auto cast_op = cast<DynamicMemRefCastOp>(op);
+
+    DynamicMemRefCastOpOperandAdaptor operands_adaptor(operands);
+    MemRefDescriptor sourceMemRef(operands_adaptor.operand());
+
+    MemRefType targetMemRefType =
+        cast_op.getResult().getType().cast<MemRefType>();
+    auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
+                                      .dyn_cast_or_null<LLVM::LLVMType>();
+    if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+      return failure();
+    // Create descriptor.
+    auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
+    Type llvmTargetElementTy = desc.getElementType();
+    // Set allocated ptr.
+    Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
+    allocated =
+        rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
+    desc.setAllocatedPtr(rewriter, loc, allocated);
+    // Set aligned ptr.
+    Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
+    ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
+    desc.setAlignedPtr(rewriter, loc, ptr);
+    // Copy offset of `sourceMemRef`.
+    desc.setOffset(rewriter, loc, sourceMemRef.offset(rewriter, loc));
+
+    // Fill size and stride descriptors in memref.
+    if (!cast_op.sizes().empty()) {
+      auto sizes = operands_adaptor.sizes();
+      auto strides = operands_adaptor.strides();
+      for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
+        desc.setSize(rewriter, loc, i, sizes[i]);
+        desc.setStride(rewriter, loc, i, strides[i]);
+      }
+    }
+    rewriter.replaceOp(op, {desc});
+    return success();
+  }
+};
+
 }  // namespace
 
 void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter,
                                           OwningRewritePatternList *patterns) {
-  patterns->insert<StaticMemRefCastOpConverter>(*converter);
+  patterns->insert<DynamicMemRefCastOpConverter, StaticMemRefCastOpConverter>(
+      *converter);
 }
 
 }  // namespace xla_lhlo