| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // Provides a set of matchers for tensorflow nodes. |
| // |
| // Example usage: |
| // |
| // tensorflow::Node* node = ...; |
| // EXPECT_THAT(node, NodeWith(Name("name"), Op("op"), |
| // Inputs(NodeWith(Name("input"))))) |
| // |
| // Matchable node properties (the expressions that go inside NodeWith(...)) |
| // are: |
| // |
| // - Name(string): matches the node name exactly. We will probably need to |
| // have this take a string matcher soon in the future. |
| // |
| // - Op(string): matches the op exactly. |
| // |
| // - AssignedDevice(string): matches the assigned device exactly. |
| // |
| // - Inputs(<ordered list>): matches the list of non-control inputs to the node |
| // exactly (i.e. does not match a suffix or a prefix). |
| // |
| // - CtrlDeps(<unordered list>): matches the list of control dependences on the |
| // node exactly but in any order. |
| // |
| // - ConstantValue(tensorflow::Input::Initializer init): matches a Const node |
| // with the constant value `init`. Implies Op("Const"). |
| // |
| // Node properties may not be repeated in a single NodeWith(...) matcher. |
| // E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since ConstantValue |
| // implies Op("Const"), a single NodeWith matcher can't have both |
| // ConstantValue(...) and Op(...). |
| |
| #ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ |
| #define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ |
| |
| #include <array> |
| #include <string> |
| #include <vector> |
| |
| #include "absl/algorithm/container.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/types/optional.h" |
| #include "absl/types/span.h" |
| #include "tensorflow/cc/framework/ops.h" |
| #include "tensorflow/compiler/xla/test.h" |
| #include "tensorflow/core/graph/graph.h" |
| |
| namespace tensorflow { |
| namespace testing { |
| namespace matchers { |
| |
| namespace impl { |
| |
| // ----------------------------------------------------------------------------- |
| // Implementation details. |
| |
| // Properties that we match on for a particular Node. If a particular property |
| // is nullopt then any value for it is allowed. |
| class NodeMatcherProperties { |
| public: |
| using NodeSeqMatcher = std::vector<::testing::Matcher<const Node*>>; |
| |
| const absl::optional<string>& name() const { return name_; } |
| const absl::optional<string>& op() const { return op_; } |
| const absl::optional<string>& assigned_device() const { |
| return assigned_device_; |
| } |
| const absl::optional<Tensor>& constant_value() const { |
| return constant_value_; |
| } |
| const absl::optional<NodeSeqMatcher>& input_nodes() const { |
| return input_nodes_; |
| } |
| const absl::optional<NodeSeqMatcher>& control_deps() const { |
| return control_deps_; |
| } |
| |
| void set_name(string name) { |
| DCHECK(IsEmpty()); |
| name_ = std::move(name); |
| } |
| |
| void set_op(string op) { |
| DCHECK(IsEmpty()); |
| op_ = std::move(op); |
| } |
| |
| void set_assigned_device(string assigned_device) { |
| DCHECK(IsEmpty()); |
| assigned_device_ = std::move(assigned_device); |
| } |
| |
| void set_constant_value(Tensor constant_value) { |
| DCHECK(IsEmpty()); |
| constant_value_ = std::move(constant_value); |
| op_ = "Const"; |
| } |
| |
| void set_input_nodes(NodeSeqMatcher input_nodes) { |
| DCHECK(IsEmpty()); |
| input_nodes_ = std::move(input_nodes); |
| } |
| |
| void set_control_deps(NodeSeqMatcher control_deps) { |
| DCHECK(IsEmpty()); |
| control_deps_ = std::move(control_deps); |
| } |
| |
| bool IsEmpty() const { |
| return !name().has_value() && !op().has_value() && |
| !input_nodes().has_value() && !control_deps().has_value(); |
| } |
| |
| private: |
| absl::optional<string> name_; |
| absl::optional<string> op_; |
| absl::optional<string> assigned_device_; |
| absl::optional<Tensor> constant_value_; |
| absl::optional<NodeSeqMatcher> input_nodes_; |
| absl::optional<NodeSeqMatcher> control_deps_; |
| }; |
| |
| ::testing::Matcher<const Node*> NodeWith( |
| absl::Span<const NodeMatcherProperties> props); |
| |
| impl::NodeMatcherProperties Inputs( |
| absl::Span<const ::testing::Matcher<const Node*>> inputs); |
| |
| impl::NodeMatcherProperties CtrlDeps( |
| absl::Span<const ::testing::Matcher<const Node*>> control_deps); |
| } // namespace impl |
| |
| // ----------------------------------------------------------------------------- |
| // Public interface. |
| |
| // Matches a node with name `name`. |
| impl::NodeMatcherProperties Name(string name); |
| |
| // Matches a node with op `op`. |
| impl::NodeMatcherProperties Op(string op); |
| |
| // Matches a node with assigned device `assigned_device`. |
| impl::NodeMatcherProperties AssignedDevice(string assigned_device); |
| |
| // Matches a node with inputs `inputs`. |
| // |
| // `inputs` are ordered; `inputs`[i] must match input i. |
| template <typename... Ts> |
| impl::NodeMatcherProperties Inputs(Ts... inputs) { |
| return impl::Inputs({inputs...}); |
| } |
| |
| // Matches a node with control dependences `control_deps`. |
| // |
| // `control_deps` are unordered and will match the control deps of a node in any |
| // order. |
| template <typename... Ts> |
| impl::NodeMatcherProperties CtrlDeps(Ts... control_deps) { |
| return impl::CtrlDeps({control_deps...}); |
| } |
| |
| // Matches a constant node with value `val`. |
| impl::NodeMatcherProperties ConstantValue( |
| const ::tensorflow::Input::Initializer& val); |
| |
| // The main gmock matcher. See file comment for example usage. |
| template <typename... Ts> |
| ::testing::Matcher<const Node*> NodeWith(Ts... args) { |
| std::array<impl::NodeMatcherProperties, sizeof...(Ts)> array = {args...}; |
| return impl::NodeWith(array); |
| } |
| |
| ::testing::Matcher<const Node*> Const( |
| const ::tensorflow::Input::Initializer& val); |
| } // namespace matchers |
| |
| // If `g` has a node named `name` returns it, otherwise returns null. |
| Node* FindNodeByName(Graph* g, absl::string_view name); |
| } // namespace testing |
| } // namespace tensorflow |
| |
| #endif // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ |