blob: 6fe1f4d2090682781ee7f6e0b9cd94de9189e268 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#include <cstdlib>
#include <cstring>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/kernel_def.pb_text.h"
#include "tensorflow/core/framework/kernel_def_util.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/platform_strings.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
const DataTypeSlice expected_outputs,
const DataTypeSlice inputs,
const DataTypeSlice outputs) {
bool signature_mismatch = false;
if (inputs.size() != expected_inputs.size()) signature_mismatch = true;
for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) {
if (!TypesCompatible(expected_inputs[i], inputs[i])) {
signature_mismatch = true;
}
}
if (outputs.size() != expected_outputs.size()) signature_mismatch = true;
for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) {
if (!TypesCompatible(expected_outputs[i], outputs[i])) {
signature_mismatch = true;
}
}
if (signature_mismatch) {
return errors::InvalidArgument(
"Signature mismatch, have: ", DataTypeSliceString(inputs), "->",
DataTypeSliceString(outputs),
" expected: ", DataTypeSliceString(expected_inputs), "->",
DataTypeSliceString(expected_outputs));
}
return Status::OK();
}
} // namespace
// OpKernel ------------------------------------------------------------------
OpKernel::OpKernel(OpKernelConstruction* context)
: OpKernel(context, MakeUnique<const NodeDef>(context->def())) {}
OpKernel::OpKernel(OpKernelConstruction* context,
std::unique_ptr<const NodeDef> node_def)
: def_(std::move(node_def)),
input_types_(context->input_types().begin(),
context->input_types().end()),
input_memory_types_(context->input_memory_types().begin(),
context->input_memory_types().end()),
output_types_(context->output_types().begin(),
context->output_types().end()),
output_memory_types_(context->output_memory_types().begin(),
context->output_memory_types().end()),
input_name_map_(context->num_inputs()),
output_name_map_(context->num_outputs()),
graph_def_version_(context->graph_def_version()),
cost_estimate_(OpKernel::kInitialCostEstimateCycles) {
OP_REQUIRES_OK(context,
NameRangesForNode(*def_, *context->op_def_, &input_name_map_,
&output_name_map_));
OP_REQUIRES_OK(context, CheckOpDeprecation(*context->op_def_,
context->graph_def_version()));
// Kernels executing on GPU/SYCL tie very few resources on the CPU where the
// scheduler runs: we consider them as inexpensive.
expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
context->device_type() != DeviceType(DEVICE_SYCL);
}
OpKernel::~OpKernel() {}
const uint64 OpKernel::kInitialCostEstimateCycles;
const uint64 OpKernel::kOpIsExpensiveThresholdCycles;
const uint64 OpKernel::kCostDecay;
const string& OpKernel::name() const { return def_->name(); }
const string& OpKernel::type_string() const { return def_->op(); }
const string& OpKernel::requested_device() const { return def_->device(); }
const string& OpKernel::requested_input(int i) const { return def_->input(i); }
// This static function exists only because device_attributes.pb.h is
// already included here, and it can't be introduced elsewhere.
/*static*/ int OpKernel::DeviceNumaNode(const DeviceBase* device) {
return device->attributes().locality().numa_node();
}
Status OpKernel::InputRange(StringPiece input_name, int* start,
int* stop) const {
const auto result = input_name_map_.find(input_name);
if (result == input_name_map_.end()) {
return errors::InvalidArgument("Unknown input name: ", input_name);
} else {
*start = result->second.first;
*stop = result->second.second;
return Status::OK();
}
}
Status OpKernel::OutputRange(StringPiece output_name, int* start,
int* stop) const {
const auto result = output_name_map_.find(output_name);
if (result == output_name_map_.end()) {
return errors::InvalidArgument("Unknown output name: ", output_name);
} else {
*start = result->second.first;
*stop = result->second.second;
return Status::OK();
}
}
Status OpKernel::MakeShape(const Tensor& shape, TensorShape* out) const {
if (!IsLegacyVector(shape.shape())) {
return errors::InvalidArgument(
"shape must be a vector of {int32,int64}, got shape ",
shape.shape().DebugString());
}
if (shape.dtype() == DataType::DT_INT32) {
auto vec = shape.flat<int32>();
return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
} else if (shape.dtype() == DataType::DT_INT64) {
auto vec = shape.flat<int64>();
return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
} else {
return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
}
}
void AsyncOpKernel::Compute(OpKernelContext* context) {
Notification n;
ComputeAsync(context, [&n]() { n.Notify(); });
n.WaitForNotification();
}
// PersistentTensor ----------------------------------------------------------
Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) {
// the caller has to have a valid context
CHECK(context);
return &tensor_;
}
Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
context->NotifyUseOfPersistentTensor(tensor_);
return &tensor_;
}
// OpKernelConstruction ------------------------------------------------------
OpKernelConstruction::OpKernelConstruction(
DeviceType device_type, DeviceBase* device, Allocator* allocator,
const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib,
const DataTypeSlice& input_types, const MemoryTypeSlice& input_memory_types,
const DataTypeSlice& output_types,
const MemoryTypeSlice& output_memory_types, int graph_def_version,
Status* status)
: device_type_(std::move(device_type)),
device_(device),
allocator_(allocator),
def_(node_def),
op_def_(op_def),
flib_(flib),
input_types_(input_types),
input_memory_types_(input_memory_types),
output_types_(output_types),
output_memory_types_(output_memory_types),
graph_def_version_(graph_def_version),
status_(status) {}
bool OpKernelConstruction::HasAttr(StringPiece attr_name) const {
return HasNodeAttr(def(), attr_name);
}
void OpKernelConstruction::SetStatus(const Status& status) {
status_->Update(status);
}
Status OpKernelConstruction::MatchSignature(
const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_,
output_types_);
}
Status OpKernelConstruction::allocate_temp(DataType type,
const TensorShape& shape,
Tensor* out_temp) {
AllocationAttributes attr;
attr.allocation_will_be_logged = true;
Tensor new_temp(allocator_, type, shape, attr);
if (!new_temp.IsInitialized()) {
return errors::ResourceExhausted(
"OOM when allocating temporary tensor with shape", shape.DebugString());
}
if (LogMemory::IsEnabled()) {
LogMemory::RecordTensorAllocation(
def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
}
*out_temp = new_temp;
return Status::OK();
}
Status OpKernelConstruction::allocate_persistent(
DataType type, const TensorShape& shape, PersistentTensor* out_persistent,
Tensor** out_tensor) {
// for now just do the same thing as allocate_temp
// TODO(misard) add specific memory tracking for persistent tensors
Tensor persistent;
Status s = allocate_temp(type, shape, &persistent);
if (!s.ok()) {
return s;
}
*out_persistent = PersistentTensor(persistent);
Tensor* allocated = out_persistent->AccessTensor(this);
if (out_tensor) {
*out_tensor = allocated;
}
return s;
}
// OpKernelContext -----------------------------------------------------------
const int OpKernelContext::Params::kNeverForward;
const int OpKernelContext::Params::kNoReservation;
OpKernelContext::OpKernelContext(Params* params)
: OpKernelContext(
params, static_cast<int>(params->op_kernel->output_types().size())) {}
OpKernelContext::OpKernelContext(Params* params, int num_outputs)
: params_(params),
outputs_(num_outputs),
temp_memory_allocated_(0),
persistent_memory_allocated_(0) {
params_->ensure_eigen_gpu_device();
if (params_->eigen_gpu_device != nullptr) {
Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
Status s = params_->device->ReinitializeGpuDevice(
this, params_->eigen_gpu_device, params_->op_device_context,
eigen_gpu_allocator);
if (!s.ok()) {
SetStatus(s);
}
}
if (params_->record_tensor_accesses) {
referenced_tensors_.Init();
}
}
OpKernelContext::~OpKernelContext() {
for (TensorValue& value : outputs_) {
if (!value.is_ref()) {
delete value.tensor;
}
}
if (params_->record_tensor_accesses) referenced_tensors_.Destroy();
if (params_->track_allocations && !wrapped_allocators_.empty()) {
LOG(WARNING) << "OpKernelContext is tracking allocations but they are not "
<< "being consumed by the StepStatsCollector.";
for (auto& wrapped_alloator : wrapped_allocators_) {
wrapped_alloator.second->GetRecordsAndUnRef();
}
}
}
Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
Allocator* allocator = nullptr;
if (TF_PREDICT_FALSE(attr.scope_id > 0)) {
allocator = params_->device->GetScopedAllocator(attr, step_id());
CHECK(allocator);
} else {
allocator = params_->device->GetAllocator(attr);
}
if (TF_PREDICT_FALSE(track_allocations())) {
mutex_lock lock(mu_);
for (const auto& wrapped : wrapped_allocators_) {
if (wrapped.first == allocator) {
return wrapped.second;
}
}
TrackingAllocator* wrapped_allocator =
new TrackingAllocator(allocator, params_->track_allocations);
wrapped_allocators_.push_back(std::make_pair(allocator, wrapped_allocator));
return wrapped_allocator;
} else {
return allocator;
}
}
void OpKernelContext::SetStatus(const Status& status) {
status_.Update(status);
}
void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) {
mutex_lock l(mu_);
// Keep a reference to the underlying memory around.
referenced_tensors_->Add(tensor);
}
Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was "
"expected");
}
if (input_is_ref(start)) {
return errors::InvalidArgument("OpKernel used ref input name '", name,
"' when non-ref input was expected");
}
*tensor = (*params_->inputs)[start].tensor;
record_tensor_reference(**tensor);
return Status::OK();
}
Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was "
"expected");
}
const TensorValue& value((*params_->inputs)[start]);
*dtype = value.dtype();
return Status::OK();
}
Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was expected");
}
*out_mutex = input_ref_mutex(start);
return Status::OK();
}
const Tensor& OpKernelContext::input(int index) {
CHECK_GE(index, 0);
CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
CHECK(!input_is_ref(index));
const Tensor& tensor = *((*params_->inputs)[index].tensor);
record_tensor_reference(tensor);
return tensor;
}
Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
CHECK_GE(index, 0);
CHECK_LT(index, num_inputs());
CHECK(input_is_ref(index));
// return a copy of the Ref acquired while holding the mutex
if (lock_held) {
Tensor& tensor = *((*params_->inputs)[index].tensor);
record_tensor_reference(tensor);
return tensor;
} else {
tf_shared_lock l(*input_ref_mutex(index));
Tensor& tensor = *((*params_->inputs)[index].tensor);
record_tensor_reference(tensor);
return tensor;
}
}
void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
bool lock_held) {
CHECK_GE(index, 0);
CHECK_LT(index, num_inputs());
CHECK(input_is_ref(index));
// should only modify the tensor while holding the mutex
if (lock_held) {
*(*params_->inputs)[index].tensor = tensor;
} else {
mutex_lock l(*input_ref_mutex(index));
*(*params_->inputs)[index].tensor = tensor;
}
record_tensor_reference(tensor);
}
void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
int output_index) {
CHECK_GE(input_index, 0);
CHECK_LT(input_index, num_inputs());
CHECK(input_is_ref(input_index));
set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref,
(*params_->inputs)[input_index].tensor);
}
bool OpKernelContext::forward_input_to_output_with_shape(
int input_index, int output_index, const TensorShape& output_shape,
Tensor** output) {
const auto output_attr = params_->output_attr_array == nullptr
? AllocatorAttributes()
: output_alloc_attr(output_index);
std::unique_ptr<Tensor> new_tensor = forward_input(
input_index, output_index, expected_output_dtype(output_index),
output_shape, output_memory_type(output_index), output_attr);
if (new_tensor != nullptr) {
// Transfer ownership to the output slot in OpKernelContext.
outputs_[output_index] = TensorValue(new_tensor.release());
*output = outputs_[output_index].tensor;
return true;
} else {
return false;
}
}
Status OpKernelContext::forward_input_to_output_with_shape(
StringPiece input_name, StringPiece output_name,
const TensorShape& output_shape, Tensor** output) {
int input_index, output_index, stop;
TF_RETURN_IF_ERROR(
params_->op_kernel->InputRange(input_name, &input_index, &stop));
if (stop != input_index + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
input_name,
"' when single-valued input was "
"expected");
}
TF_RETURN_IF_ERROR(
params_->op_kernel->OutputRange(output_name, &output_index, &stop));
if (stop != output_index + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
output_name,
"' when single-valued output was "
"expected");
}
if (!forward_input_to_output_with_shape(input_index, output_index,
output_shape, output)) {
return errors::FailedPrecondition("OpKernel could not forward input '",
input_name, "' to output '", output_name);
}
return Status::OK();
}
std::unique_ptr<Tensor> OpKernelContext::forward_input(
int input_index, int output_index, DataType output_dtype,
const TensorShape& output_shape, MemoryType output_memory_type,
const AllocatorAttributes& output_attr) {
CHECK_GE(input_index, 0);
CHECK_LT(input_index, num_inputs());
const TensorValue& input = (*params_->inputs)[input_index];
// Check whether at graph construction time this output was marked
// either for no forwarding or with a reservation for this input.
// If it's reserved for this input we'll skip the refcount and
// AllocatorAttribute checks.
// TODO(tucker): Maybe we should skip all of the checks?
bool never_forward =
(params_->forward_from_array != nullptr && output_index >= 0 &&
params_->forward_from_array[output_index] == Params::kNeverForward);
if (never_forward) return nullptr;
bool forward_expected =
(params_->forward_from_array != nullptr && output_index >= 0 &&
params_->forward_from_array[output_index] == input_index);
if (!forward_expected && params_->forward_from_array != nullptr) {
// Check for possibly conflicting forward.
for (int i = 0; i < num_outputs(); ++i) {
if (params_->forward_from_array[i] == input_index) {
// This input is reserved for output i.
return nullptr;
}
}
}
// Check that input tensor exists and is not a ref.
if (input.tensor == nullptr || input.is_ref()) {
CHECK(!forward_expected);
return nullptr;
}
// Check that input type matches.
if (input_dtype(input_index) != output_dtype) {
CHECK(!forward_expected);
return nullptr;
}
// Check that the input and output sizes are compatible.
if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
CHECK(!forward_expected);
return nullptr;
}
// Check that input and output memory types match, i.e.
// that they either both live in host or both live in device memory.
if (input_memory_type(input_index) != output_memory_type) {
CHECK(!forward_expected);
return nullptr;
}
if (!forward_expected) {
if (!input->RefCountIsOne()) {
return nullptr;
}
// Check that output allocator attributes are not more restrictive than
// input allocator attributes.
const auto input_attr = params_->input_alloc_attrs == nullptr
? AllocatorAttributes()
: input_alloc_attr(input_index);
if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
return nullptr;
}
}
auto output_tensor = MakeUnique<Tensor>();
CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
return output_tensor;
}
Status OpKernelContext::forward_input_or_allocate_temp(
gtl::ArraySlice<int> candidate_input_indices, DataType type,
const TensorShape& shape, const AllocatorAttributes& allocator_attr,
Tensor* out_temp) {
for (int input_index : candidate_input_indices) {
std::unique_ptr<Tensor> new_tensor =
forward_input(input_index, Params::kNoReservation /*output_index*/,
type, shape, DEVICE_MEMORY, allocator_attr);
if (new_tensor != nullptr) {
*out_temp = std::move(*new_tensor);
return Status::OK();
}
}
return allocate_temp(type, shape, out_temp, allocator_attr);
}
void OpKernelContext::delete_ref_input(int index, bool lock_held) {
CHECK_GE(index, 0);
CHECK_LT(index, num_inputs());
CHECK(input_is_ref(index));
// should only modify the tensor while holding the mutex
if (lock_held) {
delete (*params_->inputs)[index].tensor;
} else {
mutex_lock l(*input_ref_mutex(index));
delete (*params_->inputs)[index].tensor;
}
}
Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
bool lock_held) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was expected");
}
if (!input_is_ref(start)) {
return errors::InvalidArgument("OpKernel used non-ref input name '", name,
"' when ref input was expected");
}
// return a copy of the Ref acquired while holding the mutex
if (lock_held) {
*tensor = *(*params_->inputs)[start].tensor;
} else {
tf_shared_lock l(*input_ref_mutex(start));
*tensor = *(*params_->inputs)[start].tensor;
}
record_tensor_reference(*tensor);
return Status::OK();
}
Status OpKernelContext::replace_ref_input(StringPiece name,
const Tensor& tensor,
bool lock_held) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was expected");
}
if (!input_is_ref(start)) {
return errors::InvalidArgument("OpKernel used immutable input name '", name,
"' when ref input was expected");
}
replace_ref_input(start, tensor, lock_held);
return Status::OK();
}
Status OpKernelContext::input_list(StringPiece name, OpInputList* list) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
*list = OpInputList(this, start, stop);
return Status::OK();
}
Status OpKernelContext::mutable_input_list(StringPiece name,
OpMutableInputList* list) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
*list = OpMutableInputList(this, start, stop);
return Status::OK();
}
Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
*list = OpOutputList(this, start, stop);
return Status::OK();
}
void OpKernelContext::maybe_initialize_scope_id_set() {
if (allocated_scope_ids_ == nullptr) {
allocated_scope_ids_ = absl::make_unique<std::unordered_set<int32>>();
}
}
Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
Tensor** tensor) {
if (index < 0) {
return errors::Internal("allocate_output with bad index=", index,
" kernel=", params_->op_kernel->name());
}
if (index >= num_outputs()) {
return errors::Internal("allocate_output with bad index=", index,
" num_outputs=", num_outputs(),
" kernel=", params_->op_kernel->name());
}
bool forward_expected =
(params_->forward_from_array != nullptr && index >= 0 &&
params_->forward_from_array[index] >= 0);
if (forward_expected) {
return errors::Internal(
"Explicit allocate_output call where input forwarding required. Try "
"turning off the ScopedAllocator optimizer.");
}
AllocatorAttributes attr = output_alloc_attr(index);
return allocate_output(index, shape, tensor, attr);
}
Status OpKernelContext::allocate_output(StringPiece name,
const TensorShape& shape,
Tensor** tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
name,
"' when single-valued output was "
"expected");
}
return allocate_output(start, shape, tensor);
}
Status OpKernelContext::allocate_output(StringPiece name,
const TensorShape& shape,
Tensor** tensor,
AllocatorAttributes attr) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
name,
"' when single-valued output was "
"expected");
}
return allocate_output(start, shape, tensor, attr);
}
Status OpKernelContext::allocate_tensor(
DataType type, const TensorShape& shape, Tensor* out_tensor,
AllocatorAttributes attr, const AllocationAttributes& allocation_attr) {
Allocator* a = get_allocator(attr);
Tensor new_tensor(a, type, shape,
AllocationAttributes(allocation_attr.no_retry_on_failure,
/* allocation_will_be_logged= */ true,
allocation_attr.freed_by_func));
if (!new_tensor.IsInitialized()) {
return errors::ResourceExhausted(
"OOM when allocating tensor with shape", shape.DebugString(),
" and type ", DataTypeString(type), " on ", params_->device->name(),
" by allocator ", a->Name());
}
if (params_->log_memory) {
LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
params_->step_id, new_tensor);
}
record_tensor_reference(new_tensor);
*out_tensor = std::move(new_tensor);
return Status::OK();
}
Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
Tensor** output,
AllocatorAttributes attr) {
if (index < 0) {
return errors::Internal("allocate_output with bad index=", index,
" kernel=", params_->op_kernel->name());
}
if (index >= num_outputs()) {
return errors::Internal("allocate_output with bad index=", index,
" num_outputs=", outputs_.size(),
" kernel=", params_->op_kernel->name());
}
const DataType type = params_->op_kernel->output_type(index);
if (IsRefType(type)) {
return errors::Internal("allocate_output with ref type. index=", index,
" type=", type,
" kernel=", params_->op_kernel->name());
}
if (mutable_output(index) != nullptr) {
return errors::Internal("allocate_output on same index multiple times.",
" index = ", index,
" mutable_output(index) = ", mutable_output(index),
" kernel=", params_->op_kernel->name());
}
if (attr.scope_id > 0) {
maybe_initialize_scope_id_set();
if (!allocated_scope_ids_->insert(attr.scope_id).second) {
return errors::Internal(
"OpKernel ", params_->op_kernel->name(),
" called allocate_output at index ", index, " with scope_id ",
attr.scope_id,
" more than once. Try turning off the ScopedAllocator optimizer.");
}
}
auto output_tensor = MakeUnique<Tensor>();
Status s = allocate_tensor(type, shape, output_tensor.get(), attr);
if (s.ok()) {
outputs_[index] = TensorValue(output_tensor.release());
*output = outputs_[index].tensor;
}
return s;
}
Status OpKernelContext::allocate_temp(
DataType type, const TensorShape& shape, Tensor* out_temp,
AllocatorAttributes allocator_attr,
const AllocationAttributes& allocation_attr) {
if (allocator_attr.scope_id > 0) {
// We do not allow ScopedAllocator calls from allocate_temp. Unlike
// allocate_persistent where we return an error if a kernel provides a
// meaningful scope_id, here we clear the scope_id and return a temporary
// buffer. This is because it is legal for a kernel to call allocate_temp
// and then set_output with the temp tensor.
//
// We achieve memory correctness by forcing an allocation in set_output and
// copying over the tensor from the temp buffer. Kernels which would like
// to avoid this performance penalty should switch to calling
// allocate_output.
VLOG(2) << "Warning: OpKernel " << params_->op_kernel->name()
<< " called allocate_temp with scope_id " << allocator_attr.scope_id
<< ". Switch to allocate_output to avoid performance penalty.";
allocator_attr.scope_id = -1;
}
Status s =
allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr);
if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) {
Allocator* a = get_allocator(allocator_attr);
if (a->TracksAllocationSizes()) {
int64 alloc_size = a->AllocatedSize(out_temp->tensor_data().data());
record_temp_memory_allocation(alloc_size, *out_temp);
}
} else if (record_memory_consumption_) {
mutex_lock l(stats_mu_);
temp_memory_allocated_ += out_temp->TotalBytes();
}
return s;
}
Status OpKernelContext::allocate_persistent(DataType type,
const TensorShape& shape,
PersistentTensor* out_persistent,
Tensor** out_tensor,
AllocatorAttributes attr) {
if (attr.scope_id > 0) {
// ScopedAllocator cannot be used for persistent tensors, because these
// tensors may persist across kernel invocations/steps, whereas the backing
// tensor for the scoped allocator will be reallocated every step.
return errors::Internal(
"Unexpected call to allocate_persistent with scope_id ", attr.scope_id);
}
Tensor persistent;
Status s = allocate_tensor(type, shape, &persistent, attr);
if (s.ok()) {
*out_persistent = PersistentTensor(persistent);
Tensor* t = out_persistent->AccessTensor(this);
if (out_tensor) {
*out_tensor = t;
}
if (track_allocations()) {
Allocator* a = get_allocator(attr);
if (a->TracksAllocationSizes()) {
// Zero-byte Tensors don't use allocators: check and skip tracking.
AllocationDescription alloc_desc;
TensorReference tensor_ref(*t);
tensor_ref.FillDescription(&alloc_desc);
tensor_ref.Unref();
if (alloc_desc.allocated_bytes()) { // Non-zero sized tensor.
int64 alloc_size = a->AllocatedSize(t->tensor_data().data());
int64 alloc_id = a->AllocationId(t->tensor_data().data());
record_persistent_memory_allocation(alloc_size, alloc_id);
}
}
} else if (record_memory_consumption_) {
record_persistent_memory_allocation(t->TotalBytes());
}
}
return s;
}
Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
name,
"' when single-valued output was "
"expected");
}
set_output(start, tensor);
return Status::OK();
}
void OpKernelContext::set_output(int index, const Tensor& tensor) {
CHECK_GE(index, 0);
CHECK_LT(index, outputs_.size());
const DataType type = params_->op_kernel->output_type(index);
CHECK(!IsRefType(type));
CHECK_EQ(mutable_output(index), nullptr);
bool allocate_and_copy = false;
const bool never_forward =
(params_->forward_from_array != nullptr &&
params_->forward_from_array[index] == Params::kNeverForward);
if (never_forward) {
maybe_initialize_scope_id_set();
if (allocated_scope_ids_->find(output_alloc_attr(index).scope_id) ==
allocated_scope_ids_->end()) {
allocate_and_copy = true;
} else {
// The output at `index` must have been previously allocated via a call to
// `allocate_output(index, ...)`. That call would ensure that we return
// the correct slice of the ScopedAllocated buffer, so we do not
// re-allocate and copy here.
LOG(WARNING)
<< "OpKernel " << params_->op_kernel->name()
<< " called both allocate_output and set_output with scope_id "
<< output_alloc_attr(index).scope_id;
}
}
if (allocate_and_copy) {
// This output was marked to not be forwarded either during graph
// construction or grappler passes. Force an allocation and copy input to
// output.
VLOG(1) << "OpKernelContext set_output index " << index << " tensor "
<< tensor.DebugString() << " never_forward " << never_forward
<< " params_->forward_from_array[index] "
<< params_->forward_from_array[index] << " alloc_attr.scope_id "
<< output_alloc_attr(index).scope_id;
auto new_tensor = MakeUnique<Tensor>();
Status s = allocate_tensor(type, tensor.shape(), new_tensor.get(),
output_alloc_attr(index));
TF_CHECK_OK(s);
device()->CopyTensorInSameDevice(&tensor, new_tensor.get(),
op_device_context(), [](const Status&) {});
outputs_[index] = TensorValue(new_tensor.release());
} else {
// Input can be forwarded to output; incref on `tensor` and set output at
// `index` to this tensor.
record_tensor_reference(tensor);
outputs_[index] = TensorValue(new Tensor(tensor));
if (track_allocations() && tensor.TotalBytes() > 0) {
mutex_lock l(stats_mu_);
if (!temp_tensor_buffer_and_size_) {
return;
}
const auto it = std::find_if(
temp_tensor_buffer_and_size_->begin(),
temp_tensor_buffer_and_size_->end(),
[&tensor](const std::pair<const void*, int64>& e) {
return e.first ==
static_cast<const void*>(tensor.tensor_data().data());
});
if (it != temp_tensor_buffer_and_size_->end()) {
temp_memory_allocated_ -= it->second;
temp_tensor_buffer_and_size_->erase(it);
}
}
}
}
void OpKernelContext::set_output_ref(int index, mutex* mu,
Tensor* tensor_for_ref) {
CHECK_GE(index, 0);
CHECK_LT(index, outputs_.size());
CHECK(IsRefType(params_->op_kernel->output_type(index)));
record_tensor_reference(*tensor_for_ref);
outputs_[index] = TensorValue(mu, tensor_for_ref);
}
Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
Tensor* tensor_for_ref) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
name,
"' when single-valued output was "
"expected");
}
set_output_ref(start, mu, tensor_for_ref);
return Status::OK();
}
Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
name,
"' when single-valued output was "
"expected");
}
*tensor = mutable_output(start);
return Status::OK();
}
bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
const auto& inputs = *params_->inputs;
for (size_t i = 1; i < inputs.size(); ++i) {
if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
SetStatus(errors::InvalidArgument(
"Inputs to operation ", op->name(), " of type ", op->type_string(),
" must have the same size and shape. Input 0: ",
inputs[0]->shape().DebugString(), " != input ", i, ": ",
inputs[i]->shape().DebugString()));
return false;
}
}
return true;
}
Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
const DataTypeSlice expected_outputs) {
DataTypeVector inputs;
for (const TensorValue& t : *params_->inputs) {
inputs.push_back(t.dtype());
}
DataTypeVector outputs = params_->op_kernel->output_types();
return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
outputs);
}
void OpKernelContext::record_temp_memory_allocation(int64 size,
const Tensor& t) {
mutex_lock l(stats_mu_);
temp_memory_allocated_ += size;
if (!temp_tensor_buffer_and_size_) {
temp_tensor_buffer_and_size_.reset(
new gtl::InlinedVector<std::pair<const void*, int64>, 2>());
}
temp_tensor_buffer_and_size_->emplace_back(
static_cast<const void*>(t.tensor_data().data()), size);
}
int64 OpKernelContext::temp_memory_allocated() const {
mutex_lock l(stats_mu_);
return temp_memory_allocated_;
}
void OpKernelContext::record_persistent_memory_allocation(int64 size,
int64 alloc_id) {
mutex_lock l(stats_mu_);
persistent_memory_allocated_ += size;
if (alloc_id >= 0) {
if (!persistent_alloc_ids_) {
persistent_alloc_ids_.reset(new gtl::InlinedVector<int64, 2>());
}
persistent_alloc_ids_->push_back(alloc_id);
}
}
int64 OpKernelContext::persistent_memory_allocated() const {
mutex_lock l(stats_mu_);
return persistent_memory_allocated_;
}
std::vector<int64> OpKernelContext::persistent_alloc_ids() const {
mutex_lock l(stats_mu_);
if (persistent_alloc_ids_) {
return std::vector<int64>(persistent_alloc_ids_->begin(),
persistent_alloc_ids_->end());
} else {
return std::vector<int64>();
}
}
void OpKernelContext::clear_recorded_memory() {
mutex_lock l(stats_mu_);
temp_memory_allocated_ = 0;
persistent_memory_allocated_ = 0;
if (temp_tensor_buffer_and_size_) {
temp_tensor_buffer_and_size_->clear();
}
if (persistent_alloc_ids_) {
persistent_alloc_ids_->clear();
}
}
// OpKernel registration ------------------------------------------------------
struct KernelRegistration {
KernelRegistration(const KernelDef& d, StringPiece c,
std::unique_ptr<kernel_factory::OpKernelFactory> f)
: def(d), kernel_class_name(c), factory(std::move(f)) {}
const KernelDef def;
const string kernel_class_name;
std::unique_ptr<kernel_factory::OpKernelFactory> factory;
};
// This maps from 'op_type' + DeviceType to the set of KernelDefs and
// factory functions for instantiating the OpKernel that matches the
// KernelDef.
struct KernelRegistry {
mutex mu;
std::unordered_multimap<string, KernelRegistration> registry GUARDED_BY(mu);
};
#if defined(_WIN32)
static const char kKernelLibPattern[] = "libtfkernel*.dll";
#elif defined(__APPLE__)
static const char kKernelLibPattern[] = "libtfkernel*.dylib";
#else
static const char kKernelLibPattern[] = "libtfkernel*.so";
#endif
#define FEATURE(x) \
{ x, #x }
// Returns Status::OK if the dynamic library at the given path is safe to
// load with some level of confidence.
static Status IsProbablySafeToLoad(const string& path) {
// A map of platform string to required CPU feature.
using port::CPUFeature;
static const auto* feature_map =
new std::map<string, std::pair<CPUFeature, string>>{
{"__AVX512VL__=1", FEATURE(CPUFeature::AVX512VL)},
};
std::vector<std::string> platform_strings;
int result = GetPlatformStrings(path, &platform_strings);
if (result) {
return Status(error::Code::UNKNOWN, strerror(result));
}
if (platform_strings.empty()) {
return Status(error::Code::FAILED_PRECONDITION,
"Didn't find any platform strings");
}
std::vector<std::string> missing_features;
for (const auto& platform_string : platform_strings) {
const auto& entry = feature_map->find(platform_string);
if (entry != feature_map->end() &&
!port::TestCPUFeature(entry->second.first)) {
missing_features.emplace_back(entry->second.second);
}
}
if (!missing_features.empty()) {
string errmsg = "Missing CPU features: ";
errmsg.append(absl::StrJoin(missing_features, ", "));
return Status(errors::Code::FAILED_PRECONDITION, errmsg);
}
return Status::OK();
}
void LoadDynamicKernelsInternal() {
Env* env = Env::Default();
// Override to allow loading unsafe packages for development.
// DO NOT USE UNLESS YOU KNOW WHAT ABI ISSUES YOU CAN ENCOUNTER.
bool override_abi_check =
strcmp(getenv("TF_REALLY_LOAD_UNSAFE_PACKAGES"), "1") == 0;
string bazel_kernel_dir =
io::JoinPath(env->GetRunfilesDir(), "tensorflow", "core", "kernels");
std::vector<string> files;
Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files);
if (s_kernel_dir.ok()) {
string dll_spec = io::JoinPath(bazel_kernel_dir, kKernelLibPattern);
for (const auto& file : files) {
string fullpath = io::JoinPath(bazel_kernel_dir, file);
if (env->MatchPath(fullpath, dll_spec)) {
Status s = IsProbablySafeToLoad(fullpath);
if (!s.ok() && override_abi_check) {
LOG(WARNING) << "Loading UNSAFE library " << fullpath
<< " because ABI check override is set: "
<< s.error_message();
}
if (s.ok() || override_abi_check) {
// TODO(gunan): Store the handles to the opened files.
void* unused_filehandle;
TF_CHECK_OK(env->LoadLibrary(fullpath.c_str(), &unused_filehandle));
} else {
LOG(WARNING) << "Not loading plugin library " << fullpath << ": "
<< s.error_message();
}
}
}
}
}
// Mechanism for loading existing kernel libraries.
void LoadDynamicKernels() {
// TODO(gunan): As more features are available, add intelligent kernel
// selection, and dropping unsuitable kernel logic here.
static std::once_flag dll_loader_flag;
std::call_once(dll_loader_flag, LoadDynamicKernelsInternal);
}
void* GlobalKernelRegistry() {
static KernelRegistry* global_kernel_registry = new KernelRegistry;
return global_kernel_registry;
}
static KernelRegistry* GlobalKernelRegistryTyped() {
#ifdef AUTOLOAD_DYNAMIC_KERNELS
LoadDynamicKernels();
#endif // AUTOLOAD_DYNAMIC_KERNELS
return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
}
static string Key(StringPiece op_type, const DeviceType& device_type,
StringPiece label) {
return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
label);
}
namespace kernel_factory {
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
StringPiece kernel_class_name,
std::unique_ptr<OpKernelFactory> factory) {
// See comments in register_kernel::Name in header for info on _no_register.
if (kernel_def->op() != "_no_register") {
const string key =
Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
kernel_def->label());
// To avoid calling LoadDynamicKernels DO NOT CALL GlobalKernelRegistryTyped
// here.
// InitInternal gets called by static initializers, so it ends up executing
// before main. This causes LoadKernelLibraries function to get called
// before some file libraries can initialize, which in turn crashes the
// program flakily. Until we get rid of static initializers in kernel
// registration mechanism, we have this workaround here.
auto global_registry =
reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
mutex_lock l(global_registry->mu);
global_registry->registry.emplace(
key,
KernelRegistration(*kernel_def, kernel_class_name, std::move(factory)));
}
delete kernel_def;
}
OpKernel* OpKernelRegistrar::PtrOpKernelFactory::Create(
OpKernelConstruction* context) {
return (*create_func_)(context);
}
} // namespace kernel_factory
namespace {
static const StringPiece kKernelAttr("_kernel");
// TODO(irving): Replace with const Node& version below.
Status FindKernelRegistration(
const DeviceType& device_type, StringPiece node_name,
bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
StringPiece node_op, AttrSlice node_attrs, const KernelRegistration** reg,
bool* was_attr_mismatch) {
*reg = nullptr;
*was_attr_mismatch = false;
// Label defaults to empty if not found in NodeDef.
const string& label = GetNodeAttrString(node_attrs, kKernelAttr);
const string key = Key(node_op, device_type, label);
auto typed_registry = GlobalKernelRegistryTyped();
tf_shared_lock lock(typed_registry->mu);
auto regs = typed_registry->registry.equal_range(key);
for (auto iter = regs.first; iter != regs.second; ++iter) {
// If there is a kernel registered for the op and device_type,
// check that the attrs match.
bool match;
TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_attrs, &match));
if (match) {
if (*reg != nullptr) {
return errors::InvalidArgument(
"Multiple OpKernel registrations match NodeDef '",
FormatNodeDefForError(node_name, has_experimental_debug_info,
experimental_debug_info),
"': '", ProtoShortDebugString((*reg)->def), "' and '",
ProtoShortDebugString(iter->second.def), "'");
}
*reg = &iter->second;
} else {
*was_attr_mismatch = true;
}
}
// Check if no device specific registrations found. If not, try finding a
// default kernel.
if (*reg == nullptr &&
!IsSymbolicExecutionDevice(device_type.type_string())) {
const string default_key = Key(node_op, DEVICE_DEFAULT, label);
auto regs = typed_registry->registry.equal_range(default_key);
for (auto iter = regs.first; iter != regs.second; ++iter) {
// If there is a kernel registered for the op and device_type,
// check that the attrs match.
bool match;
TF_RETURN_IF_ERROR(
KernelAttrsMatch(iter->second.def, node_attrs, &match));
if (match) {
if (*reg != nullptr) {
return errors::InvalidArgument(
"Multiple Default OpKernel registrations match NodeDef '",
FormatNodeDefForError(node_name, has_experimental_debug_info,
experimental_debug_info),
"': '", ProtoShortDebugString((*reg)->def), "' and '",
ProtoShortDebugString(iter->second.def), "'");
}
*reg = &iter->second;
} else {
*was_attr_mismatch = true;
}
}
if (*reg != nullptr) {
VLOG(1) << "No device-specific kernels found for NodeDef '"
<< FormatNodeDefForError(node_name, has_experimental_debug_info,
experimental_debug_info)
<< "'"
<< "Will fall back to a default kernel." << std::endl;
}
}
return Status::OK();
}
Status FindKernelRegistration(const DeviceType& device_type,
const NodeDef& node_def,
const KernelRegistration** reg,
bool* was_attr_mismatch) {
return FindKernelRegistration(
device_type, node_def.name(), node_def.has_experimental_debug_info(),
node_def.experimental_debug_info(), node_def.op(),
AttrSlice(&node_def.attr()), reg, was_attr_mismatch);
}
} // namespace
bool KernelDefAvailable(const DeviceType& device_type,
const NodeDef& node_def) {
const KernelRegistration* reg = nullptr;
bool was_attr_mismatch;
Status result =
FindKernelRegistration(device_type, node_def, &reg, &was_attr_mismatch);
return result.ok() && reg != nullptr;
}
// TODO(irving): Change const NodeDef& to const Node&
Status FindKernelDef(
const DeviceType& device_type, StringPiece node_name,
bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
const KernelDef** def, string* kernel_class_name) {
const KernelRegistration* reg = nullptr;
bool was_attr_mismatch;
TF_RETURN_IF_ERROR(FindKernelRegistration(
device_type, node_name, has_experimental_debug_info,
experimental_debug_info, node_op, node_attrs, &reg, &was_attr_mismatch));
if (reg == nullptr) {
Status s = errors::NotFound(
"No registered '", node_op, "' OpKernel for ",
DeviceTypeString(device_type), " devices compatible with node ",
FormatNodeDefForError(node_name, has_experimental_debug_info,
experimental_debug_info));
if (was_attr_mismatch) {
errors::AppendToMessage(
&s, " (OpKernel was found, but attributes didn't match) ",
"Requested Attributes: ",
SummarizeAttrsHelper(node_attrs, node_device));
}
errors::AppendToMessage(&s,
". Registered:", KernelsRegisteredForOp(node_op));
return s;
}
if (def != nullptr) *def = &reg->def;
if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
return Status::OK();
}
Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
const KernelDef** def, string* kernel_class_name) {
return FindKernelDef(
device_type, node_def.name(), node_def.has_experimental_debug_info(),
node_def.experimental_debug_info(), node_def.op(), node_def.device(),
AttrSlice(&node_def.attr()), def, kernel_class_name);
}
Status SupportedDeviceTypesForNode(
const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
PrioritizedDeviceTypeVector* prioritized_device_types) {
// TODO(zhifengc): Changes the callers (SimplePlacer and
// DynamicPlacer) to consider the possibility that 'def' is call to
// a user-defined function and only calls this
// SupportedDeviceTypesForNode for primitive ops.
const OpRegistrationData* op_reg_data;
const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data);
if (s.ok()) {
for (const DeviceType& device_type : prioritized_types) {
const KernelRegistration* reg = nullptr;
bool was_attr_mismatch;
TF_RETURN_IF_ERROR(
FindKernelRegistration(device_type, def, &reg, &was_attr_mismatch));
if (reg != nullptr) {
int32 priority = reg->def.priority();
prioritized_device_types->emplace_back(device_type, priority);
}
}
std::sort(prioritized_device_types->begin(),
prioritized_device_types->end(),
[](const std::pair<DeviceType, int32>& a,
const std::pair<DeviceType, int32>& b) {
return a.second > b.second;
});
} else {
// Assumes that all device types support this node.
for (const DeviceType& device_type : prioritized_types) {
prioritized_device_types->push_back(std::make_pair(device_type, 0));
}
}
return Status::OK();
}
void LogAllRegisteredKernels() {
KernelList kernel_list = GetAllRegisteredKernels();
for (const auto& kernel_def : kernel_list.kernel()) {
LOG(INFO) << "OpKernel ('" << ProtoShortDebugString(kernel_def) << "')";
}
}
KernelList GetAllRegisteredKernels() {
return GetFilteredRegisteredKernels([](const KernelDef& k) { return true; });
}
KernelList GetFilteredRegisteredKernels(
const std::function<bool(const KernelDef&)>& predicate) {
KernelRegistry* const typed_registry = GlobalKernelRegistryTyped();
KernelList kernel_list;
tf_shared_lock lock(typed_registry->mu);
kernel_list.mutable_kernel()->Reserve(typed_registry->registry.size());
for (const auto& p : typed_registry->registry) {
const KernelDef& kernel_def = p.second.def;
if (predicate(kernel_def)) {
*kernel_list.add_kernel() = kernel_def;
}
}
return kernel_list;
}
KernelList GetRegisteredKernelsForOp(StringPiece op_name) {
auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; };
return GetFilteredRegisteredKernels(op_pred);
}
string KernelsRegisteredForOp(StringPiece op_name) {
KernelList kernel_list = GetRegisteredKernelsForOp(op_name);
if (kernel_list.kernel_size() == 0) return " <no registered kernels>\n";
string ret;
for (const auto& kernel_def : kernel_list.kernel()) {
strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'");
if (!kernel_def.label().empty()) {
strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
}
for (int i = 0; i < kernel_def.constraint_size(); ++i) {
strings::StrAppend(
&ret, "; ", kernel_def.constraint(i).name(), " in ",
SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
}
strings::StrAppend(&ret, "\n");
}
return ret;
}
std::unique_ptr<OpKernel> CreateOpKernel(
DeviceType device_type, DeviceBase* device, Allocator* allocator,
const NodeDef& node_def, int graph_def_version, Status* status) {
OpKernel* kernel = nullptr;
*status = CreateOpKernel(std::move(device_type), device, allocator, nullptr,
node_def, graph_def_version, &kernel);
return std::unique_ptr<OpKernel>(kernel);
}
Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
Allocator* allocator, FunctionLibraryRuntime* flib,
const NodeDef& node_def, int graph_def_version,
OpKernel** kernel) {
VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
// Look up the Op registered for this op name.
const OpDef* op_def = nullptr;
Status s = OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def);
if (!s.ok()) return s;
// Validate node_def against OpDef.
s = ValidateNodeDef(node_def, *op_def);
if (!s.ok()) return s;
// Look up kernel registration.
const KernelRegistration* registration;
bool was_attr_mismatch;
s = FindKernelRegistration(device_type, node_def, &registration,
&was_attr_mismatch);
if (!s.ok()) {
errors::AppendToMessage(&s, " when instantiating ", node_def.op());
return s;
}
if (registration == nullptr) {
s.Update(errors::NotFound("No registered '", node_def.op(),
"' OpKernel for '", DeviceTypeString(device_type),
"' devices compatible with node ",
FormatNodeDefForError(node_def)));
if (was_attr_mismatch) {
errors::AppendToMessage(
&s, " (OpKernel was found, but attributes didn't match) ",
"Requested Attributes: ", SummarizeAttrs(node_def));
}
errors::AppendToMessage(
&s, ". Registered:", KernelsRegisteredForOp(node_def.op()));
return s;
}
// Get signature from the OpDef & NodeDef
DataTypeVector inputs;
DataTypeVector outputs;
s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
if (!s.ok()) {
errors::AppendToMessage(&s, " for node: ", FormatNodeDefForError(node_def));
return s;
}
// We are creating a kernel for an op registered in
// OpRegistry::Global(), we consult the kernel registry to decide
// the kernel's input and output memory types.
MemoryTypeVector input_memory_types;
MemoryTypeVector output_memory_types;
TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type,
node_def, &input_memory_types,
&output_memory_types));
// Everything needed for OpKernel construction.
OpKernelConstruction context(
device_type, device, allocator, &node_def, op_def, flib, inputs,
input_memory_types, outputs, output_memory_types, graph_def_version, &s);
*kernel = registration->factory->Create(&context);
if (!s.ok()) {
delete *kernel;
*kernel = nullptr;
}
return s;
}
namespace {
bool FindArgInOp(StringPiece arg_name,
const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
for (const auto& arg : args) {
if (arg_name == arg.name()) {
return true;
}
}
return false;
}
} // namespace
Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) {
auto typed_registry = GlobalKernelRegistryTyped();
tf_shared_lock lock(typed_registry->mu);
for (const auto& key_registration : typed_registry->registry) {
const KernelDef& kernel_def(key_registration.second.def);
const OpRegistrationData* op_reg_data;
const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data);
if (!status.ok()) {
// TODO(josh11b): Make this a hard error.
LOG(ERROR) << "OpKernel ('" << ProtoShortDebugString(kernel_def)
<< "') for unknown op: " << kernel_def.op();
continue;
}
const OpDef& op_def = op_reg_data->op_def;
for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
if (!FindArgInOp(host_memory_arg, op_def.input_arg()) &&
!FindArgInOp(host_memory_arg, op_def.output_arg())) {
return errors::InvalidArgument(
"HostMemory arg '", host_memory_arg,
"' not found in OpDef: ", SummarizeOpDef(op_def));
}
}
}
return Status::OK();
}
template <>
const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const {
return eigen_cpu_device();
}
template <>
const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
return eigen_gpu_device();
}
#ifdef TENSORFLOW_USE_SYCL
template <>
const Eigen::SyclDevice& OpKernelContext::eigen_device() const {
return eigen_sycl_device();
}
#endif
void OpKernelConstruction::CtxFailure(const Status& s) {
VLOG(1) << s;
SetStatus(s);
}
void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
LOG(WARNING) << s;
SetStatus(s);
}
void OpKernelConstruction::CtxFailure(const char* file, int line,
const Status& s) {
VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
<< " : " << s;
SetStatus(s);
}
void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
const Status& s) {
LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
<< " : " << s;
SetStatus(s);
}
void OpKernelContext::CtxFailure(const Status& s) {
VLOG(1) << s;
SetStatus(s);
}
void OpKernelContext::CtxFailureWithWarning(const Status& s) {
LOG(WARNING) << s;
SetStatus(s);
}
void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
<< " : " << s;
SetStatus(s);
}
void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
const Status& s) {
LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
<< " : " << s;
SetStatus(s);
}
void CheckNotInComputeAsync(OpKernelContext* ctx,
const char* correct_macro_name) {
CHECK_EQ(nullptr, ctx->op_kernel().AsAsync())
<< "Use " << correct_macro_name << " in AsyncOpKernel implementations.";
}
} // namespace tensorflow