blob: ce4747eee0da6128efad022223356e14517454de [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_PASSES_H_
#include <memory>
#include <string>
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
namespace tensorflow {
// Pass for trivial buffer forwarding for the linalg.generic operations.
std::unique_ptr<mlir::FunctionPass> CreateLinalgTrivialBufferForwardingPass();
// Pass for trivial copy removal of memref.copy operations.
std::unique_ptr<mlir::FunctionPass> CreateLinalgTrivialCopyRemovalPass();
// Pass to optimize padding in tiled loops by peeling the final loop iteration.
std::unique_ptr<mlir::FunctionPass> CreatePeelTiledLoopsPass();
// Pass to tile and fuse linalg.generic on tensors that models reduction.
std::unique_ptr<mlir::FunctionPass> CreateTileReductionPass();
std::unique_ptr<mlir::FunctionPass> CreateTileReductionPass(
int64_t reduction_vector_size, int64_t reduction_1d_tile_size,
llvm::ArrayRef<int64_t> reduction_2d_tile_sizes);
// Pass to fuse `linalg.fill` into a tiled reduction.
std::unique_ptr<mlir::FunctionPass> CreateFuseFillIntoTiledReductionPass();
// Pass to replace 'i1' tensor types with 'i8' tensor types. This pass is a
// temporary workaround to avoid the problem of vectorizing 'i1' tensors (see
// b/205714705).
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateJitRtLegalizeI1TypesPass();
// Pass to vectorize linalg ops.
std::unique_ptr<mlir::FunctionPass> CreateVectorizeTiledOpsPass();
// Rewrite `vector.multi_reduction` into a sequence of `vector.reduction` ops.
std::unique_ptr<mlir::FunctionPass> createRewriteVectorMultiReductionPass();
// Pass to tile elementwise ops on tensors.
std::unique_ptr<mlir::FunctionPass> CreateTileCWisePass();
std::unique_ptr<mlir::FunctionPass> CreateTileCWisePass(
int64_t cwise_tile_size);
// Pass to split _Fused Tensorflow kernels into primitives.
std::unique_ptr<mlir::FunctionPass> CreateFissionPass();
// Pass to fuse Linalg generic operations on Tensors.
std::unique_ptr<mlir::FunctionPass> CreateFusionPass();
// Pass to optimize broadcasts based on the symbolic shape constraints.
std::unique_ptr<mlir::FunctionPass> CreateSymbolicShapeOptimizationPass(
bool constraints_only = false);
// Pass to replace 0-d tensor inputs to LinalgOp with extracted elements.
std::unique_ptr<mlir::FunctionPass> CreateDetensorizeLinalgPass();
// Creates `tf_device.cluster` operations according to the TF JitRt clustering
// policy.
std::unique_ptr<mlir::FunctionPass> CreateTfJitRtClusteringPass();
std::unique_ptr<mlir::FunctionPass> CreateTfJitRtClusteringPass(
llvm::ArrayRef<std::string> oplist, int min_cluster_size);
// Pass to replace math ops with approximations.
std::unique_ptr<mlir::FunctionPass> CreateMathApproximationPass(
llvm::ArrayRef<std::string> oplist = {});
// Returns true if the `value` type is a memref that is contiguous in memory.
bool IsContiguousMemref(mlir::Value value);
// Detects the combiner in the body of LinalgOp if any. Currently, only
// ops with a single combiner are supported.
mlir::FailureOr<mlir::Operation *> DetectCombiner(
mlir::linalg::LinalgOp linalg_op);
} // namespace tensorflow
#define GEN_PASS_REGISTRATION
#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
#endif // TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_PASSES_H_