| // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0. |
| |
| use std::io::{Error, ErrorKind, Read}; |
| use std::path::Path; |
| use std::{fs, io, process::Command}; |
| |
| use derive_new::new; |
| use prost::Message; |
| use prost_build::{protoc, protoc_include, Config, Method, Service, ServiceGenerator}; |
| use prost_types::FileDescriptorSet; |
| |
| use crate::util::{fq_grpc, to_snake_case, MethodType}; |
| |
| /// Returns the names of all packages compiled. |
| pub fn compile_protos<P>(protos: &[P], includes: &[P], out_dir: &str) -> io::Result<Vec<String>> |
| where |
| P: AsRef<Path>, |
| { |
| let mut prost_config = Config::new(); |
| prost_config.service_generator(Box::new(Generator)); |
| prost_config.out_dir(out_dir); |
| |
| // Create a file descriptor set for the protocol files. |
| let tmp = tempfile::Builder::new().prefix("prost-build").tempdir()?; |
| let descriptor_set = tmp.path().join("prost-descriptor-set"); |
| |
| let mut cmd = Command::new(protoc()); |
| cmd.arg("--include_imports") |
| .arg("--include_source_info") |
| .arg("-o") |
| .arg(&descriptor_set); |
| |
| for include in includes { |
| cmd.arg("-I").arg(include.as_ref()); |
| } |
| |
| // Set the protoc include after the user includes in case the user wants to |
| // override one of the built-in .protos. |
| cmd.arg("-I").arg(protoc_include()); |
| |
| for proto in protos { |
| cmd.arg(proto.as_ref()); |
| } |
| |
| let output = cmd.output()?; |
| if !output.status.success() { |
| return Err(Error::new( |
| ErrorKind::Other, |
| format!("protoc failed: {}", String::from_utf8_lossy(&output.stderr)), |
| )); |
| } |
| |
| let mut buf = Vec::new(); |
| fs::File::open(descriptor_set)?.read_to_end(&mut buf)?; |
| let descriptor_set = FileDescriptorSet::decode(buf.as_slice())?; |
| |
| // Get the package names from the descriptor set. |
| let mut packages: Vec<_> = descriptor_set |
| .file |
| .iter() |
| .filter_map(|f| f.package.clone()) |
| .collect(); |
| packages.sort(); |
| packages.dedup(); |
| |
| // FIXME(https://github.com/danburkert/prost/pull/155) |
| // Unfortunately we have to forget the above work and use `compile_protos` to |
| // actually generate the Rust code. |
| prost_config.compile_protos(protos, includes)?; |
| |
| Ok(packages) |
| } |
| |
| struct Generator; |
| |
| impl ServiceGenerator for Generator { |
| fn generate(&mut self, service: Service, buf: &mut String) { |
| generate_methods(&service, buf); |
| generate_client(&service, buf); |
| generate_server(&service, buf); |
| } |
| } |
| |
| fn generate_methods(service: &Service, buf: &mut String) { |
| let service_path = if service.package.is_empty() { |
| format!("/{}", service.proto_name) |
| } else { |
| format!("/{}.{}", service.package, service.proto_name) |
| }; |
| |
| for method in &service.methods { |
| generate_method(&service.name, &service_path, method, buf); |
| } |
| } |
| |
| fn const_method_name(service_name: &str, method: &Method) -> String { |
| format!( |
| "METHOD_{}_{}", |
| to_snake_case(service_name).to_uppercase(), |
| method.name.to_uppercase() |
| ) |
| } |
| |
| fn generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String) { |
| let name = const_method_name(service_name, method); |
| let ty = format!( |
| "{}<{}, {}>", |
| fq_grpc("Method"), |
| method.input_type, |
| method.output_type |
| ); |
| |
| buf.push_str("const "); |
| buf.push_str(&name); |
| buf.push_str(": "); |
| buf.push_str(&ty); |
| buf.push_str(" = "); |
| generate_method_body(service_path, method, buf); |
| } |
| |
| fn generate_method_body(service_path: &str, method: &Method, buf: &mut String) { |
| let ty = fq_grpc(&MethodType::from_method(method).to_string()); |
| let pr_mar = format!( |
| "{} {{ ser: {}, de: {} }}", |
| fq_grpc("Marshaller"), |
| fq_grpc("pr_ser"), |
| fq_grpc("pr_de") |
| ); |
| |
| buf.push_str(&fq_grpc("Method")); |
| buf.push('{'); |
| generate_field_init("ty", &ty, buf); |
| generate_field_init( |
| "name", |
| &format!("\"{}/{}\"", service_path, method.proto_name), |
| buf, |
| ); |
| generate_field_init("req_mar", &pr_mar, buf); |
| generate_field_init("resp_mar", &pr_mar, buf); |
| buf.push_str("};\n"); |
| } |
| |
| // TODO share this code with protobuf codegen |
| impl MethodType { |
| fn from_method(method: &Method) -> MethodType { |
| match (method.client_streaming, method.server_streaming) { |
| (false, false) => MethodType::Unary, |
| (true, false) => MethodType::ClientStreaming, |
| (false, true) => MethodType::ServerStreaming, |
| (true, true) => MethodType::Duplex, |
| } |
| } |
| } |
| |
| fn generate_field_init(name: &str, value: &str, buf: &mut String) { |
| buf.push_str(name); |
| buf.push_str(": "); |
| buf.push_str(value); |
| buf.push_str(", "); |
| } |
| |
| fn generate_client(service: &Service, buf: &mut String) { |
| let client_name = format!("{}Client", service.name); |
| buf.push_str("#[derive(Clone)]\n"); |
| buf.push_str("pub struct "); |
| buf.push_str(&client_name); |
| buf.push_str(" { client: ::grpcio::Client }\n"); |
| |
| buf.push_str("impl "); |
| buf.push_str(&client_name); |
| buf.push_str(" {\n"); |
| generate_ctor(&client_name, buf); |
| generate_client_methods(service, buf); |
| generate_spawn(buf); |
| buf.push_str("}\n") |
| } |
| |
| fn generate_ctor(client_name: &str, buf: &mut String) { |
| buf.push_str("pub fn new(channel: ::grpcio::Channel) -> Self { "); |
| buf.push_str(client_name); |
| buf.push_str(" { client: ::grpcio::Client::new(channel) }"); |
| buf.push_str("}\n"); |
| } |
| |
| fn generate_client_methods(service: &Service, buf: &mut String) { |
| for method in &service.methods { |
| generate_client_method(&service.name, method, buf); |
| } |
| } |
| |
| fn generate_client_method(service_name: &str, method: &Method, buf: &mut String) { |
| let name = &format!( |
| "METHOD_{}_{}", |
| to_snake_case(service_name).to_uppercase(), |
| method.name.to_uppercase() |
| ); |
| match MethodType::from_method(method) { |
| MethodType::Unary => { |
| ClientMethod::new( |
| &method.name, |
| true, |
| Some(&method.input_type), |
| false, |
| vec![&method.output_type], |
| "unary_call", |
| name, |
| ) |
| .generate(buf); |
| ClientMethod::new( |
| &method.name, |
| false, |
| Some(&method.input_type), |
| false, |
| vec![&method.output_type], |
| "unary_call", |
| name, |
| ) |
| .generate(buf); |
| ClientMethod::new( |
| &method.name, |
| true, |
| Some(&method.input_type), |
| true, |
| vec![&format!( |
| "{}<{}>", |
| fq_grpc("ClientUnaryReceiver"), |
| method.output_type |
| )], |
| "unary_call", |
| name, |
| ) |
| .generate(buf); |
| ClientMethod::new( |
| &method.name, |
| false, |
| Some(&method.input_type), |
| true, |
| vec![&format!( |
| "{}<{}>", |
| fq_grpc("ClientUnaryReceiver"), |
| method.output_type |
| )], |
| "unary_call", |
| name, |
| ) |
| .generate(buf); |
| } |
| MethodType::ClientStreaming => { |
| ClientMethod::new( |
| &method.name, |
| true, |
| None, |
| false, |
| vec![ |
| &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type), |
| &format!( |
| "{}<{}>", |
| fq_grpc("ClientCStreamReceiver"), |
| method.output_type |
| ), |
| ], |
| "client_streaming", |
| name, |
| ) |
| .generate(buf); |
| ClientMethod::new( |
| &method.name, |
| false, |
| None, |
| false, |
| vec![ |
| &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type), |
| &format!( |
| "{}<{}>", |
| fq_grpc("ClientCStreamReceiver"), |
| method.output_type |
| ), |
| ], |
| "client_streaming", |
| name, |
| ) |
| .generate(buf); |
| } |
| MethodType::ServerStreaming => { |
| ClientMethod::new( |
| &method.name, |
| true, |
| Some(&method.input_type), |
| false, |
| vec![&format!( |
| "{}<{}>", |
| fq_grpc("ClientSStreamReceiver"), |
| method.output_type |
| )], |
| "server_streaming", |
| name, |
| ) |
| .generate(buf); |
| ClientMethod::new( |
| &method.name, |
| false, |
| Some(&method.input_type), |
| false, |
| vec![&format!( |
| "{}<{}>", |
| fq_grpc("ClientSStreamReceiver"), |
| method.output_type |
| )], |
| "server_streaming", |
| name, |
| ) |
| .generate(buf); |
| } |
| MethodType::Duplex => { |
| ClientMethod::new( |
| &method.name, |
| true, |
| None, |
| false, |
| vec![ |
| &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type), |
| &format!( |
| "{}<{}>", |
| fq_grpc("ClientDuplexReceiver"), |
| method.output_type |
| ), |
| ], |
| "duplex_streaming", |
| name, |
| ) |
| .generate(buf); |
| ClientMethod::new( |
| &method.name, |
| false, |
| None, |
| false, |
| vec![ |
| &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type), |
| &format!( |
| "{}<{}>", |
| fq_grpc("ClientDuplexReceiver"), |
| method.output_type |
| ), |
| ], |
| "duplex_streaming", |
| name, |
| ) |
| .generate(buf); |
| } |
| } |
| } |
| |
| #[derive(new)] |
| struct ClientMethod<'a> { |
| method_name: &'a str, |
| opt: bool, |
| request: Option<&'a str>, |
| r#async: bool, |
| result_types: Vec<&'a str>, |
| inner_method_name: &'a str, |
| data_name: &'a str, |
| } |
| |
| impl<'a> ClientMethod<'a> { |
| fn generate(&self, buf: &mut String) { |
| buf.push_str("pub fn "); |
| |
| buf.push_str(self.method_name); |
| if self.r#async { |
| buf.push_str("_async"); |
| } |
| if self.opt { |
| buf.push_str("_opt"); |
| } |
| |
| buf.push_str("(&self"); |
| if let Some(req) = self.request { |
| buf.push_str(", req: &"); |
| buf.push_str(req); |
| } |
| if self.opt { |
| buf.push_str(", opt: "); |
| buf.push_str(&fq_grpc("CallOption")); |
| } |
| buf.push_str(") -> "); |
| |
| buf.push_str(&fq_grpc("Result")); |
| buf.push('<'); |
| if self.result_types.len() != 1 { |
| buf.push('('); |
| } |
| for rt in &self.result_types { |
| buf.push_str(rt); |
| buf.push(','); |
| } |
| if self.result_types.len() != 1 { |
| buf.push(')'); |
| } |
| buf.push_str("> { "); |
| if self.opt { |
| self.generate_inner_body(buf); |
| } else { |
| self.generate_opt_body(buf); |
| } |
| buf.push_str(" }\n"); |
| } |
| |
| // Method delegates to the `_opt` version of the method. |
| fn generate_opt_body(&self, buf: &mut String) { |
| buf.push_str("self."); |
| buf.push_str(self.method_name); |
| if self.r#async { |
| buf.push_str("_async"); |
| } |
| buf.push_str("_opt("); |
| if self.request.is_some() { |
| buf.push_str("req, "); |
| } |
| buf.push_str(&fq_grpc("CallOption::default()")); |
| buf.push(')'); |
| } |
| |
| // Method delegates to the inner client. |
| fn generate_inner_body(&self, buf: &mut String) { |
| buf.push_str("self.client."); |
| buf.push_str(self.inner_method_name); |
| if self.r#async { |
| buf.push_str("_async"); |
| } |
| buf.push_str("(&"); |
| buf.push_str(self.data_name); |
| if self.request.is_some() { |
| buf.push_str(", req"); |
| } |
| buf.push_str(", opt)"); |
| } |
| } |
| |
| fn generate_spawn(buf: &mut String) { |
| buf.push_str( |
| "pub fn spawn<F>(&self, f: F) \ |
| where F: ::futures::Future<Output = ()> + Send + 'static {\ |
| self.client.spawn(f)\ |
| }\n", |
| ); |
| } |
| |
| fn generate_server(service: &Service, buf: &mut String) { |
| buf.push_str("pub trait "); |
| buf.push_str(&service.name); |
| buf.push_str(" {\n"); |
| generate_server_methods(service, buf); |
| buf.push_str("}\n"); |
| |
| buf.push_str("pub fn create_"); |
| buf.push_str(&to_snake_case(&service.name)); |
| buf.push_str("<S: "); |
| buf.push_str(&service.name); |
| buf.push_str(" + Send + Clone + 'static>(s: S) -> "); |
| buf.push_str(&fq_grpc("Service")); |
| buf.push_str(" {\n"); |
| buf.push_str("let mut builder = ::grpcio::ServiceBuilder::new();\n"); |
| |
| for method in &service.methods[0..service.methods.len() - 1] { |
| buf.push_str("let mut instance = s.clone();\n"); |
| generate_method_bind(&service.name, method, buf); |
| } |
| |
| buf.push_str("let mut instance = s;\n"); |
| generate_method_bind( |
| &service.name, |
| &service.methods[service.methods.len() - 1], |
| buf, |
| ); |
| |
| buf.push_str("builder.build()\n"); |
| buf.push_str("}\n"); |
| } |
| |
| fn generate_server_methods(service: &Service, buf: &mut String) { |
| for method in &service.methods { |
| let method_type = MethodType::from_method(method); |
| let request_arg = match method_type { |
| MethodType::Unary | MethodType::ServerStreaming => { |
| format!("req: {}", method.input_type) |
| } |
| MethodType::ClientStreaming | MethodType::Duplex => format!( |
| "stream: {}<{}>", |
| fq_grpc("RequestStream"), |
| method.input_type |
| ), |
| }; |
| let response_type = match method_type { |
| MethodType::Unary => "UnarySink", |
| MethodType::ClientStreaming => "ClientStreamingSink", |
| MethodType::ServerStreaming => "ServerStreamingSink", |
| MethodType::Duplex => "DuplexSink", |
| }; |
| generate_server_method(method, &request_arg, response_type, buf); |
| } |
| } |
| |
| fn generate_server_method( |
| method: &Method, |
| request_arg: &str, |
| response_type: &str, |
| buf: &mut String, |
| ) { |
| buf.push_str("fn "); |
| buf.push_str(&method.name); |
| buf.push_str("(&mut self, ctx: "); |
| buf.push_str(&fq_grpc("RpcContext")); |
| buf.push_str(", "); |
| buf.push_str(request_arg); |
| buf.push_str(", sink: "); |
| buf.push_str(&fq_grpc(response_type)); |
| buf.push('<'); |
| buf.push_str(&method.output_type); |
| buf.push('>'); |
| buf.push_str(");\n"); |
| } |
| |
| fn generate_method_bind(service_name: &str, method: &Method, buf: &mut String) { |
| let add_name = match MethodType::from_method(method) { |
| MethodType::Unary => "add_unary_handler", |
| MethodType::ClientStreaming => "add_client_streaming_handler", |
| MethodType::ServerStreaming => "add_server_streaming_handler", |
| MethodType::Duplex => "add_duplex_streaming_handler", |
| }; |
| |
| buf.push_str("builder = builder."); |
| buf.push_str(add_name); |
| buf.push_str("(&"); |
| buf.push_str(&const_method_name(service_name, method)); |
| buf.push_str(", move |ctx, req, resp| instance."); |
| buf.push_str(&method.name); |
| buf.push_str("(ctx, req, resp));\n"); |
| } |