| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/jit/kernels/xla_ops.h" |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/memory/memory.h" |
| #include "tensorflow/compiler/jit/defs.h" |
| #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" |
| #include "tensorflow/compiler/jit/flags.h" |
| #include "tensorflow/compiler/jit/xla_activity_listener.h" |
| #include "tensorflow/compiler/jit/xla_cluster_util.h" |
| #include "tensorflow/compiler/tf2xla/shape_util.h" |
| #include "tensorflow/compiler/tf2xla/tf2xla_util.h" |
| #include "tensorflow/compiler/tf2xla/xla_compiler.h" |
| #include "tensorflow/compiler/tf2xla/xla_op_registry.h" |
| #include "tensorflow/compiler/xla/client/client_library.h" |
| #include "tensorflow/compiler/xla/client/local_client.h" |
| #include "tensorflow/compiler/xla/service/compiler.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/core/common_runtime/dma_helper.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/framework/allocator.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/platform/env.h" |
| #include "tensorflow/core/platform/stream_executor_no_cuda.h" |
| #include "tensorflow/core/profiler/lib/traceme.h" |
| #include "tensorflow/core/util/stream_executor_util.h" |
| |
| // OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that |
| // in error case, it returns RET instead of void. |
| #define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ |
| do { \ |
| ::tensorflow::Status _s(__VA_ARGS__); \ |
| if (!TF_PREDICT_TRUE(_s.ok())) { \ |
| (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ |
| return RET; \ |
| } \ |
| } while (0) |
| |
| namespace tensorflow { |
| |
| namespace { |
| |
| XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { |
| DeviceType device_type = ctx->device_type(); |
| se::Platform::Id platform_id = nullptr; |
| const XlaDevice::Metadata* xla_device_metadata = nullptr; |
| se::DeviceMemoryAllocator* custom_allocator = nullptr; |
| |
| if (ctx->device_type() == DeviceType(DEVICE_CPU)) { |
| platform_id = se::host::kHostPlatformId; |
| } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) { |
| platform_id = ctx->device() |
| ->tensorflow_gpu_device_info() |
| ->stream->parent() |
| ->platform() |
| ->id(); |
| } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) { |
| // If we are on an XlaDevice, use the underlying XLA platform's allocator |
| // directly. We could use the StreamExecutor's allocator which may |
| // theoretically be more correct, but XLA returns a nice OOM message in a |
| // Status and StreamExecutor does not. |
| // |
| // Importantly we can't use ctx->device()->GetAllocator() as the allocator |
| // (which xla_allocator above uses) as on an XlaDevice, this is a dummy |
| // allocator that returns XlaTensor objects. The XlaCompiler needs a real |
| // allocator to allocate real buffers. |
| platform_id = xla_device_metadata->platform()->id(); |
| custom_allocator = |
| xla_device_metadata->client()->backend().memory_allocator(); |
| } |
| |
| return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, |
| custom_allocator); |
| } |
| |
| // A closure describing how to run a compiled version of a TensorFlow function. |
| // |
| // It may seem unusual to stick the resource variable snapshots in this class. |
| // This is necessary: we need to use the snapshots observed by the compiler as |
| // the initial values for the resource variables (and cannot snapshot them again |
| // during execution) because otherwise we risk observing a different snapshot |
| // with shapes different from what we compiled for. |
| class XlaExecutableClosure { |
| public: |
| explicit XlaExecutableClosure( |
| xla::LocalClient* client, xla::LocalExecutable* executable, |
| const XlaCompiler::CompilationResult* compilation_result, |
| std::map<int, OptionalTensor> resource_var_snapshots, |
| int num_constant_args) |
| : client_(client), |
| executable_(executable), |
| compilation_result_(compilation_result), |
| resource_var_snapshots_(std::move(resource_var_snapshots)), |
| num_constant_args_(num_constant_args) {} |
| |
| XlaExecutableClosure(XlaExecutableClosure&&) = default; |
| XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default; |
| |
| xla::LocalClient* client() const { return client_; } |
| xla::LocalExecutable* executable() const { return executable_; } |
| const XlaCompiler::CompilationResult* compilation_result() const { |
| return compilation_result_; |
| } |
| const std::map<int, OptionalTensor>& resource_var_snapshots() const { |
| return resource_var_snapshots_; |
| } |
| int num_constant_args() const { return num_constant_args_; } |
| |
| private: |
| xla::LocalClient* client_; |
| xla::LocalExecutable* executable_; |
| const XlaCompiler::CompilationResult* compilation_result_; |
| std::map<int, OptionalTensor> resource_var_snapshots_; |
| int num_constant_args_; |
| |
| TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure); |
| }; |
| |
| // This maintains a mapping from a globally unique ID to XlaExecutableClosure |
| // instances. |
| class XlaExecutableClosureStore { |
| public: |
| XlaExecutableClosureStore() : key_counter_(0) {} |
| |
| using KeyT = string; |
| |
| KeyT Produce(XlaExecutableClosure result) { |
| mutex_lock l(mutex_); |
| KeyT key = absl::StrCat(key_counter_++); |
| bool insert_successful = closures_.emplace(key, std::move(result)).second; |
| DCHECK(insert_successful); |
| (void)insert_successful; |
| return key; |
| } |
| |
| XlaExecutableClosure Consume(const KeyT& key) { |
| mutex_lock l(mutex_); |
| auto it = closures_.find(key); |
| DCHECK(it != closures_.end()); |
| XlaExecutableClosure value = std::move(it->second); |
| closures_.erase(it); |
| return value; |
| } |
| |
| static XlaExecutableClosureStore* Global() { |
| static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore; |
| return instance; |
| } |
| |
| private: |
| mutex mutex_; |
| int64 key_counter_ GUARDED_BY(mutex_); |
| absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_); |
| |
| TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore); |
| }; |
| |
| // Return allocator from platform info if non-null, or populate and return a |
| // pointer to the allocator adapter with allocator from context. |
| // |
| // This is necessary because for XLA devices the underlying TF allocator returns |
| // dummy tensors. |
| se::DeviceMemoryAllocator* GetAllocator( |
| absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter, |
| OpKernelContext* ctx, const XlaPlatformInfo& platform_info) { |
| if (platform_info.custom_allocator()) { |
| return platform_info.custom_allocator(); |
| } |
| if (!ctx->op_device_context()) { |
| // Stream is not set for the host platform. |
| se::Platform* platform = |
| se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()) |
| .ValueOrDie(); |
| tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform); |
| return &tf_allocator_adapter->value(); |
| } |
| // platform_info. |
| tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), |
| ctx->op_device_context()->stream()); |
| return &tf_allocator_adapter->value(); |
| } |
| |
| } // namespace |
| |
| XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, |
| const std::vector<int>& constants, |
| const std::vector<int>& resources, |
| const NameAttrList& function) |
| : OpKernel(ctx), |
| constants_(constants), |
| resources_(resources), |
| function_(function), |
| platform_info_(PlatformInfoFromContext(ctx)) {} |
| |
| static Status BuildCompilationCache(OpKernelContext* ctx, |
| const XlaPlatformInfo& platform_info, |
| XlaCompilationCache** cache) { |
| if (platform_info.xla_device_metadata()) { |
| *cache = new XlaCompilationCache( |
| platform_info.xla_device_metadata()->client(), |
| platform_info.xla_device_metadata()->jit_device_type()); |
| return Status::OK(); |
| } |
| |
| auto platform = |
| se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()); |
| if (!platform.ok()) { |
| return platform.status(); |
| } |
| |
| xla::StatusOr<xla::Compiler*> compiler_for_platform = |
| xla::Compiler::GetForPlatform(platform.ValueOrDie()); |
| if (!compiler_for_platform.ok()) { |
| // In some rare cases (usually in unit tests with very small clusters) we |
| // may end up transforming an XLA cluster with at least one GPU operation |
| // (which would normally force the cluster to be compiled using XLA:GPU) |
| // into an XLA cluster with no GPU operations (i.e. containing only CPU |
| // operations). Such a cluster can fail compilation (in way that |
| // MarkForCompilation could not have detected) if the CPU JIT is not linked |
| // in. |
| // |
| // So bail out of _XlaCompile in this case, and let the executor handle the |
| // situation for us. |
| const Status& status = compiler_for_platform.status(); |
| if (status.code() == error::NOT_FOUND) { |
| return errors::Unimplemented("Could not find compiler for platform ", |
| platform.ValueOrDie()->Name(), ": ", |
| status.ToString()); |
| } |
| } |
| |
| xla::LocalClientOptions client_options; |
| client_options.set_platform(platform.ValueOrDie()); |
| client_options.set_intra_op_parallelism_threads( |
| ctx->device()->tensorflow_cpu_worker_threads()->num_threads); |
| auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); |
| if (!client.ok()) { |
| return client.status(); |
| } |
| const XlaOpRegistry::DeviceRegistration* registration; |
| if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(), |
| ®istration)) { |
| return errors::InvalidArgument("No JIT device registered for ", |
| platform_info.device_type().type()); |
| } |
| *cache = new XlaCompilationCache( |
| client.ValueOrDie(), DeviceType(registration->compilation_device_name)); |
| return Status::OK(); |
| } |
| |
| static Status CompileToLocalExecutable( |
| OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, |
| const XlaPlatformInfo& platform_info, absl::Span<const int> resources, |
| absl::Span<const int> constants, bool lazy, xla::LocalClient** client, |
| std::map<int, OptionalTensor>* variables, |
| const XlaCompiler::CompilationResult** kernel, |
| xla::LocalExecutable** executable) { |
| // We store information about the JIT-compiled XLA computation |
| // in the ResourceMgr. |
| ResourceMgr* rm = ctx->resource_manager(); |
| if (!rm) { |
| return errors::Internal("No resource manager."); |
| } |
| |
| XlaCompilationCache* cache; |
| TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>( |
| rm->default_container(), "xla_cache", &cache, |
| [&](XlaCompilationCache** cache) { |
| return BuildCompilationCache(ctx, platform_info, cache); |
| })); |
| // Hold the reference to the JIT during evaluation. (We could probably |
| // free it sooner because the ResourceMgr will retain a reference, but |
| // this is more obviously correct.) |
| core::ScopedUnref cache_ref(cache); |
| |
| TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables)); |
| *client = static_cast<xla::LocalClient*>(cache->client()); |
| |
| absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter; |
| XlaCompiler::Options options; |
| options.client = *client; |
| if (ctx->op_device_context() != nullptr) { |
| options.device_ordinal = |
| ctx->op_device_context()->stream()->parent()->device_ordinal(); |
| } |
| options.device_type = cache->device_type(); |
| options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); |
| options.graph_def_version = ctx->function_library()->graph_def_version(); |
| options.allow_cpu_custom_calls = |
| (platform_info.platform_id() == se::host::kHostPlatformId); |
| options.device_allocator = |
| GetAllocator(&tf_allocator_adapter, ctx, platform_info); |
| if (platform_info.xla_device_metadata()) { |
| options.shape_representation_fn = |
| platform_info.xla_device_metadata()->shape_representation_fn(); |
| } |
| // If reference variables are not present in the graph, we can safely alias |
| // passthrough parameters without performing a copy. |
| options.alias_passthrough_params = |
| !has_ref_vars && !platform_info.is_on_xla_device(); |
| |
| std::map<int, Tensor> constant_args; |
| for (int i : constants) { |
| constant_args.insert({i, ctx->input(i)}); |
| } |
| XlaCompiler::CompileOptions compile_options; |
| compile_options.is_entry_computation = true; |
| compile_options.resolve_compile_time_constants = |
| !GetXlaOpsCommonFlags().tf_xla_noresolve_compile_time_constants; |
| // Optimization: where possible, have the computation return a naked array |
| // rather than a one-element tuple. |
| compile_options.always_return_tuple = false; |
| |
| std::vector<XlaCompiler::Argument> args; |
| TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( |
| constant_args, *variables, ctx, &args)); |
| return cache->Compile(options, function, args, compile_options, |
| lazy ? XlaCompilationCache::CompileMode::kLazy |
| : XlaCompilationCache::CompileMode::kStrict, |
| kernel, executable); |
| } |
| |
| void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { |
| VLOG(1) << "XlaLocalLaunchOpBase::Compute " |
| << Canonicalize(function_.name(), AttrSlice(&function_.attr())); |
| |
| xla::LocalClient* client; |
| const XlaCompiler::CompilationResult* kernel; |
| xla::LocalExecutable* executable; |
| std::map<int, OptionalTensor> variables; |
| |
| { |
| Status s = CompileToLocalExecutable( |
| ctx, function_, /*has_ref_vars=*/true, platform_info_, resources_, |
| constants_, /*lazy=*/false, &client, &variables, &kernel, &executable); |
| if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU || |
| platform_info_.device_type().type_string() == DEVICE_GPU)) { |
| // Suggest auto jit if the failure was with GPU or CPU. |
| errors::AppendToMessage(&s, |
| xla::status_macros::kPossibleAutoJitAlternative); |
| } |
| |
| OP_REQUIRES_OK(ctx, s); |
| } |
| |
| se::Stream* stream = |
| ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; |
| |
| VLOG(1) << "Executing XLA Computation..."; |
| |
| absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter; |
| se::DeviceMemoryAllocator* allocator = |
| GetAllocator(&tf_allocator_adapter, ctx, platform_info_); |
| XlaComputationLaunchContext launch_context( |
| client, allocator, |
| /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), |
| platform_info_.UseMultipleStreams()); |
| launch_context.PopulateInputs(ctx, kernel, variables, |
| /*missing_ctx_input_prefix=*/0); |
| |
| // Execute the computation. |
| VLOG(2) << "Executing computation."; |
| xla::ExecutableRunOptions run_options; |
| run_options.set_stream(stream); |
| run_options.set_allocator(allocator); |
| run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); |
| run_options.set_rng_seed(GetXLARandomSeed()); |
| Env* env = Env::Default(); |
| auto start_time = env->NowMicros(); |
| |
| xla::StatusOr<xla::ScopedShapedBuffer> run_result; |
| if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) { |
| run_result = executable->Run(launch_context.arguments(), run_options); |
| } else { |
| run_result = executable->RunAsync(launch_context.arguments(), run_options); |
| } |
| OP_REQUIRES(ctx, run_result.ok(), run_result.status()); |
| |
| auto elapsed = env->NowMicros() - start_time; |
| VLOG(2) << "Elapsed time: " << elapsed << "us"; |
| |
| const xla::HloInputOutputAliasConfig& input_output_alias = |
| executable->executable()->module().input_output_alias_config(); |
| OP_REQUIRES_OK( |
| ctx, launch_context.PopulateOutputs( |
| ctx, kernel, run_result.ConsumeValueOrDie(), |
| /*missing_ctx_input_prefix=*/0, input_output_alias, variables)); |
| VLOG(1) << "Done"; |
| } |
| |
| namespace { |
| // Helper static functions to construct parameters for |
| // XlaLocalLaunchBase constructor from OpKernelConstruction. |
| std::vector<int> ConstantsVector(OpKernelConstruction* ctx) { |
| DataTypeVector constant_types; |
| OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), |
| ctx->GetAttr("Tconstants", &constant_types)); |
| std::vector<int> constants(constant_types.size()); |
| std::iota(constants.begin(), constants.end(), 0); |
| return constants; |
| } |
| |
| std::vector<int> ResourcesVector(OpKernelConstruction* ctx) { |
| DataTypeVector constant_types; |
| OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), |
| ctx->GetAttr("Tconstants", &constant_types)); |
| |
| DataTypeVector arg_types; |
| OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), |
| ctx->GetAttr("Targs", &arg_types)); |
| |
| int num_resources; |
| OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), |
| ctx->GetAttr("Nresources", &num_resources)); |
| |
| std::vector<int> resources(num_resources); |
| std::iota(resources.begin(), resources.end(), |
| constant_types.size() + arg_types.size()); |
| return resources; |
| } |
| |
| NameAttrList FunctionAttr(OpKernelConstruction* ctx) { |
| const NameAttrList* func; |
| OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); |
| return *func; |
| } |
| |
| bool MustCompileAttr(OpKernelConstruction* ctx) { |
| bool must_compile; |
| OP_REQUIRES_OK_RETURN(ctx, false, |
| ctx->GetAttr("must_compile", &must_compile)); |
| return must_compile; |
| } |
| |
| bool HasRefVars(OpKernelConstruction* ctx) { |
| bool has_ref_vars; |
| OP_REQUIRES_OK_RETURN(ctx, false, |
| ctx->GetAttr(kXlaHasReferenceVarsAttr, &has_ref_vars)); |
| return has_ref_vars; |
| } |
| |
| } // namespace |
| |
| XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) |
| : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), |
| FunctionAttr(ctx)) {} |
| |
| XlaLocalLaunchOp::~XlaLocalLaunchOp() { |
| VLOG(1) << "XlaLocalLaunchOp destroyed"; |
| } |
| |
| XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) |
| : OpKernel(ctx), |
| constants_(ConstantsVector(ctx)), |
| resources_(ResourcesVector(ctx)), |
| function_(FunctionAttr(ctx)), |
| platform_info_(PlatformInfoFromContext(ctx)), |
| must_compile_(MustCompileAttr(ctx)), |
| has_ref_vars_(HasRefVars(ctx)) {} |
| |
| void XlaCompileOp::Compute(OpKernelContext* ctx) { |
| VLOG(3) << "XlaCompileOp " << def().name() |
| << (must_compile_ ? "(must-compile)" : ""); |
| xla::LocalClient* client; |
| const XlaCompiler::CompilationResult* kernel; |
| xla::LocalExecutable* executable; |
| std::map<int, OptionalTensor> variables; |
| |
| bool cannot_compile_cluster; |
| { |
| mutex_lock guard(cannot_compile_cluster_mu_); |
| cannot_compile_cluster = cannot_compile_cluster_; |
| } |
| |
| if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || |
| cannot_compile_cluster) { |
| executable = nullptr; |
| } else { |
| Status status = CompileToLocalExecutable( |
| ctx, function_, has_ref_vars_, platform_info_, resources_, constants_, |
| /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable); |
| if (must_compile_ || status.code() != error::UNIMPLEMENTED) { |
| OP_REQUIRES_OK(ctx, status); |
| } |
| |
| if (status.code() == error::UNIMPLEMENTED) { |
| LOG(WARNING) << "Compilation failed:" << status.ToString() |
| << ". Falling back to TF function call."; |
| |
| BroadcastOptimizationRemark( |
| XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString()) |
| .IgnoreError(); |
| executable = nullptr; |
| mutex_lock guard(cannot_compile_cluster_mu_); |
| cannot_compile_cluster_ = true; |
| } |
| } |
| |
| AllocatorAttributes host_alloc_attrs; |
| host_alloc_attrs.set_gpu_compatible(true); |
| host_alloc_attrs.set_on_host(true); |
| Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs); |
| |
| if (!executable) { |
| DCHECK(!must_compile_); |
| Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); |
| |
| Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); |
| compilation_successful.scalar<bool>()() = false; |
| ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({}))); |
| ctx->set_output(1, compilation_successful); |
| return; |
| } |
| |
| // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even |
| // if it didn't have to compile the cluster because of a compilation-cache |
| // hit. This is because we at least need new snapshots of the resource |
| // variables. |
| XlaExecutableClosureStore::KeyT key = |
| XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure( |
| client, executable, kernel, std::move(variables), constants_.size())); |
| |
| Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); |
| compilation_key.flat<tstring>()(0) = key; |
| |
| Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); |
| compilation_successful.flat<bool>()(0) = true; |
| |
| ctx->set_output(0, compilation_key); |
| ctx->set_output(1, compilation_successful); |
| } |
| |
| XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) |
| : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {} |
| |
| void XlaRunOp::Compute(OpKernelContext* ctx) { |
| VLOG(3) << "XlaRunOp " << def().name(); |
| Tensor key_tensor = ctx->input(ctx->num_inputs() - 1); |
| const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<tstring>()(0); |
| |
| XlaExecutableClosure closure = |
| XlaExecutableClosureStore::Global()->Consume(key); |
| |
| absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter; |
| se::DeviceMemoryAllocator* allocator = |
| GetAllocator(&tf_allocator_adapter, ctx, platform_info_); |
| XlaComputationLaunchContext launch_context( |
| closure.client(), allocator, |
| /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), |
| /*use_multiple_streams=*/platform_info_.UseMultipleStreams()); |
| |
| // We're missing the must-be-constant inputs, tell `PopulateInputs` |
| // about this. We don't actually need these inputs because they've |
| // already been baked into the compiled kernel. |
| { |
| tensorflow::profiler::TraceMe hlo_module_activity( |
| [&] { |
| return absl::StrCat( |
| "Populate Inputs (", |
| closure.compilation_result()->xla_input_shapes.size(), ")"); |
| }, |
| tensorflow::profiler::TraceMeLevel::kInfo); |
| |
| launch_context.PopulateInputs( |
| ctx, closure.compilation_result(), closure.resource_var_snapshots(), |
| /*missing_ctx_input_prefix=*/closure.num_constant_args()); |
| } |
| |
| se::Stream* stream = |
| ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; |
| xla::ExecutableRunOptions run_options; |
| run_options.set_stream(stream); |
| run_options.set_allocator(allocator); |
| run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); |
| run_options.set_rng_seed(GetXLARandomSeed()); |
| Env* env = Env::Default(); |
| auto start_time = env->NowMicros(); |
| |
| xla::StatusOr<xla::ScopedShapedBuffer> run_result; |
| if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) { |
| run_result = |
| closure.executable()->Run(launch_context.arguments(), run_options); |
| } else { |
| run_result = |
| closure.executable()->RunAsync(launch_context.arguments(), run_options); |
| } |
| OP_REQUIRES(ctx, run_result.ok(), run_result.status()); |
| |
| auto elapsed = env->NowMicros() - start_time; |
| VLOG(2) << "Elapsed time in computation: " << elapsed << "us"; |
| |
| const xla::HloInputOutputAliasConfig& input_output_alias = |
| closure.executable()->executable()->module().input_output_alias_config(); |
| |
| tensorflow::profiler::TraceMe hlo_module_activity( |
| [&] { |
| return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")"); |
| }, |
| tensorflow::profiler::TraceMeLevel::kInfo); |
| |
| OP_REQUIRES_OK( |
| ctx, |
| launch_context.PopulateOutputs( |
| ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(), |
| /*missing_ctx_input_prefix=*/closure.num_constant_args(), |
| input_output_alias, closure.resource_var_snapshots())); |
| } |
| |
| XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
| |
| void XlaMergeOp::Compute(OpKernelContext* ctx) { |
| VLOG(3) << "XlaMergeOp " << def().name(); |
| int i = 0; |
| if (ctx->has_input(i) || ctx->has_input(++i)) { |
| ctx->set_output(0, ctx->input(i)); |
| } |
| } |
| |
| REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); |
| |
| REGISTER_KERNEL_BUILDER(Name("XlaLaunch") |
| .Device(DEVICE_GPU) |
| .HostMemory("constants") |
| .HostMemory("resources"), |
| XlaLocalLaunchOp); |
| |
| REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp); |
| REGISTER_KERNEL_BUILDER(Name("_XlaCompile") |
| .Device(DEVICE_GPU) |
| .HostMemory("constants") |
| .HostMemory("key") |
| .HostMemory("compilation_successful") |
| .HostMemory("resources"), |
| XlaCompileOp); |
| |
| REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp); |
| REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU).HostMemory("key"), |
| XlaRunOp); |
| |
| REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_CPU), XlaMergeOp); |
| REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_GPU), XlaMergeOp); |
| |
| } // namespace tensorflow |