| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/jit/node_matchers.h" |
| |
| #include <utility> |
| #include "absl/algorithm/container.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/str_join.h" |
| #include "absl/strings/str_split.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| |
| namespace tensorflow { |
| namespace testing { |
| namespace matchers { |
| namespace { |
| |
| using impl::NodeMatcherProperties; |
| |
| string IndentAllButFirstLine(absl::string_view text) { |
| std::vector<std::string> lines = absl::StrSplit(text, '\n'); |
| for (int i = 1; i < lines.size(); i++) { |
| lines[i].insert(0, " "); |
| } |
| return absl::StrJoin(lines, "\n"); |
| } |
| |
| template <typename T> |
| bool CompareTensor(const Tensor& actual, const Tensor& expected, |
| ::testing::MatchResultListener* listener) { |
| if (actual.NumElements() != expected.NumElements()) { |
| if (listener->IsInterested()) { |
| *listener << "\nwas looking for tensor with " << expected.NumElements() |
| << " elements, found tensor with " << actual.NumElements() |
| << " elements"; |
| return false; |
| } |
| } |
| |
| for (int64 i = 0, e = actual.NumElements(); i < e; i++) { |
| if (actual.flat<T>()(i) != expected.flat<T>()(i)) { |
| *listener << "\nmismatch in constant tensor at index " << i |
| << " expected = " << expected.flat<T>()(i) |
| << " actual = " << actual.flat<T>()(i); |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor, |
| ::testing::MatchResultListener* listener) { |
| if (tensor.dtype() != expected_tensor.dtype()) { |
| if (listener->IsInterested()) { |
| *listener << "\nexpected tensor of type " |
| << DataType_Name(expected_tensor.dtype()) |
| << " but found one of type " << DataType_Name(tensor.dtype()); |
| return false; |
| } |
| } |
| |
| switch (tensor.dtype()) { |
| case DT_FLOAT: |
| return CompareTensor<float>(tensor, expected_tensor, listener); |
| case DT_DOUBLE: |
| return CompareTensor<double>(tensor, expected_tensor, listener); |
| case DT_INT8: |
| return CompareTensor<int8>(tensor, expected_tensor, listener); |
| case DT_INT16: |
| return CompareTensor<int16>(tensor, expected_tensor, listener); |
| case DT_INT32: |
| return CompareTensor<int32>(tensor, expected_tensor, listener); |
| case DT_INT64: |
| return CompareTensor<int64>(tensor, expected_tensor, listener); |
| case DT_UINT8: |
| return CompareTensor<uint8>(tensor, expected_tensor, listener); |
| case DT_UINT16: |
| return CompareTensor<uint16>(tensor, expected_tensor, listener); |
| case DT_UINT32: |
| return CompareTensor<uint32>(tensor, expected_tensor, listener); |
| case DT_UINT64: |
| return CompareTensor<uint64>(tensor, expected_tensor, listener); |
| default: |
| LOG(FATAL) << "Unsupported dtype " // Crash ok: testonly. |
| << DataType_Name(tensor.dtype()); |
| } |
| } |
| |
| using Input = std::pair<const Node*, int>; |
| |
| struct NodeMatcher : public ::testing::MatcherInterface<const Node*> { |
| bool MatchAndExplain( |
| const Node* node, |
| ::testing::MatchResultListener* listener) const override { |
| if (op && node->type_string() != *op) { |
| if (listener->IsInterested()) { |
| *listener << "\nexpected op " << *op << " but found " |
| << node->type_string(); |
| } |
| return false; |
| } |
| |
| if (assigned_device && node->assigned_device_name() != *assigned_device) { |
| if (listener->IsInterested()) { |
| *listener << "\nexpected assigned_device " << *assigned_device |
| << " but found \"" << node->assigned_device_name() << "\""; |
| } |
| return false; |
| } |
| |
| if (name && node->name() != *name) { |
| if (listener->IsInterested()) { |
| *listener << "\nexpected name " << *name << " but found " |
| << node->name(); |
| } |
| return false; |
| } |
| |
| if (constant_value) { |
| const TensorProto* proto = nullptr; |
| if (!GetNodeAttr(node->def(), "value", &proto).ok()) { |
| if (listener->IsInterested()) { |
| *listener << "\ncould not find \"value\" attribute in node"; |
| } |
| return false; |
| } |
| |
| Tensor tensor(proto->dtype()); |
| if (!tensor.FromProto(*proto)) { |
| if (listener->IsInterested()) { |
| *listener << "\ncould not convert TensorProto in \"value\" attribute " |
| "to Tensor"; |
| } |
| return false; |
| } |
| |
| if (!MatchAndExplainTensor(/*tensor=*/tensor, |
| /*expected_tensor=*/*constant_value, |
| listener)) { |
| return false; |
| } |
| } |
| |
| if (input_matchers) { |
| if (input_matchers->size() != node->num_inputs()) { |
| if (listener->IsInterested()) { |
| *listener << "\nexpected " << input_matchers->size() |
| << " inputs but node has " << node->num_inputs(); |
| } |
| return false; |
| } |
| |
| for (int input_idx = 0, e = input_matchers->size(); input_idx < e; |
| input_idx++) { |
| if (!MatchAndExplainInput(node, input_idx, listener)) { |
| return false; |
| } |
| } |
| } |
| |
| std::vector<const Node*> control_deps; |
| for (const Edge* e : node->in_edges()) { |
| if (e->IsControlEdge()) { |
| control_deps.push_back(e->src()); |
| } |
| } |
| |
| ::testing::StringMatchResultListener inner_listener; |
| if (control_dep_set && |
| !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) { |
| if (listener->IsInterested()) { |
| string explanation = inner_listener.str(); |
| if (!explanation.empty()) { |
| explanation = absl::StrCat(", ", explanation, ","); |
| } |
| *listener << "ctrl_deps" << explanation << " does not match expected: "; |
| control_dep_set->DescribeTo(listener->stream()); |
| } |
| return false; |
| } |
| return true; |
| } |
| |
| void DescribeTo(::std::ostream* os) const override { |
| std::vector<string> predicates; |
| |
| if (name) { |
| predicates.push_back(absl::StrCat("name: ", *name)); |
| } |
| |
| if (op) { |
| predicates.push_back(absl::StrCat("op: ", *op)); |
| } |
| |
| if (assigned_device) { |
| predicates.push_back(absl::StrCat("assigned device: ", *assigned_device)); |
| } |
| |
| bool printed_something = !predicates.empty(); |
| |
| *os << absl::StrJoin(predicates, ", "); |
| |
| if (constant_value) { |
| printed_something = true; |
| *os << "constant value: " << constant_value->DebugString(); |
| } |
| |
| if (input_matchers) { |
| if (!input_matchers->empty()) { |
| printed_something = true; |
| *os << " with " << (input_matchers->size() == 1 ? "only " : "") |
| << "input" << (input_matchers->size() == 1 ? "" : "s") << " "; |
| } |
| |
| if (input_matchers->size() == 1) { |
| ::std::stringstream ss; |
| input_matchers->front().DescribeTo(&ss); |
| printed_something = true; |
| *os << "matching " << ss.str(); |
| } else { |
| int edge_idx = 0; |
| for (const ::testing::Matcher<Input>& matcher : (*input_matchers)) { |
| *os << "\n [" << edge_idx << "] matching ("; |
| ::std::stringstream ss; |
| matcher.DescribeTo(&ss); |
| printed_something = true; |
| *os << IndentAllButFirstLine(ss.str()); |
| *os << ")"; |
| edge_idx++; |
| } |
| } |
| } |
| |
| if (control_dep_set) { |
| printed_something = true; |
| *os << " and control deps "; |
| control_dep_set->DescribeTo(os); |
| } |
| |
| if (!printed_something) { |
| *os << "is any node"; |
| } |
| } |
| |
| bool MatchAndExplainInput(const Node* node, int input_idx, |
| ::testing::MatchResultListener* listener) const { |
| const Edge* edge; |
| if (!node->input_edge(input_idx, &edge).ok()) { |
| if (listener->IsInterested()) { |
| *listener << "\ncould not find incoming edge for input " << input_idx; |
| } |
| return false; |
| } |
| |
| ::testing::StringMatchResultListener inner_listener; |
| Input input = {edge->src(), edge->src_output()}; |
| if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) { |
| return true; |
| } |
| |
| if (listener->IsInterested()) { |
| *listener << "\ninput " << input_idx << " does not match expected:\n"; |
| (*input_matchers)[input_idx].DescribeTo(listener->stream()); |
| string explanation = inner_listener.str(); |
| if (!explanation.empty()) { |
| *listener << ", " << explanation; |
| } |
| } |
| return false; |
| } |
| |
| absl::optional<string> op; |
| absl::optional<string> name; |
| absl::optional<string> assigned_device; |
| absl::optional<Tensor> constant_value; |
| absl::optional<std::vector<::testing::Matcher<Input>>> input_matchers; |
| absl::optional<::testing::Matcher<absl::Span<const Node* const>>> |
| control_dep_set; |
| }; |
| |
| // Matches a dst and dst_output on an input edge. Today we only use this with |
| // dst_output=0 but we will eventually need to support multi-output operations. |
| class InputMatcher : public ::testing::MatcherInterface<Input> { |
| public: |
| InputMatcher(::testing::Matcher<const Node*> src_matcher, int src_output) |
| : src_matcher_(std::move(src_matcher)), src_output_(src_output) {} |
| |
| bool MatchAndExplain( |
| Input input, ::testing::MatchResultListener* listener) const override { |
| ::testing::StringMatchResultListener inner_listener; |
| if (!src_matcher_.MatchAndExplain(input.first, &inner_listener)) { |
| if (listener->IsInterested()) { |
| *listener << "\nsource does not match expected "; |
| src_matcher_.DescribeTo(listener->stream()); |
| string explanation = inner_listener.str(); |
| if (!explanation.empty()) { |
| *listener << "\n\t" << explanation; |
| } |
| } |
| return false; |
| } |
| if (input.second != src_output_) { |
| if (listener->IsInterested()) { |
| *listener << "\nexpected output slot to be " << src_output_ |
| << " but found " << input.second; |
| } |
| return false; |
| } |
| |
| return true; |
| } |
| |
| void DescribeTo(::std::ostream* os) const override { |
| if (src_output_) { |
| *os << "output slot: " << src_output_ << ", source: ("; |
| } |
| |
| src_matcher_.DescribeTo(os); |
| |
| if (src_output_) { |
| *os << ")"; |
| } |
| } |
| |
| private: |
| ::testing::Matcher<const Node*> src_matcher_; |
| int src_output_; |
| }; |
| |
| std::vector<::testing::Matcher<Input>> NodeMatchersToInputMatchers( |
| absl::Span<const ::testing::Matcher<const Node*>> node_matchers) { |
| std::vector<::testing::Matcher<Input>> result; |
| absl::c_transform(node_matchers, std::back_inserter(result), |
| [](::testing::Matcher<const Node*> n) { |
| return ::testing::MakeMatcher(new InputMatcher(n, 0)); |
| }); |
| return result; |
| } |
| } // namespace |
| |
| ::testing::Matcher<const Node*> impl::NodeWith( |
| absl::Span<const NodeMatcherProperties> props) { |
| NodeMatcher* matcher = new NodeMatcher(); |
| for (const NodeMatcherProperties& prop : props) { |
| if (prop.name()) { |
| DCHECK(!matcher->name); |
| matcher->name = prop.name(); |
| } |
| |
| if (prop.op()) { |
| DCHECK(!matcher->op); |
| matcher->op = prop.op(); |
| } |
| |
| if (prop.constant_value()) { |
| DCHECK(!matcher->constant_value); |
| matcher->constant_value = prop.constant_value(); |
| } |
| |
| if (prop.assigned_device()) { |
| DCHECK(!matcher->assigned_device); |
| matcher->assigned_device = prop.assigned_device(); |
| } |
| |
| if (prop.input_nodes()) { |
| DCHECK(!matcher->input_matchers); |
| matcher->input_matchers = |
| NodeMatchersToInputMatchers(*prop.input_nodes()); |
| } |
| |
| if (prop.control_deps()) { |
| DCHECK(!matcher->control_dep_set); |
| matcher->control_dep_set = |
| ::testing::UnorderedElementsAreArray(*prop.control_deps()); |
| } |
| } |
| |
| return ::testing::MakeMatcher(matcher); |
| } |
| |
| impl::NodeMatcherProperties Name(string name) { |
| impl::NodeMatcherProperties props; |
| props.set_name(std::move(name)); |
| return props; |
| } |
| |
| // Matches a node with op `op`. |
| impl::NodeMatcherProperties Op(string op) { |
| impl::NodeMatcherProperties props; |
| props.set_op(std::move(op)); |
| return props; |
| } |
| |
| // Matches a node with assigned device `assigned_device`. |
| impl::NodeMatcherProperties AssignedDevice(string assigned_device) { |
| impl::NodeMatcherProperties props; |
| props.set_assigned_device(std::move(assigned_device)); |
| return props; |
| } |
| |
| impl::NodeMatcherProperties impl::Inputs( |
| absl::Span<const ::testing::Matcher<const Node*>> inputs) { |
| std::vector<::testing::Matcher<const Node*>> inputs_vector; |
| absl::c_copy(inputs, std::back_inserter(inputs_vector)); |
| |
| impl::NodeMatcherProperties props; |
| props.set_input_nodes(std::move(inputs_vector)); |
| return props; |
| } |
| |
| impl::NodeMatcherProperties impl::CtrlDeps( |
| absl::Span<const ::testing::Matcher<const Node*>> control_deps) { |
| std::vector<::testing::Matcher<const Node*>> control_deps_vector; |
| absl::c_copy(control_deps, std::back_inserter(control_deps_vector)); |
| |
| impl::NodeMatcherProperties props; |
| props.set_control_deps(std::move(control_deps_vector)); |
| return props; |
| } |
| |
| NodeMatcherProperties ConstantValue( |
| const ::tensorflow::Input::Initializer& val) { |
| TF_CHECK_OK(val.status); |
| NodeMatcherProperties props; |
| props.set_constant_value(val.tensor); |
| return props; |
| } |
| |
| ::testing::Matcher<const Node*> Const( |
| const ::tensorflow::Input::Initializer& val) { |
| return NodeWith(ConstantValue(val)); |
| } |
| } // namespace matchers |
| |
| Node* FindNodeByName(Graph* g, absl::string_view name) { |
| for (Node* n : g->nodes()) { |
| if (n->name() == name) { |
| return n; |
| } |
| } |
| |
| return nullptr; |
| } |
| } // namespace testing |
| } // namespace tensorflow |