blob: f0c9672b5c28cc92f3cbd657c423676cd297602a [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation definition file for TensorFlow.
//
// This file contains TensorFlow ops whose definitions are amended to fix
// issues or provide more information. In this file you have full control
// of the op definition; all changes will be retained with subsequent
// refreshes.
//
// This file includes another file, `tf_generated_ops.td`, which contains
// all ops whose definitions are generated from TensorFlow codebase.
// Changes made there are not respected.
#ifndef TF_OPS
#define TF_OPS
include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/SymbolInterfaces.td"
class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
let results = (outs
TF_VariantTensor:$handle
);
TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>;
let hasVerifier = 1;
DerivedTypeAttr element_dtype = DerivedTypeAttr<
"return getElementTypeOrSelf(element_type());">;
let extraClassDeclaration = [{
// Returns type of the TensorList element produced by this op.
TensorType element_type() { return handle_dtype().getSubtypes()[0]; }
// Returns data type of the result handle. Returned type contains type of
// the TensorList element as a subtype.
VariantType handle_dtype() {
return getElementTypeOrSelf(handle().getType()).cast<TF::VariantType>();
}
}];
}
def TF_CaseOp : TF_Op<"Case", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = [{
An n-way switch statement which calls a single branch function.
}];
let description = [{
An n-way switch statement, implementing the following:
```
switch (branch_index) {
case 0:
output = branches[0](input);
break;
case 1:
output = branches[1](input);
break;
...
case [[nbranches-1]]:
default:
output = branches[nbranches-1](input);
break;
}
```
}];
let arguments = (ins
I32Tensor:$branch_index,
Variadic<TF_Tensor>:$input,
ConfinedAttr<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches,
// Used to map StatelessCase and Case op defined in TensorFlow to a common
// op.
BoolAttr:$is_stateless
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
let hasCanonicalizer = 1;
let hasVerifier = 1;
let extraClassDeclaration = [{
int num_branches() { return branches().size(); }
// Gets function corresponding branch # `index`.
// Prefer passing in SymbolTableCollection to reduce lookup costs by
// enabling reusing cached symbol table lookup.
func::FuncOp ResolveBranchFunction(::mlir::SymbolTableCollection* table, int index) {
auto flat_sym_ref = branches()[index].cast<FlatSymbolRefAttr>();
if (table)
return table->lookupNearestSymbolFrom<func::FuncOp>(*this, flat_sym_ref);
return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, flat_sym_ref);
}
// TODO(b/204997177): Deprecate and remove.
func::FuncOp branch_function(int index) { return ResolveBranchFunction(nullptr, index); }
// Resolve all branch functions.
// Prefer passing in SymbolTableCollection to reduce lookup costs by
// enabling reusing cached symbol table lookup.
void ResolveBranchFunctions(::mlir::SymbolTableCollection* table,
SmallVectorImpl<func::FuncOp> &functions) {
functions.reserve(num_branches());
for (int idx : llvm::seq<int>(0, num_branches()))
functions.push_back(ResolveBranchFunction(table, idx));
}
// TODO(b/204997177): Deprecate and remove.
void get_branch_functions(SmallVectorImpl<func::FuncOp> &functions) {
return ResolveBranchFunctions(nullptr, functions);
}
}];
}
def TF_CaseRegionOp : TF_Op<"CaseRegion",
[SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> {
let summary = [{
An n-way switch statement which calls a single branch function.
}];
let description = [{
An n-way switch statement, implementing the following:
```
switch (branch_index) {
case 0:
output = branches[0](input);
break;
case 1:
output = branches[1](input);
break;
...
case [[nbranches-1]]:
default:
output = branches[nbranches-1](input);
break;
}
```
}];
let arguments = (ins
I32Tensor:$branch_index,
// Used to map StatelessCase and Case op defined in TensorFlow to a common
// op.
BoolAttr:$is_stateless
);
let results = (outs
Variadic<TF_Tensor>:$output
);
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
// In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with
// its type encoding the tensor's shape and data type.
def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Constant tensor op";
let arguments = (ins
ElementsAttr:$value
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
let builders = [
OpBuilder<(ins "Attribute":$value)>,
OpBuilder<(ins "Type":$type, "Attribute":$value)>,
];
let hasFolder = 1;
let extraClassDeclaration = [{
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return BroadcastCompatible(l, r);
}
}];
}
def TF_EmptyTensorListOp : TF_TensorListInitOp<"EmptyTensorList"> {
let summary = "Creates and returns an empty tensor list.";
let description = [{
All list elements must be tensors of dtype element_dtype and shape compatible
with element_shape.
handle: an empty tensor list.
element_dtype: the type of elements in the list.
element_shape: a shape compatible with that of elements in the list.
}];
let arguments = (ins
TF_I32OrI64Tensor:$element_shape,
TF_Int32Tensor:$max_num_elements
);
}
def TF_IfOp : TF_Op<"If", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "output = cond ? then_branch(input) : else_branch(input)";
let description = [{
output = cond ? then_branch(input) : else_branch(input)
cond: A Tensor. If the tensor is a scalar of non-boolean type, the
scalar is converted to a boolean according to the
following rule: if the scalar is a numerical value, non-zero means
True and zero means False; if the scalar is a string, non-empty
means True and empty means False. If the tensor is not a scalar,
being empty means False and being non-empty means True.
input: A list of input tensors.
then_branch: A function that takes 'inputs' and returns a list of
tensors, whose types are the same as what else_branch returns.
else_branch: A function that takes 'inputs' and returns a list of
tensors. whose types are the same as what then_branch returns.
}];
let arguments = (ins
TF_Tensor:$cond,
Variadic<TF_Tensor>:$input,
FlatSymbolRefAttr:$then_branch,
FlatSymbolRefAttr:$else_branch,
// Used to map StatelessIf and If op defined in TensorFlow to a common op.
BoolAttr:$is_stateless
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
let hasCanonicalizer = 1;
let extraClassDeclaration = [{
// Resolve the then branch function.
// Prefer passing in SymbolTableCollection to reduce lookup costs by
// enabling reusing cached symbol table lookup.
func::FuncOp ResolveThenFunction(::mlir::SymbolTableCollection* table) {
if (table)
return table->lookupNearestSymbolFrom<func::FuncOp>(*this, then_branchAttr());
return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
*this, then_branchAttr());
}
// TODO(b/204997177): Deprecate and remove.
func::FuncOp then_function(::mlir::SymbolTableCollection* table = nullptr) {
return ResolveThenFunction(table);
}
// Resolve the else branch function.
// Prefer passing in SymbolTableCollection to reduce lookup costs by
// enabling reusing cached symbol table lookup.
func::FuncOp ResolveElseFunction(::mlir::SymbolTableCollection* table) {
if (table)
return table->lookupNearestSymbolFrom<func::FuncOp>(*this, else_branchAttr());
return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
*this, else_branchAttr());
}
// TODO(b/204997177): Deprecate and remove.
func::FuncOp else_function(::mlir::SymbolTableCollection* table = nullptr) {
return ResolveElseFunction(table);
}
}];
}
def TF_YieldOp : TF_Op<"Yield",
[NoSideEffect, ReturnLike, Terminator,
ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>]> {
let summary = "Yield operation";
let description = [{
The "yield" operation represents a return operation within the conditional
and body of structured control flow (e.g., if and while). The operation
takes a variable number of operands and produces no results. The number and
types of inputs must match the signature of the operation that contains the
region.
}];
let arguments = (ins Variadic<AnyType>:$operands);
}
def TF_IfRegionOp : TF_Op<"IfRegion",
[SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> {
let summary = "output = cond ? then_branch output : else_branch output";
let description = [{
"output = cond ? then_branch output : else_branch output"
cond: A Tensor. If the tensor is a scalar of non-boolean type, the
scalar is converted to a boolean according to the
following rule: if the scalar is a numerical value, non-zero means
True and zero means False; if the scalar is a string, non-empty
means True and empty means False. If the tensor is not a scalar,
being empty means False and being non-empty means True.
then_branch: A region that computes the outputs of the op if cond = true.
It returns a list of tensors using tf.yield (as the terminator). The
types of these returned tensors is same as that of the else_branch
else_branch: A region that computes the outputs of the op if cond = false.
It returns a list of tensors using tf.yield (as the terminator). The
types of these returned tensors is same as that of the then_branch
}];
let arguments = (ins
0DTensorOf<[I1]>:$cond,
// Used to map StatelessIf and If op defined in TensorFlow to a common op.
BoolAttr:$is_stateless,
// Used to maintain function name when round-tripping
// between functional and regional control flow. This can be removed if
// the runtime does not require globally unique then/else branch function names.
OptionalAttr<StrAttr>:$_then_func_name,
OptionalAttr<StrAttr>:$_else_func_name
);
let results = (outs
Variadic<TF_Tensor>:$output
);
let regions = (region SizedRegion<1>:$then_branch, SizedRegion<1>:$else_branch);
let hasRegionVerifier = 1;
let builders = [
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
"llvm::ArrayRef<::mlir::NamedAttribute>":$attributes,
"unsigned":$numRegions),
[{
assert(numRegions == 2u && "mismatched number of regions");
build($_builder, $_state, resultTypes, operands, attributes);
}]>];
let hasCanonicalizer = 1;
}
def TF_LegacyCallOp : TF_Op<"LegacyCall",
[CallOpInterface, NoSideEffect]> {
let summary =
"returns `f(inputs)`, where `f` is a function.";
let description = [{
The LegacyCall operation represents a direct call to a function that is
within the same symbol scope as the call and is mapped to a GraphDef node
with the function name as the op name. Unlike a PartitionedCall which
represents asynchronously executing a function across multiple devices, a
LegacyCall ignores specification for ops in the attached function and
instead executes it on the device assigned to this op.
}];
let arguments = (ins
Variadic<TF_Tensor>:$args,
FlatSymbolRefAttr:$f,
DefaultValuedAttr<BoolAttr, "false">:$_disable_call_shape_inference
);
let results = (outs
Variadic<TF_Tensor>:$output
);
let extraClassDeclaration = [{
// Gets the argument operands to the called function.
operand_range getArgOperands() { return args(); }
// Returns the callee of this operation.
CallInterfaceCallable getCallableForCallee() { return fAttr(); }
// Returns the resolved callee function of this operation.
// Prefer passing in SymbolTableCollection to reduce lookup costs by
// enabling reusing cached symbol table lookup.
func::FuncOp ResolveFunc(::mlir::SymbolTableCollection* table) {
if (table)
return table->lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr());
return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr());
}
// TODO(b/204997177): Deprecate and remove.
func::FuncOp func() { return ResolveFunc(nullptr); }
}];
}
def TF_ParseExampleOp : TF_Op<"ParseExample",
[NoSideEffect,
AttrSizedResultSegments,
AttrSizedOperandSegments]> {
let summary =
"Transforms a vector of tf.Example protos (as strings) into typed tensors.";
let arguments = (ins
TF_StrTensor:$serialized,
TF_StrTensor:$names,
Variadic<TF_StrTensor>:$sparse_keys,
Variadic<TF_StrTensor>:$dense_keys,
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults,
TF_ShapeAttrArray:$dense_shapes,
DenseI32ArrayAttr:$result_segment_sizes,
DenseI32ArrayAttr:$operand_segment_sizes
);
let results = (outs
Variadic<TF_Int64Tensor>:$sparse_indices, // len(sparse_types)
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values, // len(sparse_types)
Variadic<TF_Int64Tensor>:$sparse_shapes, // len(sparse_types)
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values // len(Tdense)
);
TF_DerivedOperandSizeAttr Nsparse = TF_DerivedOperandSizeAttr<2>;
TF_DerivedOperandSizeAttr Ndense = TF_DerivedOperandSizeAttr<3>;
TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<4>;
TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>;
let hasVerifier = 0;
}
def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2",
[NoSideEffect,
AttrSizedResultSegments]> {
let summary =
"Transforms a vector of tf.Example protos (as strings) into typed tensors.";
let arguments = (ins
TF_StrTensor:$serialized,
TF_StrTensor:$names,
TF_StrTensor:$sparse_keys,
TF_StrTensor:$dense_keys,
TF_StrTensor:$ragged_keys,
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults,
ConfinedAttr<I64Attr, [IntMinValue<0>]>:$num_sparse,
TF_ShapeAttrArray:$dense_shapes,
DenseI32ArrayAttr:$result_segment_sizes
);
let results = (outs
Variadic<TF_Int64Tensor>:$sparse_indices, // len(sparse_types)
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values, // len(sparse_types)
Variadic<TF_Int64Tensor>:$sparse_shapes, // len(sparse_types)
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values, // len(Tdense)
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$ragged_values, // len(ragged_value_types)
// = len(ragged_split_types)
Variadic<TensorOf<[TF_Int32, TF_Int64]>>:$ragged_row_splits // len(ragged_split_types)
// = len(ragged_value_types)
);
// The Verify(ParseExampleV2Op) function validates that the lengths and types
// of these attrs are compatible.
TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<5>;
TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>;
TF_DerivedResultTypeListAttr ragged_value_types =
TF_DerivedResultTypeListAttr<4>;
TF_DerivedResultTypeListAttr ragged_split_types =
TF_DerivedResultTypeListAttr<5>;
let hasVerifier = 1;
}
def TF_PlaceholderOp : TF_Op<"Placeholder", [NoSideEffect]> {
let summary = "Placeholder op";
let description = [{
Inserts a placeholder for a tensor that will be always fed.
}];
let arguments = (ins
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect]> {
let summary = "Placeholder op";
let description = [{
A placeholder op that passes through input when its output is not fed.
}];
let arguments = (ins
TF_Tensor:$input
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
DerivedAttr shape = TF_DerivedResultShapeAttr;
}
def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall",
[CallOpInterface, SymbolUserOpInterface]> {
let summary =
"returns `f(inputs)`, where `f`'s body is placed and partitioned.";
let description = [{
Asynchronously executes a function, potentially across multiple devices but
within a single process. The kernel places and partitions a given function's
underlying graph, and executes each of the partitioned subgraphs as a function.
}];
let arguments = (ins
Variadic<TF_Tensor>:$args,
FlatSymbolRefAttr:$f,
StrAttr:$config,
StrAttr:$config_proto,
StrAttr:$executor_type
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
let extraClassDeclaration = [{
// Gets the argument operands to the called function.
operand_range getArgOperands() { return args(); }
// Returns the callee of this operation.
CallInterfaceCallable getCallableForCallee() { return fAttr(); }
// Returns the resolved callee function of this operation.
// Prefer passing in SymbolTableCollection to reduce lookup costs by
// enabling reusing cached symbol table lookup.
func::FuncOp ResolveFunc(::mlir::SymbolTableCollection* table) {
if (table)
return table->lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr());
return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr());
}
// TODO(b/204997177): Deprecate and remove.
func::FuncOp func() { return ResolveFunc(nullptr); }
// SymbolUserOpInterface verifier.
LogicalResult verifySymbolUses(SymbolTableCollection &symbolTable);
}];
}
def TF_WhileOp : TF_Op<"While", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = [{
output = input; While (Cond(output)) { output = Body(output) }
}];
let description = [{
output = input; While (Cond(output)) { output = Body(output) }
input: A list of input tensors whose types are T.
output: A list of output tensors whose types are T.
cond: A function that takes 'input' and returns a tensor. If the tensor is
a scalar of non-boolean, the scalar is converted to a boolean
according to the following rule: if the scalar is a numerical
value, non-zero means True and zero means False; if the scalar is
a string, non-empty means True and empty means False. If the
tensor is not a scalar, non-emptiness means True and False
otherwise.
body: A function that takes a list of tensors and returns another
list of tensors. Both lists have the same types as specified
by T.
}];
let arguments = (ins
Variadic<TF_Tensor>:$input,
FlatSymbolRefAttr:$cond,
FlatSymbolRefAttr:$body,
ConfinedAttr<DefaultValuedAttr<I64Attr, "10">, [IntMinValue<1>]>:$parallel_iterations,
// Used to map StatelessWhile and While op defined in TensorFlow to a common
// op.
BoolAttr:$is_stateless,
// In TensorFlow, While has a special behavior where if `output_shapes`
// attribute is not empty, those shapes are used in its shape function
// as result shapes instead of propagating operand shapes as result shapes.
// This allows for different result shapes from operand shapes. While these
// shapes are imported and set as a part of the result type, there is no
// indicator differentiating between having no output shapes compared to
// having all unranked shapes. Thus this attribute is set to determine
// which shape function behavior to use for this op, specifically
// propagating operand shapes as result shapes when this attribute is not
// set, or preserving result shapes as is when this attribute is set.
UnitAttr:$shape_invariant
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
let extraClassDeclaration = [{
// Get the condition function.
// Prefer passing in SymbolTableCollection to reduce lookup costs by
// enabling reusing cached symbol table lookup.
func::FuncOp ResolveCondFunction(::mlir::SymbolTableCollection* table) {
if (table)
return table->lookupNearestSymbolFrom<func::FuncOp>(*this, condAttr());
return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, condAttr());
}
// TODO(b/204997177): Deprecate and remove.
func::FuncOp cond_function() { return ResolveCondFunction(nullptr); }
// Get the body function.
// Prefer passing in SymbolTableCollection to reduce lookup costs by
// enabling reusing cached symbol table lookup.
func::FuncOp ResolveBodyFunction(::mlir::SymbolTableCollection* table) {
if (table)
return table->lookupNearestSymbolFrom<func::FuncOp>(*this, bodyAttr());
return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, bodyAttr());
}
// TODO(b/204997177): Deprecate and remove.
func::FuncOp body_function() { return ResolveBodyFunction(nullptr); }
}];
}
def TF_WhileRegionOp : TF_Op<"WhileRegion",
[DeclareOpInterfaceMethods<LoopLikeOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">]> {
let summary = "while operation";
let description = [{
The tf.WhileRegion op represents a while loop using 2 regions and a set of
iteration variables. The iteration variables maintained by this Op have the
same types as the inputs. The Op executes a while loop described by the
following pseudo code:
```
func WhileRegionOp(inputs) {
iteration_vars = inputs;
while (cond(iteration_vars)) {
iteration_vars = body(iteration_vars);
}
return iteration_vars;
}
```
`cond` is the condition region and `body` is the body region. Both these
regions accept the current value of the iteration variables as inputs. The
condition region returns a tensor<i1> which, if false, will exit the loop.
The body region computes new values of the iteration variables. The iteration
variables are initialized to the Op input, and the results of the
tf.WhileRegion op are the final values of the iteration variables.
This implies that the operand and result types for tf.WhileRegion should be
the same. Note that the condition and body regions can implicitly capture
loop invariant values directly. In canonical form, iteration variables that
pass through the loop body unmodified are converted to implicitly captured
references to their values outside the loop.
}];
let arguments = (ins
Variadic<AnyTensor>:$input,
ConfinedAttr<DefaultValuedAttr<I64Attr, "10">, [IntMinValue<1>]>:$parallel_iterations,
// Used to map StatelessWhile and While op defined in TensorFlow to a common
// op.
BoolAttr:$is_stateless,
// In TensorFlow, While has a special behavior where if `output_shapes`
// attribute is not empty, those shapes are used in its shape function
// as result shapes instead of propagating operand shapes as result shapes.
// This allows for different result shapes from operand shapes. While these
// shapes are imported and set as a part of the result type, there is no
// indicator differentiating between having no output shapes compared to
// having all unranked shapes. Thus this attribute is set to determine
// which shape function behavior to use for this op, specifically
// propagating operand shapes as result shapes when this attribute is not
// set, or preserving result shapes as is when this attribute is set.
UnitAttr:$shape_invariant
);
let results = (outs Variadic<AnyTensor>:$output);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> {
let summary = "List of the given size with empty elements.";
let description = [{
element_shape: the shape of the future elements of the list
num_elements: the number of elements to reserve
handle: the output list
element_dtype: the desired type of elements in the list.
}];
let arguments = (ins
TF_I32OrI64Tensor:$element_shape,
TF_Int32Tensor:$num_elements
);
}
def TF_VarHandleOp : TF_Op<"VarHandleOp", [DeclareOpInterfaceMethods<TF_ResourceHandleAllocatorInterface>]> {
let summary = "Creates a handle to a Variable resource from its name.";
let description = [{
container: the container this variable is placed in.
shared_name: the name by which this variable is referred to.
dtype and shape: attributes representing the data type and shape held in the
variable.
Example:
resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[8, 16], container="foo", shared_name="bar")
returns a handle for a variable with name "bar" in container "foo", and the
variable holds a tensor of shape [8, 16] and dtype int32.
}];
let arguments = (ins
DefaultValuedStrAttr<StrAttr, "">:$container,
DefaultValuedStrAttr<StrAttr, "">:$shared_name
);
let results = (outs
Res<TF_ResourceTensor, "", [TF_VariableAlloc]>:$resource
);
DerivedTypeAttr dtype = DerivedTypeAttr<
"return getElementTypeOrSelf(resource_subtype());">;
DerivedAttr shape = DerivedAttr<
"ShapedType",
"return resource_subtype().cast<ShapedType>();",
[{ mlir::TF::ShapeAttr::get($_ctxt, $_self) }]>;
let extraClassDeclaration = [{
TensorType resource_subtype() { return resource_type().getSubtypes()[0]; }
ResourceType resource_type() {
return getElementTypeOrSelf(resource()).cast<TF::ResourceType>();
}
}];
let hasVerifier = 1;
}
def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect, TF_NoConstantFold]> {
let summary = [{
An op which shards the input based on the given sharding attribute.
}];
let arguments = (ins
TF_Tensor:$input,
DefaultValuedStrAttr<StrAttr, "">:$sharding,
OptionalAttr<StrAttr>:$_XlaSharding
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> {
let summary = "Fetches multiple values from infeed as an XLA tuple.";
let arguments = (ins
OptionalAttr<StrAttr>:$_XlaSharding,
OptionalAttr<ArrayAttr>:$layouts
);
let results = (outs
Variadic<TF_Tensor>:$outputs
);
TF_DerivedResultShapeListAttr shapes = TF_DerivedResultShapeListAttr<0>;
TF_DerivedResultTypeListAttr dtypes = TF_DerivedResultTypeListAttr<0>;
}
// TODO(b/177675373): Make dtypes and shapes derived attributes,
// use more general solution.
def TF_InfeedEnqueueTupleOp : TF_Op<"InfeedEnqueueTuple", []> {
let summary = [{
Feeds multiple Tensor values into the computation as an XLA tuple.
}];
let arguments = (ins
Arg<Variadic<TF_Tensor>, [{A list of tensors that will be provided using the infeed mechanism.}]>:$inputs,
ConfinedAttr<TypeArrayAttr, [ArrayMinCount<1>]>:$dtypes,
TF_ShapeAttrArray:$shapes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$layouts,
DefaultValuedAttr<I64Attr, "-1">:$device_ordinal
);
let results = (outs);
}
// This op is manually defined because the attribute name `template` (which is
// a keyword) is changed to `strtemplate`.
def TF_StringFormatOp : TF_Op<"StringFormat", [NoSideEffect]> {
let summary = "Formats a string template using a list of tensors.";
let description = [{
Formats a string template using a list of tensors, pretty-printing tensor summaries.
}];
let arguments = (ins
Variadic<TF_Tensor>:$inputs,
DefaultValuedStrAttr<StrAttr, "%s">:$strtemplate,
DefaultValuedStrAttr<StrAttr, "%s">:$placeholder,
DefaultValuedAttr<I64Attr, "3">:$summarize
);
let results = (outs
TF_StrTensor:$output
);
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
}
//===----------------------------------------------------------------------===//
// tf.data ops
//===----------------------------------------------------------------------===//
def TF_ReduceDatasetOp : TF_Op<"ReduceDataset", [SameVariadicOperandSize]> {
let summary = [{
Reduces the input dataset to a singleton using a reduce function.
}];
let arguments = (ins
TF_VariantTensor:$input_dataset,
Variadic<TF_Tensor>:$initial_state,
Variadic<TF_Tensor>:$other_arguments,
SymbolRefAttr:$f,
ConfinedAttr<TypeArrayAttr, [ArrayMinCount<1>]>:$Tstate,
ConfinedAttr<TypeArrayAttr, [ArrayMinCount<0>]>:$Targuments,
ConfinedAttr<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
ConfinedAttr<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
DefaultValuedAttr<BoolAttr, "true">:$use_inter_op_parallelism
);
let results = (outs
Variadic<TF_Tensor>:$components
);
}
// Manually defined to restrict result type to `I1Tensor`.
def TF_ToBoolOp : TF_Op<"ToBool", [DeclareOpInterfaceMethods<InferTypeOpInterface>, NoSideEffect]> {
let summary = "Converts a tensor to a scalar predicate.";
let description = [{
Converts a tensor to a scalar predicate with the following rules:
- For 0D tensors, truthiness is determined by comparing against a "zero"
value. For numerical types it is the obvious zero. For strings it is the
empty string.
- For >0D tensors, truthiness is determined by looking at the number of
elements. If has zero elements, then the result is false. Otherwise the
result is true.
This matches the behavior of If and While for determining if a tensor counts
as true/false for a branch condition.
}];
let arguments = (ins
TF_Tensor:$input
);
let results = (outs
I1Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
let extraClassDeclaration = [{
// InferTypeOpInterface:
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return ArraysAreCastCompatible(l, r);
}
}];
}
def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the Bessel i0e function of `x` element-wise.";
let description = [{
Exponentially scaled modified Bessel function of order 0 defined as
`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
This function is faster and numerically stabler than `bessel_i0(x)`.
}];
let arguments = (ins
TF_FloatTensor:$x
);
let results = (outs
TF_FloatTensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_BesselI1eOp : TF_Op<"BesselI1e", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the Bessel i1e function of `x` element-wise.";
let description = [{
Exponentially scaled modified Bessel function of order 0 defined as
`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
This function is faster and numerically stabler than `bessel_i1(x)`.
}];
let arguments = (ins
TF_FloatTensor:$x
);
let results = (outs
TF_FloatTensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface, SymbolUserOpInterface]> {
let summary = "Calls a function placed on a specified TPU device.";
let arguments = (ins
Variadic<TF_Tensor>:$args,
TF_Int32Tensor:$device_ordinal,
SymbolRefAttr:$f,
DefaultValuedAttr<I64Attr, "0">:$autotuner_thresh
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
let extraClassDeclaration = [{
// Gets the argument operands to the called function.
operand_range getArgOperands() { return args(); }
// Returns the callee of this operation.
CallInterfaceCallable getCallableForCallee() { return fAttr(); }
// Returns the resolved callee function of this operation.
// Prefer passing in SymbolTableCollection to reduce lookup costs by
// enabling reusing cached symbol table lookup.
func::FuncOp ResolveFunc(::mlir::SymbolTableCollection* table) {
if (table)
return table->lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr());
return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr());
}
// TODO(b/204997177): Deprecate and remove.
func::FuncOp func() { return ResolveFunc(nullptr); }
// SymbolUserOpInterface verifier.
LogicalResult verifySymbolUses(SymbolTableCollection &symbolTable);
}];
}
def TF_StatefulUniformFullIntOp : TF_Op<"StatefulUniformFullInt", []> {
let summary = "Outputs random integers from a uniform distribution.";
let description = [{
The generated values are uniform integers covering the whole range of `dtype`.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
TF_Int64Tensor:$algorithm,
TF_I32OrI64Tensor:$shape
);
let results = (outs
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
);
TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
// TODO(lyandy): Investigate supported dtypes (`minval`, `maxval`, `output`) for
// `tf.StatefulUniformInt`. tf2xla kernels support i32, i64, ui32, and ui64
// while TensorFlow CPU/GPU kernels only support i32 and i64.
def TF_StatefulUniformIntOp : TF_Op<"StatefulUniformInt", []> {
let summary = "Outputs random integers from a uniform distribution.";
let description = [{
The generated values are uniform integers in the range `[minval, maxval)`.
The lower bound `minval` is included in the range, while the upper bound
`maxval` is excluded.
The random integers are slightly biased unless `maxval - minval` is an exact
power of two. The bias is small for values of `maxval - minval` significantly
smaller than the range of the output (either `2^32` or `2^64`).
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
TF_Int64Tensor:$algorithm,
TF_I32OrI64Tensor:$shape,
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$minval,
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$maxval
);
let results = (outs
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
);
TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<3>;
}
def TF_CloseSummaryWriterOp : TF_Op<"CloseSummaryWriter", []> {
let summary = "Flushes and closes the summary writer.";
let description = [{
Also removes it from the resource manager. To reopen, use another
CreateSummaryFileWriter op.
writer: A handle to the summary writer resource.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryFree]>:$writer
);
let results = (outs);
}
// TODO(b/168035831): Model db_uri read/write.
def TF_CreateSummaryDbWriterOp : TF_Op<"CreateSummaryDbWriter", []> {
let summary = "Creates summary database writer accessible by given resource handle.";
let description = [{
This can be used to write tensors from the execution graph directly
to a database. Only SQLite is supported right now. This function
will create the schema if it doesn't exist. Entries in the Users,
Experiments, and Runs tables will be created automatically if they
don't already exist.
writer: Handle to SummaryWriter resource to overwrite.
db_uri: For example "file:/tmp/foo.sqlite".
experiment_name: Can't contain ASCII control characters or <>. Case
sensitive. If empty, then the Run will not be associated with any
Experiment.
run_name: Can't contain ASCII control characters or <>. Case sensitive.
If empty, then each Tag will not be associated with any Run.
user_name: Must be valid as both a DNS label and Linux username. If
empty, then the Experiment will not be associated with any User.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
TF_StrTensor:$db_uri,
TF_StrTensor:$experiment_name,
TF_StrTensor:$run_name,
TF_StrTensor:$user_name
);
let results = (outs);
}
// TODO(b/168035831): Model logdir read/write.
def TF_CreateSummaryFileWriterOp : TF_Op<"CreateSummaryFileWriter", []> {
let summary = "Creates a summary file writer accessible by the given resource handle.";
let description = [{
writer: A handle to the summary writer resource
logdir: Directory where the event file will be written.
max_queue: Size of the queue of pending events and summaries.
flush_millis: How often, in milliseconds, to flush the pending events and
summaries to disk.
filename_suffix: Every event file's name is suffixed with this suffix.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
TF_StrTensor:$logdir,
TF_Int32Tensor:$max_queue,
TF_Int32Tensor:$flush_millis,
TF_StrTensor:$filename_suffix
);
let results = (outs);
}
def TF_FlushSummaryWriterOp : TF_Op<"FlushSummaryWriter", []> {
let summary = "Flushes the writer's unwritten events.";
let description = [{
writer: A handle to the summary writer resource.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer
);
let results = (outs);
}
def TF_ImportEventOp : TF_Op<"ImportEvent", []> {
let summary = "Outputs a `tf.Event` protocol buffer.";
let description = [{
When CreateSummaryDbWriter is being used, this op can be useful for
importing data from event logs.
writer: A handle to a summary writer.
event: A string containing a binary-encoded tf.Event proto.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
TF_StrTensor:$event
);
let results = (outs);
}
def TF_SummaryWriterOp : TF_Op<"SummaryWriter", [DeclareOpInterfaceMethods<TF_ResourceHandleAllocatorInterface>]> {
let summary = "Returns a handle to be used to access a summary writer.";
let description = [{
The summary writer is an in-graph resource which can be used by ops to write
summaries to event files.
writer: the summary writer resource. Scalar handle.
}];
let arguments = (ins
StrAttr:$shared_name,
StrAttr:$container
);
let results = (outs
Res<TF_ResourceTensor, "", [TF_SummaryAlloc]>:$writer
);
}
def TF_WriteAudioSummaryOp : TF_Op<"WriteAudioSummary", []> {
let summary = "Writes a `Summary` protocol buffer with audio.";
let description = [{
The summary has up to `max_outputs` summary values containing audio. The
audio is built from `tensor` which must be 3-D with shape `[batch_size,
frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are
assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`.
The `tag` argument is a scalar `Tensor` of type `string`. It is used to
build the `tag` of the summary values:
* If `max_outputs` is 1, the summary value tag is '*tag*/audio'.
* If `max_outputs` is greater than 1, the summary value tags are
generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc.
writer: A handle to a summary writer.
step: The step to write the summary for.
tag: Scalar. Used to build the `tag` attribute of the summary values.
tensor: 2-D of shape `[batch_size, frames]`.
sample_rate: The sample rate of the signal in hertz.
max_outputs: Max number of batch elements to generate audio for.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
TF_Int64Tensor:$step,
TF_StrTensor:$tag,
TF_Float32Tensor:$tensor,
TF_Float32Tensor:$sample_rate,
ConfinedAttr<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_outputs
);
let results = (outs);
}
def TF_WriteGraphSummaryOp : TF_Op<"WriteGraphSummary", []> {
let summary = "Writes a `GraphDef` protocol buffer to a `SummaryWriter`.";
let description = [{
writer: Handle of `SummaryWriter`.
step: The step to write the summary for.
tensor: A scalar string of the serialized tf.GraphDef proto.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
TF_Int64Tensor:$step,
TF_StrTensor:$tensor
);
let results = (outs);
}
def TF_WriteHistogramSummaryOp : TF_Op<"WriteHistogramSummary", []> {
let summary = "Writes a histogram summary.";
let description = [{
The generated
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
has one summary value containing a histogram for `values`.
This op reports an `InvalidArgument` error if any value is not finite.
writer: A handle to a summary writer.
step: The step to write the summary for.
tag: Scalar. Tag to use for the `Summary.Value`.
values: Any shape. Values to use to build the histogram.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
TF_Int64Tensor:$step,
TF_StrTensor:$tag,
TF_IntOrFpTensor:$values
);
let results = (outs);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
}
def TF_WriteImageSummaryOp : TF_Op<"WriteImageSummary", []> {
let summary = "Writes a `Summary` protocol buffer with images.";
let description = [{
The summary has up to `max_images` summary values containing images. The
images are built from `tensor` which must be 4-D with shape `[batch_size,
height, width, channels]` and where `channels` can be:
* 1: `tensor` is interpreted as Grayscale.
* 3: `tensor` is interpreted as RGB.
* 4: `tensor` is interpreted as RGBA.
The images have the same number of channels as the input tensor. For float
input, the values are normalized one image at a time to fit in the range
`[0, 255]`. `uint8` values are unchanged. The op uses two different
normalization algorithms:
* If the input values are all positive, they are rescaled so the largest one
is 255.
* If any input value is negative, the values are shifted so input value 0.0
is at 127. They are then rescaled so that either the smallest value is 0,
or the largest one is 255.
The `tag` argument is a scalar `Tensor` of type `string`. It is used to
build the `tag` of the summary values:
* If `max_images` is 1, the summary value tag is '*tag*/image'.
* If `max_images` is greater than 1, the summary value tags are
generated sequentially as '*tag*/image/0', '*tag*/image/1', etc.
The `bad_color` argument is the color to use in the generated images for
non-finite input values. It is a `unit8` 1-D tensor of length `channels`.
Each element must be in the range `[0, 255]` (It represents the value of a
pixel in the output image). Non-finite values in the input tensor are
replaced by this tensor in the output image. The default value is the color
red.
writer: A handle to a summary writer.
step: The step to write the summary for.
tag: Scalar. Used to build the `tag` attribute of the summary values.
tensor: 4-D of shape `[batch_size, height, width, channels]` where
`channels` is 1, 3, or 4.
max_images: Max number of batch elements to generate images for.
bad_color: Color to use for pixels with non-finite values.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
TF_Int64Tensor:$step,
TF_StrTensor:$tag,
TensorOf<[TF_Float16, TF_Float32, TF_Uint8]>:$tensor,
TF_Uint8Tensor:$bad_color,
ConfinedAttr<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_images
);
let results = (outs);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
}
def TF_WriteRawProtoSummaryOp : TF_Op<"WriteRawProtoSummary", []> {
let summary = "Writes a `Summary` protocol buffer with serialized string `Summary` protocol buffers.";
let description = [{
writer: A handle to a summary writer.
step: The step to write the summary for.
tensor: A tensor holding one or more serialized `Summary` protobufs to write.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
TF_Int64Tensor:$step,
TF_StrTensor:$tensor
);
let results = (outs);
}
def TF_WriteScalarSummaryOp : TF_Op<"WriteScalarSummary", []> {
let summary = "Writes a `Summary` protocol buffer with scalar values.";
let description = [{
The input `tag` and `value` must have the scalars.
writer: A handle to a summary writer.
step: The step to write the summary for.
tag: Tag for the summary.
value: Value for the summary.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
TF_Int64Tensor:$step,
TF_StrTensor:$tag,
TF_IntOrFpTensor:$value
);
let results = (outs);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
}
def TF_WriteSummaryOp : TF_Op<"WriteSummary", []> {
let summary = "Outputs a `Summary` protocol buffer with a tensor.";
let description = [{
writer: A handle to a summary writer.
step: The step to write the summary for.
tensor: A tensor to serialize.
tag: The summary's tag.
summary_metadata: Serialized SummaryMetadata protocol buffer containing
plugin-related metadata for this summary.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
TF_Int64Tensor:$step,
TF_Tensor:$tensor,
TF_StrTensor:$tag,
TF_StrTensor:$summary_metadata
);
let results = (outs);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
}
def TF__TPUDeviceOrdinalPlaceholderOp : TF_Op<"_TPUDeviceOrdinalPlaceholder", [NoSideEffect]> {
let summary = [{
Placeholder device ordinal that represents device ordinal of a replicated op.
}];
let description = [{
This op can be used when certain rewrite passes materialize ops that require a
device ordinal of a replicated op but replication logic has been abstracted away
using tf_device.replicate op. Subsequent rewrite passes must replace this op with
a constant output that represents the correct device ordinal of the replicated
operations inside a TPU host.
}];
let arguments = (ins);
let results = (outs
TF_Int64Tensor:$device_ordinal
);
}
def TF_TPUPartitionedInputOp : TF_Op<"TPUPartitionedInput", [NoSideEffect]> {
let summary = [{
An op that groups a list of partitioned inputs together. This op
}];
let arguments = (ins
Variadic<TF_Tensor>:$inputs,
DefaultValuedAttr<I64Attr, "0">:$partition_dim,
OptionalAttr<StrAttr>:$_XlaSharding
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
}
def TF_TPUPartitionedOutputOp : TF_Op<"TPUPartitionedOutput", [NoSideEffect]> {
let summary = [{
An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned
}];
let description = [{
outputs outside the XLA computation.
}];
let arguments = (ins
TF_Tensor:$inputs,
DefaultValuedAttr<I64Attr, "0">:$partition_dim,
OptionalAttr<StrAttr>:$_XlaSharding
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultSizeAttr num_splits = TF_DerivedResultSizeAttr<0>;
}
// Declares symbol reference attribute `shape_inference_graph` to be optional
// unlike the TensorFlow definition. This is required to support ops that use
// empty string value for the attribute to signify missing.
def TF_XlaHostComputeOp : TF_Op<"XlaHostCompute", [TF_SendSideEffect, TF_RecvSideEffect, TF_XlaHostComputeSideEffect]> {
let summary = [{
A pseudo-op to represent host-side computation in an XLA program.
}];
let arguments = (ins
Arg<Variadic<TF_Tensor>, [{A list of tensors that will be sent to the host.}]>:$inputs,
StrArrayAttr:$ancestors,
TF_ShapeAttrArray:$shapes,
OptionalAttr<SymbolRefAttr>:$shape_inference_graph,
StrAttr:$key,
DefaultValuedStrAttr<StrAttr, "">:$send_key,
DefaultValuedStrAttr<StrAttr, "">:$recv_key,
DefaultValuedAttr<I64Attr, "1000000">:$cost_estimate_ns,
DefaultValuedAttr<I64Attr, "0">:$tpu_core
);
let results = (outs
Res<Variadic<TF_Tensor>, [{A list of tensors that will be returned to the device.}]>:$outputs
);
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
}
def TF_ConfigureAndInitializeGlobalTPUOp : TF_Op<"ConfigureAndInitializeGlobalTPU", []> {
let summary = [{
An op that initialize the TPU system in a multi-client set up.
}];
let description = [{
Initializes global TPU system for mutli-client execution.
This op does the work of both ConfigureDistributedTpuOp and
InitializeHostForDistributedTpuOp, and outputs the latter's result.
}];
let arguments = (ins);
let results = (outs
Res<TF_Int32Tensor, [{A vector containing the global TPU id of each TPU on the host.}]>:$output
);
}
def TF_ShutdownTPUSystemOp : TF_Op<"ShutdownTPUSystem", []> {
let summary = [{
An op that shuts down the TPU system.
}];
let arguments = (ins);
let results = (outs
TF_BoolTensor:$success
);
}
// Internal op for testing value-based side-effects for non-resource values.
// TODO(mgester) We should have an extension of TF dialect only for testing so
// TF dialect is not polluted with test ops.
def TF__InternalTestNonResourceValueSideEffects_ : TF_Op<"_InternalTestNonResourceValueSideEffects_", []> {
let summary = "Internal op for testing only";
let arguments = (ins
Arg<TF_StrTensor,"", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$key
);
let results = (outs);
}
def TF__InternalTestMustExecuteTrait_ : TF_Op<"_InternalTestMustExecuteTrait_", [TF_MustExecute]> {
let summary = "Internal op for testing only";
let arguments = (ins);
let results = (outs);
}
def TF_SetStaticDimensionBoundsOp : TF_Op<"SetStaticDimensionBounds", []> {
let summary = "Op used to indicate to the compiler and runtime the static bounds of a tensor.";
let description = [{
The information passed through this op can possibly be used by the compiler and
runtime to perform certain optimizations such as more efficient DMAs. The
bounds passed via this op should be considered advisory only, and depending on
the implementation, might do nothing and simply be an identity
`input`: The tensor that has dynamic dimensions.
`static_shape`: The static shape of the tensor, corresponds to the maximum bounds of each dimension.
`output` is the input tensor with no changes done to it.
Example usage:
def tpu_call(args):
def model_fn(args):
# do something with dynamic tensor
@function.Defun(capture_resource_var_by_value=False)
def tpu_subgraph():
return tf.tpu.rewrite(model_fn, args)
return tf.raw_ops.TPUPartitionedCall(
args=tpu_subgraph.captured_inputs,
Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg],
f=tpu_subgraph,
device_ordinal=[0])
static_shape = tf.placeholder(tf.int32, shape=([3]), name='static_size')
w = tf.Variable(tf.constant([[1.0], [2.0], [3.0]]), name='w')
w_dyn = tf.SetDynamicDimensionBounds(w, static_size])
tpu_call([w_dyn])
}];
let arguments = (ins
TF_Tensor:$input,
TF_I32OrI64Tensor:$static_shape
);
let hasVerifier = 1;
let results = (outs
TF_Tensor:$output
);
}
def TF_TPUCompileMlirAndExecuteOp : TF_Op<"TPUCompileMlirAndExecute", [AttrSizedOperandSegments]> {
let summary = "Op that compiles a computation in MLIR into a TPU program, and loads and executes it on a TPU device.";
let description = [{
For the internal use of the TPU compiler.
'static_shapes' are tensors specifying the maximum dimension sizes for the tensors specified in `dynamic_operands`.
'args' are inputs to the TPU computation.
'operands_with_static_shape' are the indices of the operands that have a maximal static shape specified.
'mlir_module' is a serialized MLIR module with a `main` function that contains
target computation.
'metadata' is a serialized TPUCompileMetadataProto describing the shapes and
types of the inputs to the computation, as well as a mapping onto the TPU pod
topology.
'producer_name' is a string describing the name of the framework that add support for running this portion of the model on TPUs.
}];
let arguments = (ins
Variadic<TF_Tensor>:$args,
Variadic<TF_Int64Tensor>:$static_shapes,
OptionalAttr<I32ArrayAttr>:$operands_with_static_shape,
DefaultValuedStrAttr<StrAttr, "">:$mlir_module,
StrAttr:$metadata,
StrAttr:$producer_name
);
let results = (outs
TF_Tensor:$rendezvous_key_base,
Variadic<TF_Tensor>:$results
);
TF_DerivedOperandTypeListAttr Targs = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Tresults = TF_DerivedResultTypeListAttr<0>;
}
#endif // TF_OPS