| /* Copyright 2019 Google Inc. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" |
| #include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h" |
| #include "mlir/IR/Dialect.h" // from @llvm-project |
| #include "mlir/InitAllDialects.h" // from @llvm-project |
| #include "mlir/InitAllPasses.h" // from @llvm-project |
| #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/init_mlir.h" |
| #include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_gpu.h" |
| #include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_gpu_binary.h" |
| #include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.h" |
| #include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_tfrt_gpu.h" |
| #include "tfrt/gpu/kernels/gpu_ops.h" // from @tf_runtime |
| #include "tfrt/gpu/passes/passes.h" // from @tf_runtime |
| #include "tfrt/init_tfrt_dialects.h" // from @tf_runtime |
| |
| int main(int argc, char **argv) { |
| tensorflow::InitMlir y(&argc, &argv); |
| |
| mlir::DialectRegistry registry; |
| mlir::registerAllDialects(registry); |
| registry.insert<mlir::lmhlo::LmhloDialect, mlir::lmhlo_gpu::LmhloGpuDialect, |
| mlir::mhlo::MhloDialect, tfrt::gpu::GpuDialect>(); |
| tfrt::RegisterTFRTDialects(registry); |
| |
| mlir::registerAllPasses(); |
| tensorflow::registerConvertLmhloToGpuBranchPass(); |
| tensorflow::registerConvertLmhloToGpuBinaryPass(); |
| tensorflow::registerConvertLmhloToGpuPass(); |
| tensorflow::registerLmhloToTfrtGpuPass(); |
| tfrt::gpu::RegisterPasses(); |
| tensorflow::registerLmhloToJitRtPasses(); |
| return failed( |
| mlir::MlirOptMain(argc, argv, "MHLO TFRT pass driver\n", registry)); |
| } |