blob: 46abae27878203fcc5f48af8189eb4960c0925e1 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
// Transformation that converts tf.while_loop() loops into functional While
// operators and tf.cond() conditionals into function If operators, suitable for
// XLA compilation.
//
// If `node_filter` is defined, then only loops and conditions for whose
// nodes `node_filter` returns true are functionalized.
// If `include_functions` is true, then loops and conditions inside of functions
// that are associated with nodes in `graph` (e.g., a function called from a
// node in `graph`) are also functionalized, otherwise they are not.
// This also handles transitive cases, e.g., a function body will be
// functionalized when it is called in another function that is called by some
// node in `graph` (and so on). The node filter also applies here.
//
// Precondition:
// For any node in a loop or condition for which `node_filter` returns true,
// all nodes inside of the same loop or condition must also return true
// (including nodes in other nested loops and conditions inside of that loop or
// condition).
// This means that a "not to be functionalized" loop or condition is not allowed
// inside a "to be functionalized" loop or condition.
//
// The user of this function is responsible for using a node filter that
// satisfies the above conditions.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library,
const NodeFilter& node_filter = {},
bool include_functions = false);
Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
FunctionLibraryDefinition* library,
const NodeFilter& node_filter = {},
bool include_functions = false);
// This pass looks at the graph, and turns V1 control flow structure
// (Switch/Merge/etc.) into V2 control flow structure (If/While).
class FunctionalizeControlFlowPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
// Same as the above but only modifies functions that will be executed by XLA.
class FunctionalizeControlFlowForXlaPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_