blob: 709e95fb2a792615fb5a1ea6ee4a4accf74345e8 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef RT_PASSES
#define RT_PASSES
include "mlir/Pass/PassBase.td"
def ConvertToEntrypoint : Pass<"rt-convert-to-entrypoint", "mlir::ModuleOp"> {
let summary = "Converts function(s) to Xla runtime entrypoint(s)";
let description = [{
Converts function with a `xla.entrypoint` unit attribute to an Xla
entrypoint, i.e.:
- first argument is an `!rt.execution_context`
- all results returned via the `rt.set_result` operation
- failed asserts set the results error via the `rt.set_error` operation
- function calls marked with `rt.custom_call` attribute (on the callee)
converted to the `rt.custom_call` operations (or `rt.direct_custom_call`
attribute for direct custom calls)
See the `ir/runtime/rt_ops.td` to find how Xla executable returns results
and errors usin the runtime APIs.
When converting function call to the custom call operation, custom call
attributes will be a union of custom call function declaration attributes,
and the call operation attributes. Function call attributes will override
any attributes defined by the custom call function declaration.
Example:
```mlir
func @custom_call() -> memref<?xf32>
attributes { rt.custom_call = "custom_call", attr = <value> }
func @compute(...) -> memref<?xf32> attributes { xla.entrypoint } {
%0 = ... : i1
assert %0, "Oops"
%1 = call @custom_call() { attr = <new_value> }: () -> memref<?xf32>
return %1
}
```
converted to:
```mlir
func @compute(%ctx: !rt.execution_context, ...) {
%0 = ... : i1
cond_br %0, ^ok0, ^err0
^ok0:
%status, %1 = rt.custom_call %ctx, "custom_call"()
{ attr = <new value> } : () -> memref<?xf32>
%success = rt.is_ok %status : !rt.status
cond_br %success, ^ok1, ^err1
^ok1:
rt.set_output %ctx, 0, %1 : memref<xf32>
return
^err0:
rt.set_error %ctx, "Oops"
return;
^err1:
rt.set_error %ctx, "Custom call failed"
return;
}
```
}];
let constructor = "xla::runtime::CreateConvertToEntrypoint()";
let dependentDialects = ["xla::runtime::RuntimeDialect"];
}
#endif // RT_PASSES