blob: bbd3dd3e57b024d16af8d1080d0347e7f8dd14cf [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/kernel_def_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/kernel_def.pb_text.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
namespace {
// Helper for KernelAttrsMatch().
bool InTypeList(DataType dt, const AttrValue& type_list) {
for (int in_list : type_list.list().type()) {
if (dt == in_list) return true;
}
return false;
}
} // namespace
Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs,
bool* match) {
*match = false;
for (const auto& constraint : kernel_def.constraint()) {
if (constraint.allowed_values().list().type_size() == 0) {
return errors::Unimplemented(
"KernelDef '", ProtoShortDebugString(kernel_def),
" has constraint on attr '", constraint.name(),
"' with unsupported type: ",
SummarizeAttrValue(constraint.allowed_values()));
}
const AttrValue* found = attrs.Find(constraint.name());
if (found) {
if (found->type() != DT_INVALID) {
if (!InTypeList(found->type(), constraint.allowed_values())) {
return Status::OK();
}
} else {
if (!AttrValueHasType(*found, "list(type)").ok()) {
return errors::InvalidArgument(
"KernelDef '", ProtoShortDebugString(kernel_def),
"' has constraint on attr '", constraint.name(),
"' that has value '", SummarizeAttrValue(*found),
"' that does not have type 'type' or 'list(type)' in NodeDef "
"'",
attrs.SummarizeNode(), "'");
}
for (int t : found->list().type()) {
if (!InTypeList(static_cast<DataType>(t),
constraint.allowed_values())) {
return Status::OK();
}
}
}
} else {
return errors::InvalidArgument(
"OpKernel '", kernel_def.op(), "' has constraint on attr '",
constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
"', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
}
}
*match = true;
return Status::OK();
}
} // namespace tensorflow