blob: 04dfa897b11b509af606ff6e532fc01e73d824b3 [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <functional>
#include <map>
#include <string>
#include <vector>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/compiler/plugin.h>
#include <google/protobuf/compiler/code_generator.h>
#include <google/protobuf/io/printer.h>
#include <google/protobuf/io/zero_copy_stream.h>
#include <google/protobuf/stubs/strutil.h>
#include "nugget/protobuf/options.pb.h"
using ::google::protobuf::FileDescriptor;
using ::google::protobuf::JoinStrings;
using ::google::protobuf::MethodDescriptor;
using ::google::protobuf::ServiceDescriptor;
using ::google::protobuf::Split;
using ::google::protobuf::SplitStringUsing;
using ::google::protobuf::StripSuffixString;
using ::google::protobuf::compiler::CodeGenerator;
using ::google::protobuf::compiler::OutputDirectory;
using ::google::protobuf::io::Printer;
using ::google::protobuf::io::ZeroCopyOutputStream;
using ::nugget::protobuf::app_id;
using ::nugget::protobuf::request_buffer_size;
using ::nugget::protobuf::response_buffer_size;
namespace {
std::string validateServiceOptions(const ServiceDescriptor& service) {
if (!service.options().HasExtension(app_id)) {
return "nugget.protobuf.app_id is not defined for service " + service.name();
}
if (!service.options().HasExtension(request_buffer_size)) {
return "nugget.protobuf.request_buffer_size is not defined for service " + service.name();
}
if (!service.options().HasExtension(response_buffer_size)) {
return "nugget.protobuf.response_buffer_size is not defined for service " + service.name();
}
return "";
}
template <typename Descriptor>
std::vector<std::string> Packages(const Descriptor& descriptor) {
std::vector<std::string> namespaces;
SplitStringUsing(descriptor.full_name(), ".", &namespaces);
namespaces.pop_back(); // just take the package
return namespaces;
}
template <typename Descriptor>
std::string FullyQualifiedIdentifier(const Descriptor& descriptor) {
const auto namespaces = Packages(descriptor);
if (namespaces.empty()) {
return "::" + descriptor.name();
} else {
std::string namespace_path;
JoinStrings(namespaces, "::", &namespace_path);
return "::" + namespace_path + "::" + descriptor.name();
}
}
template <typename Descriptor>
std::string FullyQualifiedHeader(const Descriptor& descriptor) {
const auto packages = Packages(descriptor);
const auto file = Split(descriptor.file()->name(), "/").back();
const auto header = StripSuffixString(file, ".proto") + ".pb.h";
if (packages.empty()) {
return header;
} else {
std::string package_path;
JoinStrings(packages, "/", &package_path);
return package_path + "/" + header;
}
}
template <typename Descriptor>
void OpenNamespaces(Printer& printer, const Descriptor& descriptor) {
const auto namespaces = Packages(descriptor);
for (const auto& ns : namespaces) {
std::map<std::string, std::string> namespaceVars;
namespaceVars["namespace"] = ns;
printer.Print(namespaceVars, R"(
namespace $namespace$ {)");
}
}
template <typename Descriptor>
void CloseNamespaces(Printer& printer, const Descriptor& descriptor) {
const auto namespaces = Packages(descriptor);
for (auto it = namespaces.crbegin(); it != namespaces.crend(); ++it) {
std::map<std::string, std::string> namespaceVars;
namespaceVars["namespace"] = *it;
printer.Print(namespaceVars, R"(
} // namespace $namespace$)");
}
}
void ForEachMethod(const ServiceDescriptor& service,
std::function<void(std::map<std::string, std::string>)> handler) {
for (int i = 0; i < service.method_count(); ++i) {
const MethodDescriptor& method = *service.method(i);
std::map<std::string, std::string> vars;
vars["method_id"] = std::to_string(i);
vars["method_name"] = method.name();
vars["method_input_type"] = FullyQualifiedIdentifier(*method.input_type());
vars["method_output_type"] = FullyQualifiedIdentifier(*method.output_type());
handler(vars);
}
}
void GenerateMockClient(Printer& printer, const ServiceDescriptor& service) {
std::map<std::string, std::string> vars;
vars["include_guard"] = "PROTOC_GENERATED_MOCK_" + service.name() + "_CLIENT_H";
vars["service_header"] = service.name() + ".client.h";
vars["mock_class"] = "Mock" + service.name();
vars["class"] = service.name();
printer.Print(vars, R"(
#ifndef $include_guard$
#define $include_guard$
#include <gmock/gmock.h>
#include <$service_header$>)");
OpenNamespaces(printer, service);
printer.Print(vars, R"(
struct $mock_class$ : public I$class$ {)");
ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
printer.Print(methodVars, R"(
MOCK_METHOD2($method_name$, uint32_t(const $method_input_type$&, $method_output_type$*));)");
});
printer.Print(vars, R"(
};)");
CloseNamespaces(printer, service);
printer.Print(vars, R"(
#endif)");
}
void GenerateClientHeader(Printer& printer, const ServiceDescriptor& service) {
std::map<std::string, std::string> vars;
vars["include_guard"] = "PROTOC_GENERATED_" + service.name() + "_CLIENT_H";
vars["protobuf_header"] = FullyQualifiedHeader(service);
vars["class"] = service.name();
vars["iface_class"] = "I" + service.name();
vars["app_id"] = "APP_ID_" + service.options().GetExtension(app_id);
printer.Print(vars, R"(
#ifndef $include_guard$
#define $include_guard$
#include <application.h>
#include <nos/AppClient.h>
#include <nos/NuggetClientInterface.h>
#include "$protobuf_header$")");
OpenNamespaces(printer, service);
// Pure virtual interface to make testing easier
printer.Print(vars, R"(
class $iface_class$ {
public:
virtual ~$iface_class$() = default;)");
ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
printer.Print(methodVars, R"(
virtual uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) = 0;)");
});
printer.Print(vars, R"(
};)");
// Implementation of the interface for Nugget
printer.Print(vars, R"(
class $class$ : public $iface_class$ {
::nos::AppClient _app;
public:
$class$(::nos::NuggetClientInterface& client) : _app{client, $app_id$} {}
~$class$() override = default;)");
ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
printer.Print(methodVars, R"(
uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) override;)");
});
printer.Print(vars, R"(
};)");
CloseNamespaces(printer, service);
printer.Print(vars, R"(
#endif)");
}
void GenerateClientSource(Printer& printer, const ServiceDescriptor& service) {
std::map<std::string, std::string> vars;
vars["generated_header"] = service.name() + ".client.h";
vars["class"] = service.name();
const uint32_t max_request_size = service.options().GetExtension(request_buffer_size);
const uint32_t max_response_size = service.options().GetExtension(response_buffer_size);
vars["max_request_size"] = std::to_string(max_request_size);
vars["max_response_size"] = std::to_string(max_response_size);
printer.Print(vars, R"(
#include <$generated_header$>
#include <application.h>)");
OpenNamespaces(printer, service);
// Methods
ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
methodVars.insert(vars.begin(), vars.end());
printer.Print(methodVars, R"(
uint32_t $class$::$method_name$(const $method_input_type$& request, $method_output_type$* response) {
const size_t request_size = request.ByteSize();
if (request_size > $max_request_size$) {
return APP_ERROR_TOO_MUCH;
}
std::vector<uint8_t> buffer(request_size);
if (!request.SerializeToArray(buffer.data(), buffer.size())) {
return APP_ERROR_RPC;
}
std::vector<uint8_t> responseBuffer;
if (response != nullptr) {
responseBuffer.resize($max_response_size$);
}
const uint32_t appStatus = _app.Call($method_id$, buffer,
(response != nullptr) ? &responseBuffer : nullptr);
if (appStatus == APP_SUCCESS && response != nullptr) {
if (!response->ParseFromArray(responseBuffer.data(), responseBuffer.size())) {
return APP_ERROR_RPC;
}
}
return appStatus;
})");
});
CloseNamespaces(printer, service);
}
// Generator for C++ Nugget service client
class CppNuggetServiceClientGenerator : public CodeGenerator {
public:
CppNuggetServiceClientGenerator() = default;
~CppNuggetServiceClientGenerator() override = default;
bool Generate(const FileDescriptor* file,
const std::string& parameter,
OutputDirectory* output_directory,
std::string* error) const override {
for (int i = 0; i < file->service_count(); ++i) {
const auto& service = *file->service(i);
*error = validateServiceOptions(service);
if (!error->empty()) {
return false;
}
if (parameter == "mock") {
std::unique_ptr<ZeroCopyOutputStream> output{
output_directory->Open("Mock" + service.name() + ".client.h")};
Printer printer(output.get(), '$');
GenerateMockClient(printer, service);
} else if (parameter == "header") {
std::unique_ptr<ZeroCopyOutputStream> output{
output_directory->Open(service.name() + ".client.h")};
Printer printer(output.get(), '$');
GenerateClientHeader(printer, service);
} else if (parameter == "source") {
std::unique_ptr<ZeroCopyOutputStream> output{
output_directory->Open(service.name() + ".client.cpp")};
Printer printer(output.get(), '$');
GenerateClientSource(printer, service);
} else {
*error = "Illegal parameter: must be mock|header|source";
return false;
}
}
return true;
}
private:
GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CppNuggetServiceClientGenerator);
};
} // namespace
int main(int argc, char* argv[]) {
GOOGLE_PROTOBUF_VERIFY_VERSION;
CppNuggetServiceClientGenerator generator;
return google::protobuf::compiler::PluginMain(argc, argv, &generator);
}