Add tfrt_fallback_async.executeop.seq.allocator for using custom allocator on side-effecting TF ops.

This CL adds both kernel ODS definition and kernel implmentation.

PiperOrigin-RevId: 404136975
Change-Id: I91b14ce05c6fec582da4418b38f8da5180d689d7
diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.cc b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.cc
index f252833..36739dd 100644
--- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.cc
+++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.cc
@@ -846,17 +846,18 @@
   tensorflow::Allocator* allocator_ = nullptr;
 };
 
-void FallbackAsyncExecuteOpWithAllocator(tfrt::AsyncKernelFrame* frame) {
-  FallbackKernelAttributeFrame attr_frame(frame);
-
-  const auto& exec_ctx = frame->GetExecutionContext();
+void KernelFallbackExecuteOpCustomAllocatorInternal(
+    llvm::ArrayRef<tfrt::AsyncValue*> args,
+    llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
+    tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
+    const tfrt::ExecutionContext& exec_ctx,
+    const FallbackKernelAttributeFrame& attr_frame) {
   const auto* fallback_request_state =
       exec_ctx.request_ctx()
           ->GetDataIfExists<KernelFallbackCompatRequestState>();
   if (!fallback_request_state) {
     KernelFallbackEmitError(
-        exec_ctx, attr_frame.op_name().GetValue(), /*op_chain=*/nullptr,
-        frame->GetResults(),
+        exec_ctx, attr_frame.op_name().GetValue(), op_chain, results,
         tensorflow::errors::NotFound(
             "KernelFallbackCompatRequestState not found in RequestContext."));
     return;
@@ -870,10 +871,9 @@
   DCHECK_EQ(kernel_runner->op_kernel()->name(),
             StripTfPrefix(attr_frame.op_name().GetValue()));
 
-  auto all_args = frame->GetArguments();
-  DCHECK_GT(all_args.size(), 0);
-  auto* allocator = all_args[0]->get<tensorflow::Allocator*>();
-  llvm::ArrayRef<tfrt::AsyncValue*> args = all_args.drop_front();
+  DCHECK_GT(args.size(), 0);
+  auto* allocator = args.front()->get<tensorflow::Allocator*>();
+  args = args.drop_front();
 
   DeviceWithCustomAllocator device_with_custom_allocator(
       GetDeviceFromFallbackState(*fallback_request_state, *kernel_runner),
@@ -884,13 +884,40 @@
   //
   // TODO(b/200575143): Consider allowing async execution and extending the
   // lifetime of the wrapping device.
-  KernelFallbackExecuteOpInternal(args, frame->GetResults(),
-                                  /*op_chain=*/nullptr, attr_frame, exec_ctx,
+  KernelFallbackExecuteOpInternal(args, results,
+                                  /*op_chain=*/op_chain, attr_frame, exec_ctx,
                                   *fallback_request_state, *kernel_runner,
                                   /*is_async=*/false,
                                   &device_with_custom_allocator);
 }
 
+void FallbackAsyncExecuteOpWithAllocator(tfrt::AsyncKernelFrame* frame) {
+  auto args = frame->GetArguments();
+  auto results = frame->GetResults();
+  FallbackKernelAttributeFrame attr_frame(frame);
+  KernelFallbackExecuteOpCustomAllocatorInternal(
+      args, results, /*op_chain=*/nullptr, frame->GetExecutionContext(),
+      attr_frame);
+}
+
+void FallbackAsyncExecuteOpSeqWithAllocator(tfrt::AsyncKernelFrame* frame) {
+  auto args = frame->GetArguments();
+  DCHECK_GT(args.size(), 0);
+  tfrt::AsyncValueRef<tfrt::Chain> op_chain(tfrt::FormRef(args.front()));
+  args = args.drop_front();
+
+  auto results = frame->GetResults();
+  DCHECK_GT(results.size(), 0);
+  auto& out_op_chain = results.front();
+  results = results.drop_front();
+
+  FallbackKernelAttributeFrame attr_frame(frame);
+  KernelFallbackExecuteOpCustomAllocatorInternal(
+      args, results, &op_chain, frame->GetExecutionContext(), attr_frame);
+
+  out_op_chain = std::move(op_chain);
+}
+
 void FallbackCopyTensorIfSmall(
     tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,
     tfrt::RemainingResults results) {
@@ -966,6 +993,8 @@
                       FallbackAsyncExecuteOpSeq);
   registry->AddKernel("tfrt_fallback_async.executeop.allocator",
                       FallbackAsyncExecuteOpWithAllocator);
+  registry->AddKernel("tfrt_fallback_async.executeop.seq.allocator",
+                      FallbackAsyncExecuteOpSeqWithAllocator);
   registry->AddKernel("tfrt_fallback_async.copy_if_small",
                       TFRT_KERNEL(FallbackCopyTensorIfSmall));
   registry->AddKernel("tfrt_fallback_async.createop",
diff --git a/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.cc b/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.cc
index 7023cb7..b57d31e 100644
--- a/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.cc
+++ b/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.cc
@@ -88,6 +88,9 @@
 static LogicalResult verify(ExecuteOpWithAllocator op) {
   return fallback_common::VerifyExecuteOpCommon(op);
 }
+static LogicalResult verify(ExecuteOpSeqWithAllocator op) {
+  return fallback_common::VerifyExecuteOpCommon(op);
+}
 static LogicalResult verify(BatchFunctionOp op) {
   return fallback_common::VerifyExecuteOpCommon(op);
 }
@@ -166,6 +169,41 @@
       parser, builder, result, builder.getType<fallback::TFTensorType>(),
       parse_options);
 }
+static ParseResult parseExecuteOpSeqWithAllocator(OpAsmParser &parser,
+                                                  OperationState &result) {
+  auto &builder = parser.getBuilder();
+  llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> chain_and_allocator;
+  if (parser.parseOperandList(chain_and_allocator,
+                              /*requiredOperandCount=*/2,
+                              mlir::OpAsmParser::Delimiter::Paren))
+    return mlir::failure();
+
+  auto &chain = chain_and_allocator[0];
+  auto &allocator = chain_and_allocator[1];
+
+  if (parser.resolveOperands(chain, builder.getType<compiler::ChainType>(),
+                             result.operands))
+    return mlir::failure();
+
+  if (parser.resolveOperands(allocator,
+                             builder.getType<fallback::TFAllocatorType>(),
+                             result.operands))
+    return mlir::failure();
+
+  // The first result is a chain.
+  result.types.push_back(builder.getType<compiler::ChainType>());
+
+  fallback_common::ParseExecuteOpOptions parse_options;
+  parse_options.has_chain = false;
+  parse_options.has_key = true;
+  parse_options.has_device = true;
+  parse_options.has_func_attr = true;
+  parse_options.has_cost = true;
+
+  return fallback_common::ParseExecuteOpCommon(
+      parser, builder, result, builder.getType<fallback::TFTensorType>(),
+      parse_options);
+}
 
 static ParseResult parseBatchFunctionOp(OpAsmParser &parser,
                                         OperationState &result) {
@@ -267,6 +305,18 @@
   if (!op.results().empty()) p << " : " << op.results().size();
 }
 
+static void print(OpAsmPrinter &p, ExecuteOpSeqWithAllocator op) {
+  p << "(" << op.in_op_chain() << ", " << op.allocator() << ") key("
+    << op->getAttrOfType<mlir::IntegerAttr>("op_key").getInt() << ") cost("
+    << op->getAttrOfType<mlir::IntegerAttr>("_tfrt_cost").getInt()
+    << ") device(" << op->getAttr("device") << ") " << op->getAttr("op_name")
+    << '(' << op.operands() << ')';
+
+  fallback_common::PrintExecuteOpCommon(p, op);
+  fallback_common::PrintExecuteOpFuncAttribute(p, op);
+  if (!op.results().empty()) p << " : " << op.results().size();
+}
+
 static void print(OpAsmPrinter &p, BatchFunctionOp op) {
   p << "(" << op.in_op_chain() << ") " << op->getAttr("f") << " ("
     << op.operands() << ") ";
diff --git a/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.td b/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.td
index bcad578..6b553f6 100644
--- a/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.td
+++ b/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.td
@@ -506,4 +506,39 @@
   let parser = [{ return tfrt::fallback_async::parse$cppClass(parser, result); }];
 }
 
+def ExecuteOpSeqWithAllocator : FallbackAsync_Op<"executeop.seq.allocator",
+    [CoreRT_TypedAttributeTrait, TFRT_CostFunctionInterface, TFRT_AttrCostTrait]> {
+  let summary = "The sequenced version of Fallback ExecuteOp with custom allocator";
+  let description = [{
+    Similar to ExecuteOpSeq but takes a custom allocator for allocating output tensors.
+
+    Example:
+      %op_ch_out, %res = tfrt_fallback_async.executeop.seq.allocator(%op_ch_in, %allocator)
+        key(0) cost(100) device("/CPU:0") "some.op"(%arg) : 1
+  }];
+
+  let arguments = (ins
+    TFRT_ChainType:$in_op_chain,
+    TFAllocatorType:$allocator,
+    Variadic<TFTensorType>:$operands,
+    StrAttr:$device,
+    ArrayAttr:$op_attrs,
+    // TODO(b/173025975): consider using DictionaryAttr after we support
+    // BEF conversion for this type.
+    ArrayAttr:$op_func_attrs,
+    I64Attr:$op_key,
+    StrAttr:$op_name,
+    I64Attr:$_tfrt_cost
+  );
+
+  let results = (outs
+    TFRT_ChainType:$out_op_chain,
+    Variadic<TFTensorType>:$results
+  );
+
+  let verifier = [{ return tfrt::fallback_async::verify(*this); }];
+  let printer = [{ return tfrt::fallback_async::print(p, *this); }];
+  let parser = [{ return tfrt::fallback_async::parse$cppClass(parser, result); }];
+}
+
 #endif