[tf.data] Make the unit tests mode readable.

PiperOrigin-RevId: 449499809
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
index 64ad8db..d3b2a77 100644
--- a/tensorflow/core/framework/model.cc
+++ b/tensorflow/core/framework/model.cc
@@ -2337,7 +2337,7 @@
   ComputeTotalTimes(nodes);
 }
 
-const ModelTiming::NodeTiming* ModelTiming::GetTiming(Node* node) const {
+const ModelTiming::NodeTiming* ModelTiming::GetTiming(const Node* node) const {
   if (timing_nodes_.find(node) == timing_nodes_.end()) {
     return nullptr;
   }
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index a6666d6..075b7d7 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -865,7 +865,7 @@
   explicit ModelTiming(std::shared_ptr<Model> model);
 
   // Returns the timing data for `node`.
-  const NodeTiming* GetTiming(Node* node) const;
+  const NodeTiming* GetTiming(const Node* node) const;
 
   // Returns the root nodes of all stages.
   std::vector<std::shared_ptr<Node>> GetStageRoots() const;
diff --git a/tensorflow/core/framework/model_test.cc b/tensorflow/core/framework/model_test.cc
index dbc062a..acf3dcf 100644
--- a/tensorflow/core/framework/model_test.cc
+++ b/tensorflow/core/framework/model_test.cc
@@ -17,6 +17,7 @@
 
 #include <memory>
 #include <string>
+#include <utility>
 
 #include "tensorflow/core/framework/cancellation.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
@@ -1322,467 +1323,564 @@
 
 class ModelTimingTest : public ::testing::Test {
  public:
-  void RecordNumElements(std::shared_ptr<Node> node, int num_elements) {
-    for (int i = 0; i < num_elements; ++i) {
-      node->record_element();
+  // Computes the timing given a Model text proto.
+  void ComputeModelTiming(const std::string& model_pbtxt) {
+    ModelProto model_proto;
+    protobuf::TextFormat::ParseFromString(model_pbtxt, &model_proto);
+    std::unique_ptr<Model> model;
+    TF_CHECK_OK(Model::FromProto(model_proto, &model));
+    auto nodes =
+        model->CollectNodes(model->output(), TraversalOrder::BFS,
+                            [](const std::shared_ptr<Node>) { return true; });
+    node_map_.clear();
+    for (const auto& node : nodes) {
+      node_map_[node->id()] = node.get();
     }
+    model_timing_ = absl::make_unique<ModelTiming>(std::move(model));
   }
+
+  // Gets the timing information of a node given its id.
+  const ModelTiming::NodeTiming* GetNodeTiming(int64_t node_id) const {
+    return model_timing_->GetTiming(node_map_.at(node_id));
+  }
+
+ protected:
+  std::unique_ptr<ModelTiming> model_timing_;
+  absl::flat_hash_map<int64_t, const Node*> node_map_;
 };
 
 TEST_F(ModelTimingTest, Interleave) {
-  auto batch_1 = model::MakeKnownRatioNode(
-      {/*id=*/1, /*name=*/"Batch", /*output=*/nullptr}, /*ratio=*/1);
-  auto interleave = model::MakeInterleaveManyNode(
-      {/*id=*/2, /*name=*/"Interleave", /*output=*/batch_1},
-      {model::MakeParameter("cycle_length", nullptr, /*min=*/2, /*max=*/2)});
-  auto batch_2 = model::MakeKnownRatioNode({
-                                               /*id=*/3,
-                                               /*name=*/"Batch",
-                                               /*output=*/interleave,
-                                           },
-                                           /*ratio=*/1);
-  auto batch_3 = model::MakeKnownRatioNode({
-                                               /*id=*/4,
-                                               /*name=*/"Batch",
-                                               /*output=*/interleave,
-                                           },
-                                           /*ratio=*/1);
-  RecordNumElements(batch_1, 100);
-  batch_1->add_processing_time(1000);
-  RecordNumElements(interleave, 100);
-  interleave->add_processing_time(1000);
-  RecordNumElements(batch_2, 60);
-  batch_2->add_processing_time(1200);
-  RecordNumElements(batch_3, 40);
-  batch_3->add_processing_time(800);
+  ComputeModelTiming(R"pb(
+    nodes: {
+      key: 1
+      value: {
+        id: 1
+        name: "Batch"
+        autotune: true
+        num_elements: 100
+        processing_time: 1000
+        node_class: KNOWN_RATIO
+        ratio: 1
+        inputs: 2
+      }
+    }
+    nodes: {
+      key: 2
+      value: {
+        id: 2
+        name: "Interleave"
+        autotune: true
+        num_elements: 100
+        processing_time: 1000
+        node_class: INTERLEAVE_MANY
+        inputs: 3
+        inputs: 4
+        parameters: { name: "cycle_length" value: 2 tunable: false }
+      }
+    }
+    nodes: {
+      key: 3
+      value: {
+        id: 3
+        name: "Batch"
+        autotune: true
+        num_elements: 60
+        processing_time: 1200
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    nodes: {
+      key: 4
+      value: {
+        id: 4
+        name: "Batch"
+        autotune: true
+        num_elements: 40
+        processing_time: 800
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    output: 1
+  )pb");
 
-  std::shared_ptr<model::Model> model = std::make_shared<model::Model>();
-  model->AddNode([&batch_1](model::Node::Args args) { return batch_1; },
-                 "batch_1", nullptr, &batch_1);
-  model->AddNode([&interleave](model::Node::Args args) { return interleave; },
-                 "interleave", batch_1, &interleave);
-  model->AddNode([&batch_2](model::Node::Args args) { return batch_2; },
-                 "batch_2", interleave, &batch_2);
-  model->AddNode([&batch_3](model::Node::Args args) { return batch_3; },
-                 "batch_3", interleave, &batch_3);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/1)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/2)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/3)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/4)->pipeline_ratio);
 
-  ModelTiming model_timing(model);
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_1.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(interleave.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_2.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_3.get()));
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/2)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/3)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->self_time_nsec);
 
