| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/interpreter/compiler.h" |
| |
| #include <string> |
| #include <utility> |
| |
| #include "absl/memory/memory.h" |
| #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" |
| #include "tensorflow/compiler/xla/service/cholesky_expander.h" |
| #include "tensorflow/compiler/xla/service/comparison_expander.h" |
| #include "tensorflow/compiler/xla/service/computation_placer.h" |
| #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" |
| #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" |
| #include "tensorflow/compiler/xla/service/eigh_expander.h" |
| #include "tensorflow/compiler/xla/service/flatten_call_graph.h" |
| #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" |
| #include "tensorflow/compiler/xla/service/hlo_cse.h" |
| #include "tensorflow/compiler/xla/service/hlo_dce.h" |
| #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" |
| #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" |
| #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" |
| #include "tensorflow/compiler/xla/service/interpreter/executable.h" |
| #include "tensorflow/compiler/xla/service/layout_assignment.h" |
| #include "tensorflow/compiler/xla/service/map_inliner.h" |
| #include "tensorflow/compiler/xla/service/qr_expander.h" |
| #include "tensorflow/compiler/xla/service/reshape_mover.h" |
| #include "tensorflow/compiler/xla/service/triangular_solve_expander.h" |
| #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/platform/types.h" |
| |
| namespace xla { |
| namespace interpreter { |
| |
| namespace { |
| |
| // Handles custom_call ops during evaluation by routing them through the global |
| // CPU registry used by other CPU-based backends. |
| StatusOr<Literal> HandleEvaluatorCustomCall( |
| HloInstruction* custom_call, absl::Span<const Literal*> operands) { |
| // Find the target C function in the global registry. |
| auto* registry = CustomCallTargetRegistry::Global(); |
| void* target_fn = registry->Lookup(custom_call->custom_call_target(), "Host"); |
| if (!target_fn) { |
| return NotFound("Custom call target '%s' was not registered", |
| custom_call->custom_call_target()); |
| } |
| |
| // Populate pointers to operand and output literal data. |
| std::vector<const void*> operand_data; |
| operand_data.reserve(operands.size()); |
| for (const auto* literal : operands) { |
| operand_data.push_back(literal->untyped_data()); |
| } |
| auto output = Literal::CreateFromShape(custom_call->shape()); |
| void* output_data = output.untyped_data(); |
| |
| // Call the target function matching the C ABI used by the CPU backends. |
| auto* typed_fn = reinterpret_cast<void (*)(void*, const void**)>(target_fn); |
| (*typed_fn)(output_data, operand_data.data()); |
| |
| return std::move(output); |
| } |
| |
| } // namespace |
| |
| Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { |
| HloPassPipeline pipeline("Interpreter"); |
| |
| pipeline.AddPass<DynamicIndexSplitter>(); |
| pipeline.AddPass<CholeskyExpander>(); |
| pipeline.AddPass<QrExpander>(); |
| pipeline.AddPass<EighExpander>(); |
| pipeline.AddPass<ComparisonExpander>(); |
| pipeline.AddPass<TriangularSolveExpander>(); |
| pipeline.AddPass<LayoutAssignment>( |
| hlo_module->mutable_entry_computation_layout()); |
| |
| return pipeline.Run(hlo_module).status(); |
| } |
| |
| StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses( |
| std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/, |
| const CompileOptions& /*options*/) { |
| VLOG(1) << "Run hlo passes on graph " << hlo_module->name(); |
| TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); |
| return std::move(hlo_module); |
| } |
| |
| StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend( |
| std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec, |
| const CompileOptions& /*options*/) { |
| TF_RET_CHECK(stream_exec != nullptr); |
| |
| VLOG(1) << "Run backend " << hlo_module->name(); |
| |
| TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, |
| DynamicDimensionInference::Run(hlo_module.get())); |
| |
| auto evaluator = absl::make_unique<HloEvaluator>(); |
| evaluator->set_use_fast_path( |
| hlo_module->config().debug_options().xla_hlo_evaluator_use_fast_path()); |
| evaluator->set_custom_call_handler(HandleEvaluatorCustomCall); |
| |
| // Create executable from only the Hlo module. |
| std::unique_ptr<Executable> executable = |
| absl::make_unique<InterpreterExecutable>( |
| std::move(hlo_module), std::move(evaluator), |
| std::move(dynamic_dimension_inference)); |
| |
| return std::move(executable); |
| } |
| |
| StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile( |
| std::unique_ptr<HloModuleGroup> module_group, |
| std::vector<std::vector<se::StreamExecutor*>> stream_exec, |
| const CompileOptions& options) { |
| if (module_group->empty()) { |
| return std::vector<std::unique_ptr<Executable>>(); |
| } |
| if (module_group->size() > 1) { |
| return tensorflow::errors::Unimplemented( |
| "Compilation of multiple HLO modules is not supported on Interpreter."); |
| } |
| if (stream_exec.size() != 1 || stream_exec[0].size() != 1) { |
| return tensorflow::errors::Unimplemented( |
| "Unexpected number of StreamExecutor's."); |
| } |
| auto hlo_modules = module_group->ConsumeModules(); |
| TF_ASSIGN_OR_RETURN(auto module, RunHloPasses(std::move(hlo_modules[0]), |
| stream_exec[0][0], options)); |
| TF_ASSIGN_OR_RETURN(auto executable, RunBackend(std::move(module), |
| stream_exec[0][0], options)); |
| std::vector<std::unique_ptr<Executable>> ret; |
| ret.push_back(std::move(executable)); |
| return std::move(ret); |
| } |
| |
| StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> |
| InterpreterCompiler::CompileAheadOfTime( |
| std::unique_ptr<HloModuleGroup> module_group, |
| const AotCompilationOptions& aot_options) { |
| return tensorflow::errors::InvalidArgument( |
| "AOT compilation not supported on Interpreter"); |
| } |
| |
| se::Platform::Id InterpreterCompiler::PlatformId() const { |
| return se::interpreter::kXlaInterpreterPlatformId; |
| } |
| |
| HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() |
| const { |
| return InterpreterExecutable::ShapeSizeBytes; |
| } |
| |
| static bool InitModule() { |
| xla::Compiler::RegisterCompilerFactory( |
| se::interpreter::kXlaInterpreterPlatformId, []() { |
| return absl::make_unique<xla::interpreter::InterpreterCompiler>(); |
| }); |
| xla::ComputationPlacer::RegisterComputationPlacer( |
| se::interpreter::kXlaInterpreterPlatformId, |
| []() { return absl::make_unique<xla::ComputationPlacer>(); }); |
| return true; |
| } |
| |
| static bool module_initialized = InitModule(); |
| |
| } // namespace interpreter |
| } // namespace xla |