Merge "Packet: Shard pybind11 Python binding generation for faster compilation" am: 9bc9f7850b
am: c78ae4e586

Change-Id: I9f956df3f301f8a2c2a2feb00406dc70987b3e92
diff --git a/gd/Android.bp b/gd/Android.bp
index 19eaec3..6aaf2df 100644
--- a/gd/Android.bp
+++ b/gd/Android.bp
@@ -363,7 +363,7 @@
     tools: [
         "bluetooth_packetgen",
     ],
-    cmd: "$(location bluetooth_packetgen) --include=system/bt/gd --out=$(genDir) $(in)",
+    cmd: "$(location bluetooth_packetgen) --include=system/bt/gd --out=$(genDir) --num_shards=5 $(in)",
     srcs: [
         "hci/hci_packets.pdl",
         "l2cap/l2cap_packets.pdl",
@@ -371,8 +371,23 @@
     ],
     out: [
         "hci/hci_packets_python3.cc",
+        "hci/hci_packets_python3_shard_0.cc",
+        "hci/hci_packets_python3_shard_1.cc",
+        "hci/hci_packets_python3_shard_2.cc",
+        "hci/hci_packets_python3_shard_3.cc",
+        "hci/hci_packets_python3_shard_4.cc",
         "l2cap/l2cap_packets_python3.cc",
+        "l2cap/l2cap_packets_python3_shard_0.cc",
+        "l2cap/l2cap_packets_python3_shard_1.cc",
+        "l2cap/l2cap_packets_python3_shard_2.cc",
+        "l2cap/l2cap_packets_python3_shard_3.cc",
+        "l2cap/l2cap_packets_python3_shard_4.cc",
         "security/smp_packets_python3.cc",
+        "security/smp_packets_python3_shard_0.cc",
+        "security/smp_packets_python3_shard_1.cc",
+        "security/smp_packets_python3_shard_2.cc",
+        "security/smp_packets_python3_shard_3.cc",
+        "security/smp_packets_python3_shard_4.cc",
     ],
 }
 
diff --git a/gd/hci/address.h b/gd/hci/address.h
index a8f515c..3bc507f 100644
--- a/gd/hci/address.h
+++ b/gd/hci/address.h
@@ -18,6 +18,7 @@
 
 #pragma once
 
+#include <cstring>
 #include <string>
 
 namespace bluetooth {
diff --git a/gd/packet/iterator.h b/gd/packet/iterator.h
index f15e561..8d927a0 100644
--- a/gd/packet/iterator.h
+++ b/gd/packet/iterator.h
@@ -18,6 +18,7 @@
 
 #include <cstdint>
 #include <forward_list>
+#include <memory>
 
 #include "packet/view.h"
 
diff --git a/gd/packet/parser/main.cc b/gd/packet/parser/main.cc
index 05475b6..757523d 100644
--- a/gd/packet/parser/main.cc
+++ b/gd/packet/parser/main.cc
@@ -230,9 +230,16 @@
   return true;
 }
 
