| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/xla/hlo_module_importer.h" |
| |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/Location.h" // from @llvm-project |
| #include "mlir/IR/OperationSupport.h" // from @llvm-project |
| #include "mlir/IR/Types.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |
| #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" |
| #include "tensorflow/compiler/xla/service/hlo_computation.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_module.h" |
| #include "tensorflow/compiler/xla/xla.pb.h" |
| |
| namespace xla { |
| |
| HloModuleImporter::HloModuleImporter(mlir::ModuleOp module) |
| : module_(module), builder_(module.getContext()) { |
| module.getContext()->loadDialect<mlir::StandardOpsDialect>(); |
| module.getContext()->loadDialect<mlir::mhlo::MhloDialect>(); |
| } |
| |
| Status HloModuleImporter::Import(const xla::HloModule& module) { |
| // TODO(hinsu): Only import the entry computation here once all HLO ops with |
| // reference to other computation are updated to have a region instead of a |
| // function attribute. Currently the importer test doesn't refer to all the |
| // computations from the entry computation so tests may need some update. |
| for (const auto* computation : module.computations()) { |
| TF_RETURN_IF_ERROR(HloFunctionImporter::ImportAsFunc( |
| *computation, module_, &function_map_, &builder_)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status HloModuleImporter::Import(const xla::HloModuleProto& module_proto) { |
| xla::DebugOptions debug_options; |
| TF_ASSIGN_OR_RETURN( |
| auto module_config, |
| xla::HloModule::CreateModuleConfigFromProto(module_proto, debug_options)); |
| TF_ASSIGN_OR_RETURN(auto module, xla::HloModule::CreateFromProto( |
| module_proto, module_config)); |
| |
| return Import(*module); |
| } |
| |
| } // namespace xla |