-  EXPECT_DOUBLE_EQ(1, model_timing.GetTiming(batch_1.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(1, model_timing.GetTiming(interleave.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(0.5, model_timing.GetTiming(batch_2.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(0.5, model_timing.GetTiming(batch_3.get())->pipeline_ratio);
-
-  EXPECT_DOUBLE_EQ(10, model_timing.GetTiming(batch_1.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(10,
-                   model_timing.GetTiming(interleave.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_2.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_3.get())->self_time_nsec);
-
-  EXPECT_DOUBLE_EQ(40, model_timing.GetTiming(batch_1.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(30,
-                   model_timing.GetTiming(interleave.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_2.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_3.get())->total_time_nsec);
+  EXPECT_DOUBLE_EQ(40, GetNodeTiming(/*node_id=*/1)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(30, GetNodeTiming(/*node_id=*/2)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/3)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->total_time_nsec);
 }
 
 TEST_F(ModelTimingTest, ParallelInterleave_DeterministicFalse) {
-  auto batch_1 = model::MakeKnownRatioNode(
-      {/*id=*/1, /*name=*/"Batch", /*output=*/nullptr}, /*ratio=*/1);
-  auto parallel_interleave = model::MakeAsyncInterleaveManyNode(
-      {/*id=*/2,
-       /*name=*/"ParallelInterleaveV4", /*output=*/batch_1},
-      {model::MakeParameter("parallelism",
-                            std::make_shared<SharedState>(
-                                /*value=*/2, nullptr, nullptr),
-                            /*min=*/1, /*max=*/10),
-       model::MakeParameter(kCycleLength, nullptr,
-                            /*min=*/2,
-                            /*max=*/2),
-       model::MakeNonTunableParameter(kDeterministic, /*value=*/0.0)});
-  auto first_input =
-      model::MakeKnownRatioNode({
-                                    /*id=*/3,
-                                    /*name=*/"Batch",
-                                    /*output=*/parallel_interleave,
-                                },
-                                /*ratio=*/1);
-  auto batch_2 = model::MakeKnownRatioNode({
-                                               /*id=*/3,
-                                               /*name=*/"Batch",
-                                               /*output=*/parallel_interleave,
-                                           },
-                                           /*ratio=*/1);
-  auto batch_3 = model::MakeKnownRatioNode({
-                                               /*id=*/4,
-                                               /*name=*/"Batch",
-                                               /*output=*/parallel_interleave,
-                                           },
-                                           /*ratio=*/1);
-  RecordNumElements(batch_1, 100);
-  batch_1->add_processing_time(1000);
-  RecordNumElements(parallel_interleave, 100);
-  parallel_interleave->add_processing_time(2000);
-  RecordNumElements(first_input, 60);
-  first_input->add_processing_time(60);
-  RecordNumElements(batch_2, 60);
-  batch_2->add_processing_time(1200);
-  RecordNumElements(batch_3, 40);
-  batch_3->add_processing_time(800);
+  ComputeModelTiming(R"pb(
+    nodes: {
+      key: 1
+      value: {
+        id: 1
+        name: "Batch"
+        autotune: true
+        num_elements: 100
+        processing_time: 1000
+        node_class: KNOWN_RATIO
+        ratio: 1
+        inputs: 2
+      }
+    }
+    nodes: {
+      key: 2
+      value: {
+        id: 2
+        name: "ParallelInterleaveV4"
+        autotune: true
+        num_elements: 100
+        processing_time: 2000
+        node_class: ASYNC_INTERLEAVE_MANY
+        inputs: 3
+        inputs: 4
+        inputs: 5
+        parameters: {
+          name: "parallelism"
+          value: 2
+          min: 1
+          max: 10
+          tunable: true
+        }
+        parameters: { name: "cycle_length" value: 2 tunable: false }
+        parameters: { name: "deterministic" value: 0 tunable: false }
+      }
+    }
+    nodes: {
+      key: 3
+      value: {
+        id: 3
+        name: "Batch"
+        autotune: true
+        num_elements: 60
+        processing_time: 60
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    nodes: {
+      key: 4
+      value: {
+        id: 4
+        name: "Batch"
+        autotune: true
+        num_elements: 60
+        processing_time: 1200
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    nodes: {
+      key: 5
+      value: {
+        id: 5
+        name: "Batch"
+        autotune: true
+        num_elements: 40
+        processing_time: 800
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    output: 1
+  )pb");
 
-  std::shared_ptr<model::Model> model = std::make_shared<model::Model>();
-  model->AddNode([&batch_1](model::Node::Args args) { return batch_1; },
-                 "batch_1", nullptr, &batch_1);
-  model->AddNode([&parallel_interleave](
-                     model::Node::Args args) { return parallel_interleave; },
-                 "parallel_interleave", batch_1, &parallel_interleave);
-  model->AddNode([&first_input](model::Node::Args args) { return first_input; },
-                 "first_input", parallel_interleave, &first_input);
-  model->AddNode([&batch_2](model::Node::Args args) { return batch_2; },
-                 "batch_2", parallel_interleave, &batch_2);
-  model->AddNode([&batch_3](model::Node::Args args) { return batch_3; },
-                 "batch_3", parallel_interleave, &batch_3);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/1)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/2)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/3)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/4)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/5)->pipeline_ratio);
 
-  ModelTiming model_timing(model);
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_1.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(parallel_interleave.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_2.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_3.get()));
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/2)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/3)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/5)->self_time_nsec);
 
-  EXPECT_DOUBLE_EQ(1, model_timing.GetTiming(batch_1.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(
-      1, model_timing.GetTiming(parallel_interleave.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(0.5, model_timing.GetTiming(batch_2.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(0.5, model_timing.GetTiming(batch_3.get())->pipeline_ratio);
-
-  EXPECT_DOUBLE_EQ(10, model_timing.GetTiming(batch_1.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(
-      10, model_timing.GetTiming(parallel_interleave.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_2.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_3.get())->self_time_nsec);
-
-  EXPECT_DOUBLE_EQ(10, model_timing.GetTiming(batch_1.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(
-      20, model_timing.GetTiming(parallel_interleave.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_2.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_3.get())->total_time_nsec);
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/2)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/3)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/5)->total_time_nsec);
 }
 
 TEST_F(ModelTimingTest, ParallelInterleave_DeterministicTrue) {
-  auto batch_1 = model::MakeKnownRatioNode(
-      {/*id=*/1, /*name=*/"Batch", /*output=*/nullptr}, /*ratio=*/1);
-  auto parallel_interleave = model::MakeAsyncInterleaveManyNode(
-      {/*id=*/2,
-       /*name=*/"ParallelInterleaveV4", /*output=*/batch_1},
-      {model::MakeParameter("parallelism",
-                            std::make_shared<SharedState>(
-                                /*value=*/2, nullptr, nullptr),
-                            /*min=*/1, /*max=*/10),
-       model::MakeParameter(kCycleLength, nullptr,
-                            /*min=*/2,
-                            /*max=*/2),
-       model::MakeNonTunableParameter(kDeterministic, /*value=*/1.0)});
-  auto first_input =
-      model::MakeKnownRatioNode({
-                                    /*id=*/3,
-                                    /*name=*/"Batch",
-                                    /*output=*/parallel_interleave,
-                                },
-                                /*ratio=*/1);
-  auto batch_2 = model::MakeKnownRatioNode({
-                                               /*id=*/3,
-                                               /*name=*/"Batch",
-                                               /*output=*/parallel_interleave,
-                                           },
-                                           /*ratio=*/1);
-  auto batch_3 = model::MakeKnownRatioNode({
-                                               /*id=*/4,
-                                               /*name=*/"Batch",
-                                               /*output=*/parallel_interleave,
-                                           },
-                                           /*ratio=*/1);
-  RecordNumElements(batch_1, 100);
-  batch_1->add_processing_time(1000);
-  RecordNumElements(parallel_interleave, 100);
-  parallel_interleave->add_processing_time(2000);
-  RecordNumElements(first_input, 60);
-  first_input->add_processing_time(60);
-  RecordNumElements(batch_2, 60);
-  batch_2->add_processing_time(1200);
-  RecordNumElements(batch_3, 40);
-  batch_3->add_processing_time(1600);
+  ComputeModelTiming(R"pb(
+    nodes: {
+      key: 1
+      value: {
+        id: 1
+        name: "Batch"
+        autotune: true
+        num_elements: 100
+        processing_time: 1000
+        node_class: KNOWN_RATIO
+        ratio: 1
+        inputs: 2
+      }
+    }
+    nodes: {
+      key: 2
+      value: {
+        id: 2
+        name: "ParallelInterleaveV4"
+        autotune: true
+        num_elements: 100
+        processing_time: 2000
+        node_class: ASYNC_INTERLEAVE_MANY
+        inputs: 3
+        inputs: 4
+        inputs: 5
+        parameters: {
+          name: "parallelism"
+          value: 2
+          min: 1
+          max: 10
+          tunable: true
+        }
+        parameters: { name: "cycle_length" value: 2 tunable: false }
+        parameters: { name: "deterministic" value: 1 tunable: false }
+      }
+    }
+    nodes: {
+      key: 3
+      value: {
+        id: 3
+        name: "Batch"
+        autotune: true
+        num_elements: 60
+        processing_time: 60
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    nodes: {
+      key: 4
+      value: {
+        id: 4
+        name: "Batch"
+        autotune: true
+        num_elements: 60
+        processing_time: 1200
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    nodes: {
+      key: 5
+      value: {
+        id: 5
+        name: "Batch"
+        autotune: true
+        num_elements: 40
+        processing_time: 1600
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    output: 1
+  )pb");
 
-  std::shared_ptr<model::Model> model = std::make_shared<model::Model>();
-  model->AddNode([&batch_1](model::Node::Args args) { return batch_1; },
-                 "batch_1", nullptr, &batch_1);
-  model->AddNode([&parallel_interleave](
-                     model::Node::Args args) { return parallel_interleave; },
-                 "parallel_interleave", batch_1, &parallel_interleave);
-  model->AddNode([&first_input](model::Node::Args args) { return first_input; },
-                 "first_input", parallel_interleave, &first_input);
-  model->AddNode([&batch_2](model::Node::Args args) { return batch_2; },
-                 "batch_2", parallel_interleave, &batch_2);
-  model->AddNode([&batch_3](model::Node::Args args) { return batch_3; },
-                 "batch_3", parallel_interleave, &batch_3);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/1)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/2)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/3)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/4)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/5)->pipeline_ratio);
 
-  ModelTiming model_timing(model);
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_1.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(parallel_interleave.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_2.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_3.get()));
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/2)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/3)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(40, GetNodeTiming(/*node_id=*/5)->self_time_nsec);
 
-  EXPECT_DOUBLE_EQ(1, model_timing.GetTiming(batch_1.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(
-      1, model_timing.GetTiming(parallel_interleave.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(0.5, model_timing.GetTiming(batch_2.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(0.5, model_timing.GetTiming(batch_3.get())->pipeline_ratio);
-
-  EXPECT_DOUBLE_EQ(10, model_timing.GetTiming(batch_1.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(
-      10, model_timing.GetTiming(parallel_interleave.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_2.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(40, model_timing.GetTiming(batch_3.get())->self_time_nsec);
-
-  EXPECT_DOUBLE_EQ(10, model_timing.GetTiming(batch_1.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(
-      30, model_timing.GetTiming(parallel_interleave.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_2.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(40, model_timing.GetTiming(batch_3.get())->total_time_nsec);
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(30, GetNodeTiming(/*node_id=*/2)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/3)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(40, GetNodeTiming(/*node_id=*/5)->total_time_nsec);
 }
 
 TEST_F(ModelTimingTest, ParallelInterleave_CycleLength) {
-  auto batch_1 = model::MakeKnownRatioNode(
-      {/*id=*/1, /*name=*/"Batch", /*output=*/nullptr}, /*ratio=*/1);
-  auto parallel_interleave = model::MakeAsyncInterleaveManyNode(
-      {/*id=*/2,
-       /*name=*/"ParallelInterleaveV4", /*output=*/batch_1},
-      {model::MakeParameter("parallelism",
-                            std::make_shared<SharedState>(
-                                /*value=*/2, nullptr, nullptr),
-                            /*min=*/1, /*max=*/10),
-       model::MakeParameter("cycle_length",
-                            std::make_shared<SharedState>(
-                                /*value=*/1, nullptr, nullptr),
-                            /*min=*/2, /*max=*/2),
-       model::MakeNonTunableParameter(kDeterministic, /*value=*/0.0)});
-  auto first_input =
-      model::MakeKnownRatioNode({
-                                    /*id=*/3,
-                                    /*name=*/"Batch",
-                                    /*output=*/parallel_interleave,
-                                },
-                                /*ratio=*/1);
-  auto batch_2 = model::MakeKnownRatioNode({
-                                               /*id=*/3,
-                                               /*name=*/"Batch",
-                                               /*output=*/parallel_interleave,
-                                           },
-                                           /*ratio=*/1);
-  auto batch_3 = model::MakeKnownRatioNode({
-                                               /*id=*/4,
-                                               /*name=*/"Batch",
-                                               /*output=*/parallel_interleave,
-                                           },
-                                           /*ratio=*/1);
-  auto batch_4 = model::MakeKnownRatioNode({
-                                               /*id=*/5,
-                                               /*name=*/"Batch",
-                                               /*output=*/parallel_interleave,
-                                           },
-                                           /*ratio=*/1);
-  auto batch_5 = model::MakeKnownRatioNode({
-                                               /*id=*/6,
-                                               /*name=*/"Batch",
-                                               /*output=*/parallel_interleave,
-                                           },
-                                           /*ratio=*/1);
-  RecordNumElements(batch_1, 100);
-  batch_1->add_processing_time(1000);
-  RecordNumElements(parallel_interleave, 100);
-  parallel_interleave->add_processing_time(2000);
-  RecordNumElements(first_input, 60);
-  first_input->add_processing_time(60);
-  RecordNumElements(batch_2, 60);
-  batch_2->add_processing_time(1200);
-  RecordNumElements(batch_3, 40);
-  batch_3->add_processing_time(800);
-  RecordNumElements(batch_4, 60);
-  batch_4->add_processing_time(1200);
-  RecordNumElements(batch_5, 40);
-  batch_5->add_processing_time(800);
-  batch_4->set_autotune(false);
-  batch_5->set_autotune(false);
+  ComputeModelTiming(R"pb(
+    nodes: {
+      key: 1
+      value: {
+        id: 1
+        name: "Batch"
+        autotune: true
+        num_elements: 100
+        processing_time: 1000
+        node_class: KNOWN_RATIO
+        ratio: 1
+        inputs: 2
+      }
+    }
+    nodes: {
+      key: 2
+      value: {
+        id: 2
+        name: "ParallelInterleaveV4"
+        autotune: true
+        num_elements: 100
+        processing_time: 2000
+        node_class: ASYNC_INTERLEAVE_MANY
+        inputs: 3
+        inputs: 4
+        inputs: 5
+        inputs: 6
+        inputs: 7
+        parameters: {
+          name: "parallelism"
+          value: 2
+          min: 1
+          max: 10
+          tunable: true
+        }
+        parameters: { name: "cycle_length" value: 1 tunable: false }
+        parameters: { name: "deterministic" value: 0 tunable: false }
+      }
+    }
+    nodes: {
+      key: 3
+      value: {
+        id: 3
+        name: "Batch"
+        autotune: true
+        num_elements: 60
+        processing_time: 60
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    nodes: {
+      key: 4
+      value: {
+        id: 4
+        name: "Batch"
+        autotune: true
+        num_elements: 60
+        processing_time: 1200
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    nodes: {
+      key: 5
+      value: {
+        id: 5
+        name: "Batch"
+        autotune: true
+        num_elements: 40
+        processing_time: 800
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    nodes: {
+      key: 6
+      value: {
+        id: 6
+        name: "Batch"
+        autotune: false
+        num_elements: 60
+        processing_time: 1200
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    nodes: {
+      key: 7
+      value: {
+        id: 7
+        name: "Batch"
+        autotune: false
+        num_elements: 40
+        processing_time: 800
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    output: 1
+  )pb");
 
-  std::shared_ptr<model::Model> model = std::make_shared<model::Model>();
-  model->AddNode([&batch_1](model::Node::Args args) { return batch_1; },
-                 "batch_1", nullptr, &batch_1);
-  model->AddNode([&parallel_interleave](
-                     model::Node::Args args) { return parallel_interleave; },
-                 "parallel_interleave", batch_1, &parallel_interleave);
-  model->AddNode([&first_input](model::Node::Args args) { return first_input; },
-                 "first_input", parallel_interleave, &first_input);
-  model->AddNode([&batch_2](model::Node::Args args) { return batch_2; },
-                 "batch_2", parallel_interleave, &batch_2);
-  model->AddNode([&batch_3](model::Node::Args args) { return batch_3; },
-                 "batch_3", parallel_interleave, &batch_3);
-  model->AddNode([&batch_4](model::Node::Args args) { return batch_4; },
-                 "batch_4", parallel_interleave, &batch_4);
-  model->AddNode([&batch_5](model::Node::Args args) { return batch_5; },
-                 "batch_5", parallel_interleave, &batch_5);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/1)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/2)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/3)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/4)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/5)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(0, GetNodeTiming(/*node_id=*/6)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(0, GetNodeTiming(/*node_id=*/7)->pipeline_ratio);
 
-  ModelTiming model_timing(model);
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_1.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(parallel_interleave.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_2.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_3.get()));
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/2)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/3)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/5)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/6)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/7)->self_time_nsec);
 
-  EXPECT_DOUBLE_EQ(1.0, model_timing.GetTiming(batch_1.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(
-      1.0, model_timing.GetTiming(parallel_interleave.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(1.0, model_timing.GetTiming(batch_2.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(1.0, model_timing.GetTiming(batch_3.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(0.0, model_timing.GetTiming(batch_4.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(0.0, model_timing.GetTiming(batch_5.get())->pipeline_ratio);
-
-  EXPECT_DOUBLE_EQ(10, model_timing.GetTiming(batch_1.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(
-      10, model_timing.GetTiming(parallel_interleave.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_2.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_3.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_4.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_5.get())->self_time_nsec);
-
-  EXPECT_DOUBLE_EQ(10, model_timing.GetTiming(batch_1.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(
-      20, model_timing.GetTiming(parallel_interleave.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_2.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_3.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(0, model_timing.GetTiming(batch_4.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(0, model_timing.GetTiming(batch_5.get())->total_time_nsec);
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/2)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/3)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/5)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(0, GetNodeTiming(/*node_id=*/6)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(0, GetNodeTiming(/*node_id=*/7)->total_time_nsec);
 }
 
 TEST_F(ModelTimingTest, ParallelInterleave_Batch_ParallelMap) {
-  auto batch_1 = model::MakeKnownRatioNode(
-      {/*id=*/1, /*name=*/"Batch", /*output=*/nullptr}, /*ratio=*/1);
-  auto parallel_interleave = model::MakeAsyncInterleaveManyNode(
-      {/*id=*/2,
-       /*name=*/"ParallelInterleaveV4", /*output=*/batch_1},
-      {model::MakeParameter("parallelism",
-                            std::make_shared<SharedState>(
-                                /*value=*/2, nullptr, nullptr),
-                            /*min=*/1, /*max=*/10),
-       model::MakeParameter(kCycleLength, nullptr,
-                            /*min=*/2,
-                            /*max=*/2),
-       model::MakeNonTunableParameter(kDeterministic, /*value=*/0.0)});
-  auto first_input =
-      model::MakeKnownRatioNode({
-                                    /*id=*/3,
-                                    /*name=*/"Batch",
-                                    /*output=*/parallel_interleave,
-                                },
-                                /*ratio=*/1);
-  auto batch_2 = model::MakeKnownRatioNode({
-                                               /*id=*/3,
-                                               /*name=*/"Batch",
-                                               /*output=*/parallel_interleave,
-                                           },
-                                           /*ratio=*/2);
-  auto batch_3 = model::MakeKnownRatioNode({
-                                               /*id=*/4,
-                                               /*name=*/"Batch",
-                                               /*output=*/parallel_interleave,
-                                           },
-                                           /*ratio=*/2);
-  std::shared_ptr<Node> parallel_map_1 = model::MakeAsyncKnownRatioNode(
-      {/*id=*/1, /*name=*/"ParallelMapV2", /*output=*/batch_2}, /*ratio=*/1,
-      {model::MakeParameter(
-          "parallelism",
-          std::make_shared<SharedState>(/*value=*/2, nullptr, nullptr),
-          /*min=*/1,
-          /*max=*/16)});
-  std::shared_ptr<Node> parallel_map_2 = model::MakeAsyncKnownRatioNode(
-      {/*id=*/1, /*name=*/"ParallelMapV2", /*output=*/batch_3}, /*ratio=*/1,
-      {model::MakeParameter(
-          "parallelism",
-          std::make_shared<SharedState>(/*value=*/2, nullptr, nullptr),
-          /*min=*/1,
-          /*max=*/16)});
-  RecordNumElements(batch_1, 100);
-  batch_1->add_processing_time(1000);
-  RecordNumElements(parallel_interleave, 100);
-  parallel_interleave->add_processing_time(2000);
-  RecordNumElements(first_input, 60);
-  first_input->add_processing_time(60);
-  RecordNumElements(batch_2, 60);
-  batch_2->add_processing_time(1200);
-  RecordNumElements(batch_3, 40);
-  batch_3->add_processing_time(800);
-  RecordNumElements(parallel_map_1, 120);
-  parallel_map_1->add_processing_time(2400);
-  RecordNumElements(parallel_map_2, 120);
-  parallel_map_2->add_processing_time(2400);
+  ComputeModelTiming(R"pb(
+    nodes: {
+      key: 1
+      value: {
+        id: 1
+        name: "Batch"
+        autotune: true
+        num_elements: 100
+        processing_time: 1000
+        node_class: KNOWN_RATIO
+        ratio: 1
+        inputs: 2
+      }
+    }
+    nodes: {
+      key: 2
+      value: {
+        id: 2
+        name: "ParallelInterleaveV4"
+        autotune: true
+        num_elements: 100
+        processing_time: 2000
+        node_class: ASYNC_INTERLEAVE_MANY
+        inputs: 3
+        inputs: 4
+        inputs: 5
+        parameters: {
+          name: "parallelism"
+          value: 2
+          min: 1
+          max: 10
+          tunable: true
+        }
+        parameters: { name: "cycle_length" value: 2 tunable: false }
+        parameters: { name: "deterministic" value: 0 tunable: false }
+      }
+    }
+    nodes: {
+      key: 3
+      value: {
+        id: 3
+        name: "Batch"
+        autotune: true
+        num_elements: 60
+        processing_time: 60
+        node_class: KNOWN_RATIO
+        ratio: 1
+      }
+    }
+    nodes: {
+      key: 4
+      value: {
+        id: 4
+        name: "Batch"
+        autotune: true
+        num_elements: 60
+        processing_time: 1200
+        node_class: KNOWN_RATIO
+        ratio: 2
+        inputs: 6
+      }
+    }
+    nodes: {
+      key: 5
+      value: {
+        id: 5
+        name: "Batch"
+        autotune: true
+        num_elements: 40
+        processing_time: 800
+        node_class: KNOWN_RATIO
+        ratio: 2
+        inputs: 7
+      }
+    }
+    nodes: {
+      key: 6
+      value: {
+        id: 6
+        name: "ParallelMapV2"
+        autotune: true
+        num_elements: 120
+        processing_time: 2400
+        node_class: ASYNC_KNOWN_RATIO
+        ratio: 1
+        parameters: {
+          name: "parallelism"
+          value: 2
+          min: 1
+          max: 16
+          tunable: true
+        }
+      }
+    }
+    nodes: {
+      key: 7
+      value: {
+        id: 7
+        name: "ParallelMapV2"
+        autotune: true
+        num_elements: 120
+        processing_time: 2400
+        node_class: ASYNC_KNOWN_RATIO
+        ratio: 1
+        parameters: {
+          name: "parallelism"
+          value: 2
+          min: 1
+          max: 16
+          tunable: true
+        }
+      }
+    }
+    output: 1
+  )pb");
 
-  std::shared_ptr<model::Model> model = std::make_shared<model::Model>();
-  model->AddNode([&batch_1](model::Node::Args args) { return batch_1; },
-                 "batch_1", nullptr, &batch_1);
-  model->AddNode([&parallel_interleave](
-                     model::Node::Args args) { return parallel_interleave; },
-                 "parallel_interleave", batch_1, &parallel_interleave);
-  model->AddNode([&first_input](model::Node::Args args) { return first_input; },
-                 "first_input", parallel_interleave, &first_input);
-  model->AddNode([&batch_2](model::Node::Args args) { return batch_2; },
-                 "batch_2", parallel_interleave, &batch_2);
-  model->AddNode([&batch_3](model::Node::Args args) { return batch_3; },
-                 "batch_3", parallel_interleave, &batch_3);
-  model->AddNode(
-      [&parallel_map_1](model::Node::Args args) { return parallel_map_1; },
-      "parallel_map_1", batch_2, &parallel_map_1);
-  model->AddNode(
-      [&parallel_map_2](model::Node::Args args) { return parallel_map_2; },
-      "parallel_map_2", batch_3, &parallel_map_2);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/1)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/2)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/3)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/4)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/5)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/6)->pipeline_ratio);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/7)->pipeline_ratio);
 
-  ModelTiming model_timing(model);
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_1.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(parallel_interleave.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_2.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(batch_3.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(parallel_map_1.get()));
-  EXPECT_NE(nullptr, model_timing.GetTiming(parallel_map_2.get()));
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/2)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/3)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/5)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/6)->self_time_nsec);
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/7)->self_time_nsec);
 
-  EXPECT_DOUBLE_EQ(1, model_timing.GetTiming(batch_1.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(
-      1, model_timing.GetTiming(parallel_interleave.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(0.5, model_timing.GetTiming(batch_2.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(0.5, model_timing.GetTiming(batch_3.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(
-      1, model_timing.GetTiming(parallel_map_1.get())->pipeline_ratio);
-  EXPECT_DOUBLE_EQ(
-      1, model_timing.GetTiming(parallel_map_2.get())->pipeline_ratio);
-
-  EXPECT_DOUBLE_EQ(10, model_timing.GetTiming(batch_1.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(
-      10, model_timing.GetTiming(parallel_interleave.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_2.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_3.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(
-      10, model_timing.GetTiming(parallel_map_1.get())->self_time_nsec);
-  EXPECT_DOUBLE_EQ(
-      10, model_timing.GetTiming(parallel_map_2.get())->self_time_nsec);
-
-  EXPECT_DOUBLE_EQ(10, model_timing.GetTiming(batch_1.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(
-      20, model_timing.GetTiming(parallel_interleave.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_2.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(20, model_timing.GetTiming(batch_3.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(
-      10, model_timing.GetTiming(parallel_map_1.get())->total_time_nsec);
-  EXPECT_DOUBLE_EQ(
-      10, model_timing.GetTiming(parallel_map_2.get())->total_time_nsec);
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/2)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/3)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/5)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/6)->total_time_nsec);
+  EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/7)->total_time_nsec);
 }
 
 TEST(ModelTest, ModelMetrics) {