blob: a20bfd7ffe7422f72bd9f501a6f30c5c27a48fb8 [file] [log] [blame]
/*
* Copyright 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tuningfork_internal.h"
#include <sstream>
#include <string>
#define LOG_TAG "TuningFork:FPDownload"
#include "Log.h"
#include "tuningfork_utils.h"
#include "web.h"
#include "../../third_party/json11/json11.hpp"
#include "modp_b64.h"
namespace tuningfork {
using namespace json11;
const char kRpcName[] = ":generateTuningParameters";
const int kSuccessCodeMin = 200;
const int kSuccessCodeMax = 299;
// Add a serialization of the params as 'field_name' in 'obj'.
static void add_params(const ProtobufSerialization& params, Json::object& obj,
const std::string& field_name) {
size_t len = params.size();
std::string dest(modp_b64_encode_len(len), '\0');
size_t encoded_len = modp_b64_encode(const_cast<char*>(dest.c_str()),
reinterpret_cast<const char*>(params.data()),
len);
if (encoded_len != -1) {
dest.resize(encoded_len);
obj[field_name] = dest;
}
}
static std::string RequestJson(const ExtraUploadInfo& request_info,
const ProtobufSerialization* training_mode_params) {
Json::object request_obj = Json::object {
{"name", json_utils::GetResourceName(request_info)},
{"device_spec", json_utils::DeviceSpecJson(request_info)}};
if (training_mode_params != nullptr) {
add_params(*training_mode_params, request_obj, "serialized_training_tuning_parameters");
}
Json request = request_obj;
auto result = request.dump();
ALOGV("Request body: %s", result.c_str());
return result;
}
static TFErrorCode DecodeResponse(const std::string& response, std::vector<uint8_t>& fps,
std::string& experiment_id) {
using namespace json11;
ALOGI("Response: %s", response.c_str());
std::string err;
Json jresponse = Json::parse(response, err);
if (!err.empty()) {
ALOGE("Parsing error: %s", err.c_str());
return TFERROR_NO_FIDELITY_PARAMS;
}
ALOGV("Response, deserialized: %s", jresponse.dump().c_str());
if (!jresponse.is_object()) {
ALOGE("Response not object");
return TFERROR_NO_FIDELITY_PARAMS;
}
auto& outer = jresponse.object_items();
auto iparameters = outer.find("parameters");
if (iparameters==outer.end()) {
ALOGW("No parameters: assuming they are empty");
fps.clear();
experiment_id.clear();
}
else {
auto& params = iparameters->second;
if (!params.is_object()) {
ALOGE("parameters not object");
return TFERROR_NO_FIDELITY_PARAMS;
}
auto& inner = params.object_items();
auto iexperiment_id = inner.find("experimentId");
if (iexperiment_id==inner.end()) {
ALOGW("No experimentId: assuming it is empty");
experiment_id.clear();
}
else {
if (!iexperiment_id->second.is_string()) {
ALOGE("experimentId is not a string");
return TFERROR_NO_FIDELITY_PARAMS;
}
experiment_id = iexperiment_id->second.string_value();
}
auto ifps = inner.find("serializedFidelityParameters");
if (ifps==inner.end()) {
ALOGW("No serializedFidelityParameters: assuming empty");
fps.clear();
}
else {
if (!ifps->second.is_string()) {
ALOGE("serializedFidelityParameters is not a string");
return TFERROR_NO_FIDELITY_PARAMS;
}
std::string sfps = ifps->second.string_value();
fps.resize(modp_b64_decode_len(sfps.length()));
if (modp_b64_decode((char*)fps.data(), sfps.c_str(), sfps.length())==-1) {
ALOGE("Can't decode base 64 FPs");
return TFERROR_NO_FIDELITY_PARAMS;
}
}
}
return TFERROR_OK;
}
static TFErrorCode DownloadFidelityParams(Request& request,
const ProtobufSerialization* training_mode_fps,
ProtobufSerialization& fps,
std::string& experiment_id) {
int response_code;
std::string body;
TFErrorCode ret = request.Send(kRpcName, RequestJson(request.Info(), training_mode_fps),
response_code, body);
if (ret!=TFERROR_OK)
return ret;
if (response_code>=kSuccessCodeMin && response_code<=kSuccessCodeMax)
ret = DecodeResponse(body, fps, experiment_id);
else
ret = TFERROR_NO_FIDELITY_PARAMS;
return ret;
}
TFErrorCode ParamsLoader::GetFidelityParams(Request& request,
const ProtobufSerialization* training_mode_fps,
ProtobufSerialization& fidelity_params,
std::string& experiment_id) {
return DownloadFidelityParams(request,
training_mode_fps, fidelity_params, experiment_id);
}
} // namespace tuningfork