| #!/usr/bin/env python3 |
| import sys |
| |
| from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, CodeGeneratorResponse |
| |
| |
| def eprint(*args, **kwargs): |
| print(*args, file=sys.stderr, **kwargs) |
| |
| |
| request = CodeGeneratorRequest.FromString(sys.stdin.buffer.read()) |
| |
| def has_type(proto_file, type_name): |
| return any(filter(lambda x: x.name == type_name, proto_file.message_type)) |
| |
| def import_type(imports, type): |
| # eprint(f'type: {type} request: {request.proto_file}') |
| package = type[1:type.rindex('.')] |
| type_name = type[type.rindex('.')+1:] |
| file = next(filter(lambda x: x.package == package and has_type(x, type_name), request.proto_file)) |
| python_path = file.name.replace('.proto', '').replace('/', '.') |
| as_name = python_path.replace('.', '_dot_') + '__pb2' |
| module_path = python_path[:python_path.rindex('.')] |
| module_name = python_path[python_path.rindex('.')+1:] + '_pb2' |
| imports.add(f'from {module_path} import {module_name} as {as_name}') |
| return f'{as_name}.{type_name}' |
| |
| |
| def generate_method(imports, file, service, method): |
| input_mode = 'stream' if method.client_streaming else 'unary' |
| output_mode = 'stream' if method.server_streaming else 'unary' |
| |
| input_type = import_type(imports, method.input_type) |
| output_type = import_type(imports, method.output_type) |
| |
| if input_mode == 'stream': |
| return ( |
| f'def {method.name}(self, iterator, **kwargs):\n' |
| f' return self.channel.{input_mode}_{output_mode}(\n' |
| f" '/{file.package}.{service.name}/{method.name}',\n" |
| f' request_serializer={input_type}.SerializeToString,\n' |
| f' response_deserializer={output_type}.FromString\n' |
| f' )(iterator, **kwargs)' |
| ).split('\n') |
| else: |
| return ( |
| f'def {method.name}(self, wait_for_ready=None, **kwargs):\n' |
| f' return self.channel.{input_mode}_{output_mode}(\n' |
| f" '/{file.package}.{service.name}/{method.name}',\n" |
| f' request_serializer={input_type}.SerializeToString,\n' |
| f' response_deserializer={output_type}.FromString\n' |
| f' )({input_type}(**kwargs), wait_for_ready=wait_for_ready)' |
| ).split('\n') |
| |
| |
| def generate_service(imports, file, service): |
| methods = '\n\n '.join([ |
| '\n '.join( |
| generate_method(imports, file, service, method) |
| ) for method in service.method |
| ]) |
| return ( |
| f'class {service.name}:\n' |
| f' def __init__(self, channel):\n' |
| f' self.channel = channel\n' |
| f'\n' |
| f' {methods}\n' |
| ).split('\n') |
| |
| |
| files = [] |
| |
| for file_name in request.file_to_generate: |
| file = next(filter(lambda x: x.name == file_name, request.proto_file)) |
| |
| imports = set([]) |
| |
| services = '\n'.join(sum([ |
| generate_service(imports, file, service) for service in file.service |
| ], [])) |
| |
| files.append(CodeGeneratorResponse.File( |
| name=file_name.replace('.proto', '_grpc.py'), |
| content='\n'.join(imports) + '\n\n' + services |
| )) |
| |
| reponse = CodeGeneratorResponse(file=files) |
| |
| sys.stdout.buffer.write(reponse.SerializeToString()) |