Support importing a dependency tree of protos.
diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/BUILD b/tools/distrib/python/grpcio_tools/grpc_tools/BUILD
index 396d24b..c11a022 100644
--- a/tools/distrib/python/grpcio_tools/grpc_tools/BUILD
+++ b/tools/distrib/python/grpcio_tools/grpc_tools/BUILD
@@ -43,6 +43,11 @@
name = "protoc_test",
srcs = ["protoc_test.py"],
deps = ["//tools/distrib/python/grpcio_tools/grpc_tools:grpc_tools"],
- data = ["simple.proto"],
+ data = [
+ "simple.proto",
+ "simpler.proto",
+ "simplest.proto",
+ "complicated.proto",
+ ],
python_version = "PY3",
)
diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/_protoc_compiler.pyx b/tools/distrib/python/grpcio_tools/grpc_tools/_protoc_compiler.pyx
index 9f107e4..c82129b 100644
--- a/tools/distrib/python/grpcio_tools/grpc_tools/_protoc_compiler.pyx
+++ b/tools/distrib/python/grpcio_tools/grpc_tools/_protoc_compiler.pyx
@@ -13,8 +13,8 @@
# limitations under the License.
from libc cimport stdlib
-from libcpp.map cimport map
from libcpp.vector cimport vector
+from libcpp.utility cimport pair
from libcpp.string cimport string
from cython.operator cimport dereference
@@ -35,8 +35,8 @@
string message
int protoc_main(int argc, char *argv[])
- int protoc_get_protos(char* protobuf_path, char* include_path, map[string, string]* files_out, vector[cProtocError]* errors, vector[cProtocWarning]* wrnings) except +
- int protoc_get_services(char* protobuf_path, char* include_path, map[string, string]* files_out, vector[cProtocError]* errors, vector[cProtocWarning]* wrnings) except +
+ int protoc_get_protos(char* protobuf_path, char* include_path, vector[pair[string, string]]* files_out, vector[cProtocError]* errors, vector[cProtocWarning]* wrnings) except +
+ int protoc_get_services(char* protobuf_path, char* include_path, vector[pair[string, string]]* files_out, vector[cProtocError]* errors, vector[cProtocWarning]* wrnings) except +
def run_main(list args not None):
cdef char **argv = <char **>stdlib.malloc(len(args)*sizeof(char *))
@@ -88,7 +88,7 @@
raise Exception("An unknown error occurred while compiling {}".format(protobuf_path))
def get_protos(bytes protobuf_path, bytes include_path):
- cdef map[string, string] files
+ cdef vector[pair[string, string]] files
cdef vector[cProtocError] errors
# NOTE: Abbreviated name used to shadowing of the module name.
cdef vector[cProtocWarning] wrnings
@@ -97,7 +97,7 @@
return files
def get_services(bytes protobuf_path, bytes include_path):
- cdef map[string, string] files
+ cdef vector[pair[string, string]] files
cdef vector[cProtocError] errors
# NOTE: Abbreviated name used to shadowing of the module name.
cdef vector[cProtocWarning] wrnings
diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/complicated.proto b/tools/distrib/python/grpcio_tools/grpc_tools/complicated.proto
new file mode 100644
index 0000000..c6c500d
--- /dev/null
+++ b/tools/distrib/python/grpcio_tools/grpc_tools/complicated.proto
@@ -0,0 +1,12 @@
+syntax = "proto3";
+
+package grpc_tools.complicated;
+
+import "grpc_tools/simplest.proto";
+
+message ComplicatedMessage {
+ bool yes = 1;
+ bool no = 2;
+ bool why = 3;
+ grpc_tools.simplest.SimplestMessage simplest_message = 4;
+};
diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/main.cc b/tools/distrib/python/grpcio_tools/grpc_tools/main.cc
index 6506341..82fbf7a 100644
--- a/tools/distrib/python/grpcio_tools/grpc_tools/main.cc
+++ b/tools/distrib/python/grpcio_tools/grpc_tools/main.cc
@@ -26,11 +26,15 @@
#include <google/protobuf/descriptor.h>
// TODO: Clang format.
+#include <algorithm>
#include <vector>
#include <map>
#include <string>
#include <tuple>
+// TODO: Remove.
+#include <iostream>
+
int protoc_main(int argc, char* argv[]) {
google::protobuf::compiler::CommandLineInterface cli;
cli.AllowPlugins("protoc-");
@@ -58,14 +62,13 @@
class GeneratorContextImpl : public ::google::protobuf::compiler::GeneratorContext {
public:
GeneratorContextImpl(const std::vector<const ::google::protobuf::FileDescriptor*>& parsed_files,
- std::map<std::string, std::string>* files_out) :
+ std::vector<std::pair<std::string, std::string>>* files_out) :
files_(files_out),
parsed_files_(parsed_files) {}
::google::protobuf::io::ZeroCopyOutputStream* Open(const std::string& filename) {
- // TODO(rbellevi): Learn not to dream impossible dreams. :(
- auto [iter, _] = files_->emplace(filename, "");
- return new ::google::protobuf::io::StringOutputStream(&(iter->second));
+ files_->emplace_back(filename, "");
+ return new ::google::protobuf::io::StringOutputStream(&(files_->back().second));
}
// NOTE: Equivalent to Open, since all files start out empty.
@@ -84,7 +87,7 @@
}
private:
- std::map<std::string, std::string>* files_;
+ std::vector<std::pair<std::string, std::string>>* files_;
const std::vector<const ::google::protobuf::FileDescriptor*>& parsed_files_;
};
@@ -113,11 +116,25 @@
} // end namespace detail
+static void calculate_transitive_closure(const ::google::protobuf::FileDescriptor* descriptor,
+ std::vector<const ::google::protobuf::FileDescriptor*>* transitive_closure)
+{
+ for (int i = 0; i < descriptor->dependency_count(); ++i) {
+ const ::google::protobuf::FileDescriptor* dependency = descriptor->dependency(i);
+ // NOTE: Probably want an O(1) lookup method for very large transitive
+ // closures.
+ if (std::find(transitive_closure->begin(), transitive_closure->end(), dependency) == transitive_closure->end()) {
+ calculate_transitive_closure(dependency, transitive_closure);
+ }
+ }
+ transitive_closure->push_back(descriptor);
+}
+
// TODO: Handle multiple include paths.
static int generate_code(::google::protobuf::compiler::CodeGenerator* code_generator,
char* protobuf_path,
char* include_path,
- std::map<std::string, std::string>* files_out,
+ std::vector<std::pair<std::string, std::string>>* files_out,
std::vector<ProtocError>* errors,
std::vector<ProtocWarning>* warnings)
{
@@ -129,16 +146,21 @@
if (parsed_file == nullptr) {
return 1;
}
- detail::GeneratorContextImpl generator_context({parsed_file}, files_out);
+ // TODO: Figure out if the dependency list is flat or recursive.
+ // TODO: Ensure there's a topological ordering here.
+ std::vector<const ::google::protobuf::FileDescriptor*> transitive_closure;
+ calculate_transitive_closure(parsed_file, &transitive_closure);
+ detail::GeneratorContextImpl generator_context(transitive_closure, files_out);
std::string error;
- ::google::protobuf::compiler::python::Generator python_generator;
- python_generator.Generate(parsed_file, "", &generator_context, &error);
+ for (const auto descriptor : transitive_closure) {
+ code_generator->Generate(descriptor, "", &generator_context, &error);
+ }
return 0;
}
int protoc_get_protos(char* protobuf_path,
char* include_path,
- std::map<std::string, std::string>* files_out,
+ std::vector<std::pair<std::string, std::string>>* files_out,
std::vector<ProtocError>* errors,
std::vector<ProtocWarning>* warnings)
{
@@ -148,7 +170,7 @@
int protoc_get_services(char* protobuf_path,
char* include_path,
- std::map<std::string, std::string>* files_out,
+ std::vector<std::pair<std::string, std::string>>* files_out,
std::vector<ProtocError>* errors,
std::vector<ProtocWarning>* warnings)
{
diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/main.h b/tools/distrib/python/grpcio_tools/grpc_tools/main.h
index b5211ab..33473de 100644
--- a/tools/distrib/python/grpcio_tools/grpc_tools/main.h
+++ b/tools/distrib/python/grpcio_tools/grpc_tools/main.h
@@ -13,6 +13,11 @@
// limitations under the License.
+#include <vector>
+#include <string>
+#include <utility>
+
+
// We declare `protoc_main` here since we want access to it from Cython as an
// extern but *without* triggering a dllimport declspec when on Windows.
int protoc_main(int argc, char *argv[]);
@@ -33,14 +38,15 @@
typedef ProtocError ProtocWarning;
+// TODO: Create Alias for files_out type?
int protoc_get_protos(char* protobuf_path,
char* include_path,
- std::map<std::string, std::string>* files_out,
+ std::vector<std::pair<std::string, std::string>>* files_out,
std::vector<ProtocError>* errors,
std::vector<ProtocWarning>* warnings);
int protoc_get_services(char* protobuf_path,
char* include_path,
- std::map<std::string, std::string>* files_out,
+ std::vector<std::pair<std::string, std::string>>* files_out,
std::vector<ProtocError>* errors,
std::vector<ProtocWarning>* warnings);
diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/protoc.py b/tools/distrib/python/grpcio_tools/grpc_tools/protoc.py
index ae7e962..8de10c3 100644
--- a/tools/distrib/python/grpcio_tools/grpc_tools/protoc.py
+++ b/tools/distrib/python/grpcio_tools/grpc_tools/protoc.py
@@ -36,27 +36,27 @@
def _import_modules_from_files(files):
modules = []
- # TODO: Ensure pointer equality between two invocations of this function.
- for filename, code in six.iteritems(files):
- print("Filename {}".format(filename))
+ for filename, code in files:
base_name = os.path.basename(filename.decode('ascii'))
proto_name, _ = os.path.splitext(base_name)
anchor_package = ".".join(os.path.normpath(os.path.dirname(filename.decode('ascii'))).split(os.sep))
module_name = "{}.{}".format(anchor_package, proto_name)
- module = imp.new_module(module_name)
- six.exec_(code, module.__dict__)
- modules.append(module)
- print("Inserting module {}".format(module_name))
- sys.modules[module_name] = module
+ if module_name not in sys.modules:
+ module = imp.new_module(module_name)
+ six.exec_(code, module.__dict__)
+ sys.modules[module_name] = module
+ modules.append(module)
+ else:
+ modules.append(sys.modules[module_name])
return tuple(modules)
def get_protos(protobuf_path, include_path):
files = _protoc_compiler.get_protos(protobuf_path.encode('ascii'), include_path.encode('ascii'))
- return _import_modules_from_files(files)
+ return _import_modules_from_files(files)[-1]
def get_services(protobuf_path, include_path):
files = _protoc_compiler.get_services(protobuf_path.encode('ascii'), include_path.encode('ascii'))
- return _import_modules_from_files(files)
+ return _import_modules_from_files(files)[-1]
if __name__ == '__main__':
diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/protoc_test.py b/tools/distrib/python/grpcio_tools/grpc_tools/protoc_test.py
index 4214479..f5b0e9f 100644
--- a/tools/distrib/python/grpcio_tools/grpc_tools/protoc_test.py
+++ b/tools/distrib/python/grpcio_tools/grpc_tools/protoc_test.py
@@ -20,13 +20,47 @@
# def test_stdout_pollution(self):
# pass
- def test_protoc_in_memory(self):
+ # def test_protoc_in_memory(self):
+ # from grpc_tools import protoc
+ # proto_path = "tools/distrib/python/grpcio_tools/"
+ # protos = protoc.get_protos("grpc_tools/simple.proto", proto_path)
+ # print(protos.SimpleMessageRequest)
+ # services = protoc.get_services("grpc_tools/simple.proto", proto_path)
+ # print(services.SimpleMessageServiceServicer)
+ # complicated_protos = protoc.get_protos("grpc_tools/complicated.proto", proto_path)
+ # print(complicated_protos.ComplicatedMessage)
+ # print(dir(complicated_protos.grpc__tools_dot_simplest__pb2))
+ # print(dir(protos.grpc__tools_dot_simpler__pb2.grpc__tools_dot_simplest__pb2))
+ # print("simplest is simplest: {}".format(complicated_protos.grpc__tools_dot_simplest__pb2.SimplestMessage is protos.grpc__tools_dot_simpler__pb2.grpc__tools_dot_simplest__pb2.SimplesMessage))
+
+ # TODO: Test error messages.
+
+ # TODO: These test cases have to run in different processes to be truly
+ # independent of one another.
+
+ def test_import_protos(self):
from grpc_tools import protoc
proto_path = "tools/distrib/python/grpcio_tools/"
- protos, = protoc.get_protos("grpc_tools/simple.proto", proto_path)
- print(protos.SimpleMessageRequest)
- services, = protoc.get_services("grpc_tools/simple.proto", proto_path)
- print("Services: {}".format(dir(services)))
+ protos = protoc.get_protos("grpc_tools/simple.proto", proto_path)
+ self.assertIsNotNone(protos.SimpleMessage)
+
+ def test_import_services(self):
+ from grpc_tools import protoc
+ proto_path = "tools/distrib/python/grpcio_tools/"
+ # TODO: Should we make this step optional if you only want to import
+ # services?
+ protos = protoc.get_protos("grpc_tools/simple.proto", proto_path)
+ services = protoc.get_services("grpc_tools/simple.proto", proto_path)
+ self.assertIsNotNone(services.SimpleMessageServiceStub)
+
+ def test_proto_module_imported_once(self):
+ from grpc_tools import protoc
+ proto_path = "tools/distrib/python/grpcio_tools/"
+ protos = protoc.get_protos("grpc_tools/simple.proto", proto_path)
+ services = protoc.get_services("grpc_tools/simple.proto", proto_path)
+ complicated_protos = protoc.get_protos("grpc_tools/complicated.proto", proto_path)
+ self.assertIs(complicated_protos.grpc__tools_dot_simplest__pb2.SimplestMessage,
+ protos.grpc__tools_dot_simpler__pb2.grpc__tools_dot_simplest__pb2.SimplestMessage)
if __name__ == '__main__':
diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/simple.proto b/tools/distrib/python/grpcio_tools/grpc_tools/simple.proto
index c4c2141..2f7d56f 100644
--- a/tools/distrib/python/grpcio_tools/grpc_tools/simple.proto
+++ b/tools/distrib/python/grpcio_tools/grpc_tools/simple.proto
@@ -1,11 +1,17 @@
+// TODO: Move these proto files to a dedicated directory.
syntax = "proto3";
+package grpc_tools.simple;
+
+import "grpc_tools/simpler.proto";
+
message SimpleMessage {
string msg = 1;
oneof personal_or_business {
bool personal = 2;
bool business = 3;
};
+ grpc_tools.simpler.SimplerMessage simpler_message = 4;
};
message SimpleMessageRequest {
diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/simpler.proto b/tools/distrib/python/grpcio_tools/grpc_tools/simpler.proto
new file mode 100644
index 0000000..9e0943a
--- /dev/null
+++ b/tools/distrib/python/grpcio_tools/grpc_tools/simpler.proto
@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package grpc_tools.simpler;
+
+import "grpc_tools/simplest.proto";
+
+message SimplerMessage {
+ int64 do_i_even_exist = 1;
+ grpc_tools.simplest.SimplestMessage simplest_message = 2;
+};
+
diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/simplest.proto b/tools/distrib/python/grpcio_tools/grpc_tools/simplest.proto
new file mode 100644
index 0000000..6056c4a
--- /dev/null
+++ b/tools/distrib/python/grpcio_tools/grpc_tools/simplest.proto
@@ -0,0 +1,8 @@
+syntax = "proto3";
+
+package grpc_tools.simplest;
+
+message SimplestMessage {
+ int64 i_definitely_dont_exist = 1;
+};
+