Add tfrt_fallback_async.executeop.seq.allocator for using custom allocator on side-effecting TF ops.
This CL adds both kernel ODS definition and kernel implmentation.
PiperOrigin-RevId: 404136975
Change-Id: I91b14ce05c6fec582da4418b38f8da5180d689d7
diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.cc b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.cc
index f252833..36739dd 100644
--- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.cc
+++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.cc
@@ -846,17 +846,18 @@
tensorflow::Allocator* allocator_ = nullptr;
};
-void FallbackAsyncExecuteOpWithAllocator(tfrt::AsyncKernelFrame* frame) {
- FallbackKernelAttributeFrame attr_frame(frame);
-
- const auto& exec_ctx = frame->GetExecutionContext();
+void KernelFallbackExecuteOpCustomAllocatorInternal(
+ llvm::ArrayRef<tfrt::AsyncValue*> args,
+ llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
+ tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
+ const tfrt::ExecutionContext& exec_ctx,
+ const FallbackKernelAttributeFrame& attr_frame) {
const auto* fallback_request_state =
exec_ctx.request_ctx()
->GetDataIfExists<KernelFallbackCompatRequestState>();
if (!fallback_request_state) {
KernelFallbackEmitError(
- exec_ctx, attr_frame.op_name().GetValue(), /*op_chain=*/nullptr,
- frame->GetResults(),
+ exec_ctx, attr_frame.op_name().GetValue(), op_chain, results,
tensorflow::errors::NotFound(
"KernelFallbackCompatRequestState not found in RequestContext."));
return;
@@ -870,10 +871,9 @@
DCHECK_EQ(kernel_runner->op_kernel()->name(),
StripTfPrefix(attr_frame.op_name().GetValue()));
- auto all_args = frame->GetArguments();
- DCHECK_GT(all_args.size(), 0);
- auto* allocator = all_args[0]->get<tensorflow::Allocator*>();
- llvm::ArrayRef<tfrt::AsyncValue*> args = all_args.drop_front();
+ DCHECK_GT(args.size(), 0);
+ auto* allocator = args.front()->get<tensorflow::Allocator*>();
+ args = args.drop_front();
DeviceWithCustomAllocator device_with_custom_allocator(
GetDeviceFromFallbackState(*fallback_request_state, *kernel_runner),
@@ -884,13 +884,40 @@
//
// TODO(b/200575143): Consider allowing async execution and extending the
// lifetime of the wrapping device.
- KernelFallbackExecuteOpInternal(args, frame->GetResults(),
- /*op_chain=*/nullptr, attr_frame, exec_ctx,
+ KernelFallbackExecuteOpInternal(args, results,
+ /*op_chain=*/op_chain, attr_frame, exec_ctx,
*fallback_request_state, *kernel_runner,
/*is_async=*/false,
&device_with_custom_allocator);
}
+void FallbackAsyncExecuteOpWithAllocator(tfrt::AsyncKernelFrame* frame) {
+ auto args = frame->GetArguments();
+ auto results = frame->GetResults();
+ FallbackKernelAttributeFrame attr_frame(frame);
+ KernelFallbackExecuteOpCustomAllocatorInternal(
+ args, results, /*op_chain=*/nullptr, frame->GetExecutionContext(),
+ attr_frame);
+}
+
+void FallbackAsyncExecuteOpSeqWithAllocator(tfrt::AsyncKernelFrame* frame) {
+ auto args = frame->GetArguments();
+ DCHECK_GT(args.size(), 0);
+ tfrt::AsyncValueRef<tfrt::Chain> op_chain(tfrt::FormRef(args.front()));
+ args = args.drop_front();
+
+ auto results = frame->GetResults();
+ DCHECK_GT(results.size(), 0);
+ auto& out_op_chain = results.front();
+ results = results.drop_front();
+
+ FallbackKernelAttributeFrame attr_frame(frame);
+ KernelFallbackExecuteOpCustomAllocatorInternal(
+ args, results, &op_chain, frame->GetExecutionContext(), attr_frame);
+
+ out_op_chain = std::move(op_chain);
+}
+
void FallbackCopyTensorIfSmall(
tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,
tfrt::RemainingResults results) {
@@ -966,6 +993,8 @@
FallbackAsyncExecuteOpSeq);
registry->AddKernel("tfrt_fallback_async.executeop.allocator",
FallbackAsyncExecuteOpWithAllocator);
+ registry->AddKernel("tfrt_fallback_async.executeop.seq.allocator",
+ FallbackAsyncExecuteOpSeqWithAllocator);
registry->AddKernel("tfrt_fallback_async.copy_if_small",
TFRT_KERNEL(FallbackCopyTensorIfSmall));
registry->AddKernel("tfrt_fallback_async.createop",
diff --git a/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.cc b/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.cc
index 7023cb7..b57d31e 100644
--- a/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.cc
+++ b/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.cc
@@ -88,6 +88,9 @@
static LogicalResult verify(ExecuteOpWithAllocator op) {
return fallback_common::VerifyExecuteOpCommon(op);
}
+static LogicalResult verify(ExecuteOpSeqWithAllocator op) {
+ return fallback_common::VerifyExecuteOpCommon(op);
+}
static LogicalResult verify(BatchFunctionOp op) {
return fallback_common::VerifyExecuteOpCommon(op);
}
@@ -166,6 +169,41 @@
parser, builder, result, builder.getType<fallback::TFTensorType>(),
parse_options);
}
+static ParseResult parseExecuteOpSeqWithAllocator(OpAsmParser &parser,
+ OperationState &result) {
+ auto &builder = parser.getBuilder();
+ llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> chain_and_allocator;
+ if (parser.parseOperandList(chain_and_allocator,
+ /*requiredOperandCount=*/2,
+ mlir::OpAsmParser::Delimiter::Paren))
+ return mlir::failure();
+
+ auto &chain = chain_and_allocator[0];
+ auto &allocator = chain_and_allocator[1];
+
+ if (parser.resolveOperands(chain, builder.getType<compiler::ChainType>(),
+ result.operands))
+ return mlir::failure();
+
+ if (parser.resolveOperands(allocator,
+ builder.getType<fallback::TFAllocatorType>(),
+ result.operands))
+ return mlir::failure();
+
+ // The first result is a chain.
+ result.types.push_back(builder.getType<compiler::ChainType>());
+
+ fallback_common::ParseExecuteOpOptions parse_options;
+ parse_options.has_chain = false;
+ parse_options.has_key = true;
+ parse_options.has_device = true;
+ parse_options.has_func_attr = true;
+ parse_options.has_cost = true;
+
+ return fallback_common::ParseExecuteOpCommon(
+ parser, builder, result, builder.getType<fallback::TFTensorType>(),
+ parse_options);
+}
static ParseResult parseBatchFunctionOp(OpAsmParser &parser,
OperationState &result) {
@@ -267,6 +305,18 @@
if (!op.results().empty()) p << " : " << op.results().size();
}
+static void print(OpAsmPrinter &p, ExecuteOpSeqWithAllocator op) {
+ p << "(" << op.in_op_chain() << ", " << op.allocator() << ") key("
+ << op->getAttrOfType<mlir::IntegerAttr>("op_key").getInt() << ") cost("
+ << op->getAttrOfType<mlir::IntegerAttr>("_tfrt_cost").getInt()
+ << ") device(" << op->getAttr("device") << ") " << op->getAttr("op_name")
+ << '(' << op.operands() << ')';
+
+ fallback_common::PrintExecuteOpCommon(p, op);
+ fallback_common::PrintExecuteOpFuncAttribute(p, op);
+ if (!op.results().empty()) p << " : " << op.results().size();
+}
+
static void print(OpAsmPrinter &p, BatchFunctionOp op) {
p << "(" << op.in_op_chain() << ") " << op->getAttr("f") << " ("
<< op.operands() << ") ";
diff --git a/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.td b/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.td
index bcad578..6b553f6 100644
--- a/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.td
+++ b/tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.td
@@ -506,4 +506,39 @@
let parser = [{ return tfrt::fallback_async::parse$cppClass(parser, result); }];
}
+def ExecuteOpSeqWithAllocator : FallbackAsync_Op<"executeop.seq.allocator",
+ [CoreRT_TypedAttributeTrait, TFRT_CostFunctionInterface, TFRT_AttrCostTrait]> {
+ let summary = "The sequenced version of Fallback ExecuteOp with custom allocator";
+ let description = [{
+ Similar to ExecuteOpSeq but takes a custom allocator for allocating output tensors.
+
+ Example:
+ %op_ch_out, %res = tfrt_fallback_async.executeop.seq.allocator(%op_ch_in, %allocator)
+ key(0) cost(100) device("/CPU:0") "some.op"(%arg) : 1
+ }];
+
+ let arguments = (ins
+ TFRT_ChainType:$in_op_chain,
+ TFAllocatorType:$allocator,
+ Variadic<TFTensorType>:$operands,
+ StrAttr:$device,
+ ArrayAttr:$op_attrs,
+ // TODO(b/173025975): consider using DictionaryAttr after we support
+ // BEF conversion for this type.
+ ArrayAttr:$op_func_attrs,
+ I64Attr:$op_key,
+ StrAttr:$op_name,
+ I64Attr:$_tfrt_cost
+ );
+
+ let results = (outs
+ TFRT_ChainType:$out_op_chain,
+ Variadic<TFTensorType>:$results
+ );
+
+ let verifier = [{ return tfrt::fallback_async::verify(*this); }];
+ let printer = [{ return tfrt::fallback_async::print(p, *this); }];
+ let parser = [{ return tfrt::fallback_async::parse$cppClass(parser, result); }];
+}
+
#endif