blob: 0156d7e7e9dbb7eca011d6e4188643ca030f691d [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations for the "Control Flow" dialect of TensorFlow
// graphs. The TensorFlow control flow dialect represents control flow with
// Switch/Merge and a few related control flow nodes, along with control
// dependencies. This dialect can be raised to the standard TensorFlow dialect
// by transforming Switch/Merge and other control flow ops into functional
// control flow ops and removing control dependencies.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_CONTROL_FLOW_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_CONTROL_FLOW_OPS_H_
#include "mlir/IR/Dialect.h" // TF:llvm-project
#include "mlir/IR/OpDefinition.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project
namespace mlir {
namespace TFControlFlow {
class TFControlFlowDialect : public Dialect {
public:
explicit TFControlFlowDialect(MLIRContext *context);
// Parses a type registered to this dialect.
Type parseType(DialectAsmParser &parser) const override;
// Prints a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;
};
namespace TensorFlowControlTypes {
enum Kind {
Control = Type::FIRST_TENSORFLOW_CONTROL_TYPE,
};
}
class TFControlType : public Type::TypeBase<TFControlType, Type> {
public:
using Base::Base;
static TFControlType get(MLIRContext *context) {
return Base::get(context, TensorFlowControlTypes::Control);
}
// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {
return kind == TensorFlowControlTypes::Control;
}
};
// The "_tf.Enter" operation forwards its input to Tensorflow while loop. Each
// tensor needs its own _tf.Enter to be made available inside the while loop.
//
// More details can be found in Tensorflow Controlflow white paper:
// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
//
// This is defined in Tensorflow as:
//
// REGISTER_OP("Enter")
// .Input("data: T")
// .Output("output: T")
// .Attr("T: type")
// .Attr("frame_name: string")
// .Attr("is_constant: bool = false")
// .Attr("parallel_iterations: int = 10")
//
// For example:
// %1 = "_tf.Enter"(%0#0) {T: "tfdtype$DT_INT32", frame_name:
// "while/while_context",} : (tensor<i32>) -> (tensor<*xi32>)
//
// Note: Additional result corresponds to the control output.
class EnterOp
: public Op<EnterOp, OpTrait::AtLeastNOperands<1>::Impl,
OpTrait::NResults<2>::Impl, MemoryEffectOpInterface::Trait> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.Enter"; }
Value getData() { return getOperand(0); }
void setData(Value value) { setOperand(0, value); }
LogicalResult verify();
// EnterOp has no side-effects.
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &) {}
};
// The "_tf.Merge" operation takes a list of input operands and returns a value
// of the operand type along with the index of the first match encountered.
//
// More details can be found in Tensorflow Controlflow white paper:
// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
//
// This is defined in TensorFlow as:
//
// REGISTER_OP("Merge")
// .Input("inputs: N * T")
// .Output("output: T")
// .Output("value_index: int32")
//
// For example:
// %2 = _tf.Merge %0, %1, %2, %3 : tensor<??xf32>
//
// Note: Additional result corresponds to the control output.
class MergeOp : public Op<MergeOp, OpTrait::VariadicOperands,
OpTrait::NResults<3>::Impl> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.Merge"; }
LogicalResult verify();
};
// The "_tf.NextIteration.source" and "_tf.NextIteration.sink" operations form
// a logical pair. Together, they represent NextIteration op in Tensorflow.
//
// Tensorflow NextIteration operation forwards its input to the next iteration
// of a while loop. Each loop variable needs its own NextIteration op.
//
// More details can be found in Tensorflow Controlflow white paper:
// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
//
// NextIteration op is broken into _tf.NextIteration.sink and
// _tf.NextIteration.source because NextIteration is a back-edge in Tensorflow
// graph, which would form a data flow cycle if expressed naively in a basic
// block. _tf.NextIteration.source takes no input but returns results while
// _tf.NextIteration.sink takes input but doesn't return anything. When
// optimizing these ops, they are paired by op names and considered as a
// single op.
//
// This is defined in Tensorflow as:
//
// REGISTER_OP("NextIteration")
// .Input("data: T")
// .Output("output: T")
// .Attr("T: type")
//
// For example:
// %11 = "_tf.NextIteration.source"() {name: "while/NextIteration", T:
// "tfdtype$DT_INT32", id: 0} : () -> (tensor<*xi32>, _tf.control)
// "_tf.NextIteration.sink"(%10#0) {name: "while/NextIteration", T:
// "tfdtype$DT_INT32", id: 0} : (tensor<*xi32>) -> ()
//
// Note: Additional result corresponds to the control output.
class NextIterationSourceOp
: public Op<NextIterationSourceOp, OpTrait::NResults<2>::Impl> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.NextIteration.source"; }
LogicalResult verify();
};
class NextIterationSinkOp
: public Op<NextIterationSinkOp, OpTrait::AtLeastNOperands<1>::Impl,
OpTrait::OneResult> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.NextIteration.sink"; }
Value getData() { return getOperand(0); }
void setData(Value value) { setOperand(0, value); }
LogicalResult verify();
};
// The "_tf.LoopCond" operation forwards a boolean value as loop condition of
// Tensorflow while loops.
//
// More details can be found in Tensorflow Controlflow white paper:
// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
//
// This is defined in Tensorflow as:
//
// REGISTER_OP("LoopCond")
// .Input("input: bool")
// .Output("output: bool")
//
// For example:
// %5 = "_tf.LoopCond"(%4#0) {device: "", name: "while/LoopCond"} :
// (tensor<*xi1>) -> (i1, !_tf.control)
//
// Note: Additional result corresponds to the control output.
class LoopCondOp
: public Op<LoopCondOp, OpTrait::AtLeastNOperands<1>::Impl,
OpTrait::NResults<2>::Impl, MemoryEffectOpInterface::Trait> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.LoopCond"; }
Value getData() { return getOperand(0); }
void setData(Value value) { setOperand(0, value); }
LogicalResult verify();
// LoopCondOp has no side-effects.
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &) {}
};
// The "_tf.Switch" operation takes a data operand and a boolean predicate
// condition, and returns two values matching the type of the data predicate.
//
// More details can be found in Tensorflow Controlflow white paper:
// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
//
// This is defined in TensorFlow as:
//
// REGISTER_OP("Switch")
// .Input("data: T")
// .Input("pred: bool")
// .Output("output_false: T")
// .Output("output_true: T")
//
// For example:
// %2 = _tf.Switch %0, %1 : tensor<??xf32>
//
// Note: Additional result corresponds to the control output.
class SwitchOp : public Op<SwitchOp, OpTrait::AtLeastNOperands<2>::Impl,
OpTrait::NResults<3>::Impl> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.Switch"; }
Value getData() { return getOperand(0); }
void setData(Value value) { setOperand(0, value); }
Value getPredicate() { return getOperand(1); }
void setPredicate(Value value) { setOperand(1, value); }
LogicalResult verify();
};
// The "_tf.Exit" operation forwards a value from an while loop to its consumer
// outside of loop. Each returned tensor needs its own _tf.Exit.
//
// More details can be found in Tensorflow Controlflow white paper:
// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
//
// This is defined in Tensorflow as:
//
// REGISTER_OP("Exit")
// .Input("data: T")
// .Output("output: T")
// .Attr("T: type")
//
// For example:
// %1 = "_tf.Exit"(%0#0) {T: "tfdtype$DT_INT32",} : (tensor<*xi32>) ->
// (tensor<*xi32>, !_tf.control)
//
// Note: Additional result corresponds to the control output.
class ExitOp
: public Op<ExitOp, OpTrait::AtLeastNOperands<1>::Impl,
OpTrait::NResults<2>::Impl, MemoryEffectOpInterface::Trait> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.Exit"; }
Value getData() { return getOperand(0); }
void setData(Value value) { setOperand(0, value); }
LogicalResult verify();
// ExitOp has no side-effects.
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &) {}
};
} // namespace TFControlFlow
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_CONTROL_FLOW_OPS_H_