+// Get the out_file shard at a symbol_count
+std::ofstream& get_out_file(size_t symbol_count, size_t symbol_total, std::vector<std::ofstream>* out_files) {
+  auto symbols_per_shard = symbol_total / out_files->size();
+  auto file_index = std::min(symbol_count / symbols_per_shard, out_files->size() - 1);
+  return out_files->at(file_index);
+}
+
 bool generate_pybind11_sources_one_file(const Declarations& decls, const std::filesystem::path& input_file,
                                         const std::filesystem::path& include_dir, const std::filesystem::path& out_dir,
-                                        const std::string& root_namespace) {
+                                        const std::string& root_namespace, size_t num_shards) {
   auto gen_relative_path = input_file.lexically_relative(include_dir).parent_path();
 
   auto input_filename = input_file.filename().string().substr(0, input_file.filename().string().find(".pdl"));
@@ -241,85 +248,132 @@
   std::filesystem::create_directories(gen_path);
 
   auto gen_relative_header = gen_relative_path / (input_filename + ".h");
-  auto gen_file = gen_path / (input_filename + "_python3.cc");
-
-  std::ofstream out_file;
-  out_file.open(gen_file);
-  if (!out_file.is_open()) {
-    std::cerr << "can't open " << gen_file << std::endl;
-    return false;
-  }
-
-  out_file << "#include <pybind11/pybind11.h>\n";
-  out_file << "#include <pybind11/stl.h>\n";
-  out_file << "\n\n";
-  out_file << "#include " << gen_relative_header << "\n";
-  out_file << "\n\n";
 
   std::vector<std::string> namespace_list;
   parse_namespace(root_namespace, gen_relative_path, &namespace_list);
-  generate_namespace_open(namespace_list, out_file);
-  out_file << "\n\n";
 
-  for (const auto& c : decls.type_defs_queue_) {
-    if (c.second->GetDefinitionType() == TypeDef::Type::CUSTOM ||
-        c.second->GetDefinitionType() == TypeDef::Type::CHECKSUM) {
-      const auto* custom_def = dynamic_cast<const CustomFieldDef*>(c.second);
-      custom_def->GenUsing(out_file);
+  std::vector<std::ofstream> out_file_shards(num_shards);
+  for (size_t i = 0; i < out_file_shards.size(); i++) {
+    auto filename = gen_path / (input_filename + "_python3_shard_" + std::to_string(i) + ".cc");
+    auto& out_file = out_file_shards[i];
+    out_file.open(filename);
+    if (!out_file.is_open()) {
+      std::cerr << "can't open " << filename << std::endl;
+      return false;
+    }
+    out_file << "#include <pybind11/pybind11.h>\n";
+    out_file << "#include <pybind11/stl.h>\n";
+    out_file << "\n\n";
+    out_file << "#include " << gen_relative_header << "\n";
+    out_file << "\n\n";
+
+    generate_namespace_open(namespace_list, out_file);
+    out_file << "\n\n";
+
+    for (const auto& c : decls.type_defs_queue_) {
+      if (c.second->GetDefinitionType() == TypeDef::Type::CUSTOM ||
+          c.second->GetDefinitionType() == TypeDef::Type::CHECKSUM) {
+        const auto* custom_def = dynamic_cast<const CustomFieldDef*>(c.second);
+        custom_def->GenUsing(out_file);
+      }
+    }
+    out_file << "\n\n";
+
+    out_file << "using ::bluetooth::packet::BasePacketBuilder;";
+    out_file << "using ::bluetooth::packet::BitInserter;";
+    out_file << "using ::bluetooth::packet::CustomTypeChecker;";
+    out_file << "using ::bluetooth::packet::Iterator;";
+    out_file << "using ::bluetooth::packet::kLittleEndian;";
+    out_file << "using ::bluetooth::packet::PacketBuilder;";
+    out_file << "using ::bluetooth::packet::BaseStruct;";
+    out_file << "using ::bluetooth::packet::PacketStruct;";
+    out_file << "using ::bluetooth::packet::PacketView;";
+    out_file << "using ::bluetooth::packet::parser::ChecksumTypeChecker;";
+    out_file << "\n\n";
+
+    out_file << "namespace py = pybind11;\n\n";
+
+    out_file << "void define_" << input_filename << "_submodule_shard_" << std::to_string(i) << "(py::module& m) {\n\n";
+  }
+  size_t symbol_total = 0;
+  // Only count types that will be generated
+  for (const auto& e : decls.type_defs_queue_) {
+    if (e.second->GetDefinitionType() == TypeDef::Type::ENUM) {
+      symbol_total++;
+    } else if (e.second->GetDefinitionType() == TypeDef::Type::STRUCT) {
+      symbol_total++;
     }
   }
-  out_file << "\n\n";
-
-  out_file << "using ::bluetooth::packet::BasePacketBuilder;";
-  out_file << "using ::bluetooth::packet::BitInserter;";
-  out_file << "using ::bluetooth::packet::CustomTypeChecker;";
-  out_file << "using ::bluetooth::packet::Iterator;";
-  out_file << "using ::bluetooth::packet::kLittleEndian;";
-  out_file << "using ::bluetooth::packet::PacketBuilder;";
-  out_file << "using ::bluetooth::packet::BaseStruct;";
-  out_file << "using ::bluetooth::packet::PacketStruct;";
-  out_file << "using ::bluetooth::packet::PacketView;";
-  out_file << "using ::bluetooth::packet::parser::ChecksumTypeChecker;";
-  out_file << "\n\n";
-
-  out_file << "namespace py = pybind11;\n\n";
-
-  out_file << "void define_" << input_filename << "_submodule(py::module& parent) {\n\n";
-  out_file << "py::module m = parent.def_submodule(\"" << input_filename << "\", \"A submodule of " << input_filename
-           << "\");\n\n";
+  // View and builder are counted separately
+  symbol_total += decls.packet_defs_queue_.size() * 2;
+  size_t symbol_count = 0;
 
   for (const auto& e : decls.type_defs_queue_) {
     if (e.second->GetDefinitionType() == TypeDef::Type::ENUM) {
       const auto* enum_def = dynamic_cast<const EnumDef*>(e.second);
       EnumGen gen(*enum_def);
+      auto& out_file = get_out_file(symbol_count, symbol_total, &out_file_shards);
       gen.GenDefinitionPybind11(out_file);
       out_file << "\n\n";
+      symbol_count++;
     }
   }
 
   for (const auto& s : decls.type_defs_queue_) {
     if (s.second->GetDefinitionType() == TypeDef::Type::STRUCT) {
       const auto* struct_def = dynamic_cast<const StructDef*>(s.second);
+      auto& out_file = get_out_file(symbol_count, symbol_total, &out_file_shards);
       struct_def->GenDefinitionPybind11(out_file);
       out_file << "\n";
+      symbol_count++;
     }
   }
 
   for (const auto& packet_def : decls.packet_defs_queue_) {
+    auto& out_file = get_out_file(symbol_count, symbol_total, &out_file_shards);
     packet_def.second.GenParserDefinitionPybind11(out_file);
     out_file << "\n\n";
+    symbol_count++;
   }
 
   for (const auto& p : decls.packet_defs_queue_) {
+    auto& out_file = get_out_file(symbol_count, symbol_total, &out_file_shards);
     p.second.GenBuilderDefinitionPybind11(out_file);
     out_file << "\n\n";
+    symbol_count++;
   }
 
-  out_file << "}\n\n";
+  for (auto& out_file : out_file_shards) {
+    out_file << "}\n\n";
+    generate_namespace_close(namespace_list, out_file);
+  }
 
-  generate_namespace_close(namespace_list, out_file);
+  auto gen_file_main = gen_path / (input_filename + "_python3.cc");
+  std::ofstream out_file_main;
+  out_file_main.open(gen_file_main);
+  if (!out_file_main.is_open()) {
+    std::cerr << "can't open " << gen_file_main << std::endl;
+    return false;
+  }
+  out_file_main << "#include <pybind11/pybind11.h>\n";
+  generate_namespace_open(namespace_list, out_file_main);
 
-  out_file.close();
+  out_file_main << "namespace py = pybind11;\n\n";
+
+  for (size_t i = 0; i < out_file_shards.size(); i++) {
+    out_file_main << "void define_" << input_filename << "_submodule_shard_" << std::to_string(i)
+                  << "(py::module& m);\n";
+  }
+
+  out_file_main << "void define_" << input_filename << "_submodule(py::module& parent) {\n\n";
+  out_file_main << "py::module m = parent.def_submodule(\"" << input_filename << "\", \"A submodule of "
+                << input_filename << "\");\n\n";
+  for (size_t i = 0; i < out_file_shards.size(); i++) {
+    out_file_main << "define_" << input_filename << "_submodule_shard_" << std::to_string(i) << "(m);\n";
+  }
+  out_file_main << "}\n\n";
+
+  generate_namespace_close(namespace_list, out_file_main);
 
   return true;
 }
@@ -335,10 +389,13 @@
   std::filesystem::path out_dir;
   std::filesystem::path include_dir;
   std::string root_namespace = "bluetooth";
+  // Number of shards per output pybind11 cc file
+  size_t num_shards = 1;
   std::queue<std::filesystem::path> input_files;
   const std::string arg_out = "--out=";
   const std::string arg_include = "--include=";
   const std::string arg_namespace = "--root_namespace=";
+  const std::string arg_num_shards = "--num_shards=";
 
   for (int i = 1; i < argc; i++) {
     std::string arg = argv[i];
@@ -348,12 +405,15 @@
       include_dir = std::filesystem::current_path() / std::filesystem::path(arg.substr(arg_include.size()));
     } else if (arg.find(arg_namespace) == 0) {
       root_namespace = arg.substr(arg_namespace.size());
+    } else if (arg.find(arg_num_shards) == 0) {
+      num_shards = std::stoul(arg.substr(arg_num_shards.size()));
     } else {
       input_files.emplace(std::filesystem::current_path() / std::filesystem::path(arg));
     }
   }
-  if (out_dir == std::filesystem::path() || include_dir == std::filesystem::path()) {
-    std::cerr << "Usage: bt-packetgen --out=OUT --include=INCLUDE --root=NAMESPACE input_files..." << std::endl;
+  if (out_dir == std::filesystem::path() || include_dir == std::filesystem::path() || num_shards == 0) {
+    std::cerr << "Usage: bt-packetgen --out=OUT --include=INCLUDE --root_namespace=NAMESPACE --num_shards=NUM_SHARDS "
+              << "input_files..." << std::endl;
     return 1;
   }
 
@@ -367,7 +427,8 @@
       std::cerr << "Didn't generate cpp headers for " << input_files.front() << std::endl;
       return 3;
     }
-    if (!generate_pybind11_sources_one_file(declarations, input_files.front(), include_dir, out_dir, root_namespace)) {
+    if (!generate_pybind11_sources_one_file(declarations, input_files.front(), include_dir, out_dir, root_namespace,
+                                            num_shards)) {
       std::cerr << "Didn't generate pybind11 sources for " << input_files.front() << std::endl;
       return 4;
     }
diff --git a/gd/packet/view.h b/gd/packet/view.h
index 4f3b508..3b8b679 100644
--- a/gd/packet/view.h
+++ b/gd/packet/view.h
@@ -17,6 +17,7 @@
 #pragma once
 
 #include <cstdint>
+#include <memory>
 #include <vector>
 
 namespace bluetooth {