blob: 03b6ef096719bb30601cef7f8231e37200e4590e [file] [log] [blame]
//
// Copyright 2018 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "test/core/util/test_lb_policies.h"
#include <string>
#include <grpc/support/log.h>
#include "src/core/ext/filters/client_channel/lb_policy.h"
#include "src/core/ext/filters/client_channel/lb_policy_registry.h"
#include "src/core/lib/address_utils/parse_address.h"
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/channelz.h"
#include "src/core/lib/debug/trace.h"
#include "src/core/lib/gprpp/memory.h"
#include "src/core/lib/gprpp/orphanable.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/closure.h"
#include "src/core/lib/iomgr/combiner.h"
#include "src/core/lib/iomgr/error.h"
#include "src/core/lib/iomgr/pollset_set.h"
#include "src/core/lib/json/json.h"
#include "src/core/lib/json/json_util.h"
#include "src/core/lib/transport/connectivity_state.h"
namespace grpc_core {
namespace {
//
// ForwardingLoadBalancingPolicy
//
// A minimal forwarding class to avoid implementing a standalone test LB.
class ForwardingLoadBalancingPolicy : public LoadBalancingPolicy {
public:
ForwardingLoadBalancingPolicy(
std::unique_ptr<ChannelControlHelper> delegating_helper, Args args,
const char* delegate_policy_name, intptr_t initial_refcount = 1)
: LoadBalancingPolicy(std::move(args), initial_refcount) {
Args delegate_args;
delegate_args.work_serializer = work_serializer();
delegate_args.channel_control_helper = std::move(delegating_helper);
delegate_args.args = args.args;
delegate_ = LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy(
delegate_policy_name, std::move(delegate_args));
grpc_pollset_set_add_pollset_set(delegate_->interested_parties(),
interested_parties());
}
~ForwardingLoadBalancingPolicy() override = default;
void UpdateLocked(UpdateArgs args) override {
delegate_->UpdateLocked(std::move(args));
}
void ExitIdleLocked() override { delegate_->ExitIdleLocked(); }
void ResetBackoffLocked() override { delegate_->ResetBackoffLocked(); }
private:
void ShutdownLocked() override { delegate_.reset(); }
OrphanablePtr<LoadBalancingPolicy> delegate_;
};
//
// TestPickArgsLb
//
constexpr char kTestPickArgsLbPolicyName[] = "test_pick_args_lb";
class TestPickArgsLb : public ForwardingLoadBalancingPolicy {
public:
TestPickArgsLb(Args args, TestPickArgsCallback cb,
const char* delegate_policy_name)
: ForwardingLoadBalancingPolicy(
absl::make_unique<Helper>(RefCountedPtr<TestPickArgsLb>(this), cb),
std::move(args), delegate_policy_name,
/*initial_refcount=*/2) {}
~TestPickArgsLb() override = default;
const char* name() const override { return kTestPickArgsLbPolicyName; }
private:
class Picker : public SubchannelPicker {
public:
Picker(std::unique_ptr<SubchannelPicker> delegate_picker,
TestPickArgsCallback cb)
: delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {}
PickResult Pick(PickArgs args) override {
// Report args seen.
PickArgsSeen args_seen;
args_seen.path = std::string(args.path);
args_seen.metadata = args.initial_metadata->TestOnlyCopyToVector();
cb_(args_seen);
// Do pick.
return delegate_picker_->Pick(args);
}
private:
std::unique_ptr<SubchannelPicker> delegate_picker_;
TestPickArgsCallback cb_;
};
class Helper : public ChannelControlHelper {
public:
Helper(RefCountedPtr<TestPickArgsLb> parent, TestPickArgsCallback cb)
: parent_(std::move(parent)), cb_(std::move(cb)) {}
RefCountedPtr<SubchannelInterface> CreateSubchannel(
ServerAddress address, const grpc_channel_args& args) override {
return parent_->channel_control_helper()->CreateSubchannel(
std::move(address), args);
}
void UpdateState(grpc_connectivity_state state, const absl::Status& status,
std::unique_ptr<SubchannelPicker> picker) override {
parent_->channel_control_helper()->UpdateState(
state, status, absl::make_unique<Picker>(std::move(picker), cb_));
}
void RequestReresolution() override {
parent_->channel_control_helper()->RequestReresolution();
}
absl::string_view GetAuthority() override {
return parent_->channel_control_helper()->GetAuthority();
}
void AddTraceEvent(TraceSeverity severity,
absl::string_view message) override {
parent_->channel_control_helper()->AddTraceEvent(severity, message);
}
private:
RefCountedPtr<TestPickArgsLb> parent_;
TestPickArgsCallback cb_;
};
};
class TestPickArgsLbConfig : public LoadBalancingPolicy::Config {
public:
const char* name() const override { return kTestPickArgsLbPolicyName; }
};
class TestPickArgsLbFactory : public LoadBalancingPolicyFactory {
public:
explicit TestPickArgsLbFactory(TestPickArgsCallback cb,
const char* delegate_policy_name)
: cb_(std::move(cb)), delegate_policy_name_(delegate_policy_name) {}
OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
LoadBalancingPolicy::Args args) const override {
return MakeOrphanable<TestPickArgsLb>(std::move(args), cb_,
delegate_policy_name_);
}
const char* name() const override { return kTestPickArgsLbPolicyName; }
RefCountedPtr<LoadBalancingPolicy::Config> ParseLoadBalancingConfig(
const Json& /*json*/, grpc_error_handle* /*error*/) const override {
return MakeRefCounted<TestPickArgsLbConfig>();
}
private:
TestPickArgsCallback cb_;
const char* delegate_policy_name_;
};
//
// InterceptRecvTrailingMetadataLoadBalancingPolicy
//
constexpr char kInterceptRecvTrailingMetadataLbPolicyName[] =
"intercept_trailing_metadata_lb";
class InterceptRecvTrailingMetadataLoadBalancingPolicy
: public ForwardingLoadBalancingPolicy {
public:
InterceptRecvTrailingMetadataLoadBalancingPolicy(
Args args, InterceptRecvTrailingMetadataCallback cb)
: ForwardingLoadBalancingPolicy(
absl::make_unique<Helper>(
RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy>(
this),
std::move(cb)),
std::move(args),
/*delegate_policy_name=*/"pick_first",
/*initial_refcount=*/2) {}
~InterceptRecvTrailingMetadataLoadBalancingPolicy() override = default;
const char* name() const override {
return kInterceptRecvTrailingMetadataLbPolicyName;
}
private:
class Picker : public SubchannelPicker {
public:
Picker(std::unique_ptr<SubchannelPicker> delegate_picker,
InterceptRecvTrailingMetadataCallback cb)
: delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {}
PickResult Pick(PickArgs args) override {
// Do pick.
PickResult result = delegate_picker_->Pick(args);
// Intercept trailing metadata.
auto* complete_pick = absl::get_if<PickResult::Complete>(&result.result);
if (complete_pick != nullptr) {
new (args.call_state->Alloc(sizeof(TrailingMetadataHandler)))
TrailingMetadataHandler(complete_pick, cb_);
}
return result;
}
private:
std::unique_ptr<SubchannelPicker> delegate_picker_;
InterceptRecvTrailingMetadataCallback cb_;
};
class Helper : public ChannelControlHelper {
public:
Helper(
RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy> parent,
InterceptRecvTrailingMetadataCallback cb)
: parent_(std::move(parent)), cb_(std::move(cb)) {}
RefCountedPtr<SubchannelInterface> CreateSubchannel(
ServerAddress address, const grpc_channel_args& args) override {
return parent_->channel_control_helper()->CreateSubchannel(
std::move(address), args);
}
void UpdateState(grpc_connectivity_state state, const absl::Status& status,
std::unique_ptr<SubchannelPicker> picker) override {
parent_->channel_control_helper()->UpdateState(
state, status, absl::make_unique<Picker>(std::move(picker), cb_));
}
void RequestReresolution() override {
parent_->channel_control_helper()->RequestReresolution();
}
absl::string_view GetAuthority() override {
return parent_->channel_control_helper()->GetAuthority();
}
void AddTraceEvent(TraceSeverity severity,
absl::string_view message) override {
parent_->channel_control_helper()->AddTraceEvent(severity, message);
}
private:
RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy> parent_;
InterceptRecvTrailingMetadataCallback cb_;
};
class TrailingMetadataHandler {
public:
TrailingMetadataHandler(PickResult::Complete* result,
InterceptRecvTrailingMetadataCallback cb)
: cb_(std::move(cb)) {
result->recv_trailing_metadata_ready = [this](absl::Status /*status*/,
MetadataInterface* metadata,
CallState* call_state) {
RecordRecvTrailingMetadata(metadata, call_state);
};
}
private:
void RecordRecvTrailingMetadata(MetadataInterface* recv_trailing_metadata,
CallState* call_state) {
TrailingMetadataArgsSeen args_seen;
args_seen.backend_metric_data = call_state->GetBackendMetricData();
GPR_ASSERT(recv_trailing_metadata != nullptr);
args_seen.metadata = recv_trailing_metadata->TestOnlyCopyToVector();
cb_(args_seen);
this->~TrailingMetadataHandler();
}
InterceptRecvTrailingMetadataCallback cb_;
};
};
class InterceptTrailingConfig : public LoadBalancingPolicy::Config {
public:
const char* name() const override {
return kInterceptRecvTrailingMetadataLbPolicyName;
}
};
class InterceptTrailingFactory : public LoadBalancingPolicyFactory {
public:
explicit InterceptTrailingFactory(InterceptRecvTrailingMetadataCallback cb)
: cb_(std::move(cb)) {}
OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
LoadBalancingPolicy::Args args) const override {
return MakeOrphanable<InterceptRecvTrailingMetadataLoadBalancingPolicy>(
std::move(args), cb_);
}
const char* name() const override {
return kInterceptRecvTrailingMetadataLbPolicyName;
}
RefCountedPtr<LoadBalancingPolicy::Config> ParseLoadBalancingConfig(
const Json& /*json*/, grpc_error_handle* /*error*/) const override {
return MakeRefCounted<InterceptTrailingConfig>();
}
private:
InterceptRecvTrailingMetadataCallback cb_;
};
//
// AddressTestLoadBalancingPolicy
//
constexpr char kAddressTestLbPolicyName[] = "address_test_lb";
class AddressTestLoadBalancingPolicy : public ForwardingLoadBalancingPolicy {
public:
AddressTestLoadBalancingPolicy(Args args, AddressTestCallback cb)
: ForwardingLoadBalancingPolicy(
absl::make_unique<Helper>(
RefCountedPtr<AddressTestLoadBalancingPolicy>(this),
std::move(cb)),
std::move(args),
/*delegate_policy_name=*/"pick_first",
/*initial_refcount=*/2) {}
~AddressTestLoadBalancingPolicy() override = default;
const char* name() const override { return kAddressTestLbPolicyName; }
private:
class Helper : public ChannelControlHelper {
public:
Helper(RefCountedPtr<AddressTestLoadBalancingPolicy> parent,
AddressTestCallback cb)
: parent_(std::move(parent)), cb_(std::move(cb)) {}
RefCountedPtr<SubchannelInterface> CreateSubchannel(
ServerAddress address, const grpc_channel_args& args) override {
cb_(address);
return parent_->channel_control_helper()->CreateSubchannel(
std::move(address), args);
}
void UpdateState(grpc_connectivity_state state, const absl::Status& status,
std::unique_ptr<SubchannelPicker> picker) override {
parent_->channel_control_helper()->UpdateState(state, status,
std::move(picker));
}
void RequestReresolution() override {
parent_->channel_control_helper()->RequestReresolution();
}
absl::string_view GetAuthority() override {
return parent_->channel_control_helper()->GetAuthority();
}
void AddTraceEvent(TraceSeverity severity,
absl::string_view message) override {
parent_->channel_control_helper()->AddTraceEvent(severity, message);
}
private:
RefCountedPtr<AddressTestLoadBalancingPolicy> parent_;
AddressTestCallback cb_;
};
};
class AddressTestConfig : public LoadBalancingPolicy::Config {
public:
const char* name() const override { return kAddressTestLbPolicyName; }
};
class AddressTestFactory : public LoadBalancingPolicyFactory {
public:
explicit AddressTestFactory(AddressTestCallback cb) : cb_(std::move(cb)) {}
OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
LoadBalancingPolicy::Args args) const override {
return MakeOrphanable<AddressTestLoadBalancingPolicy>(std::move(args), cb_);
}
const char* name() const override { return kAddressTestLbPolicyName; }
RefCountedPtr<LoadBalancingPolicy::Config> ParseLoadBalancingConfig(
const Json& /*json*/, grpc_error_handle* /*error*/) const override {
return MakeRefCounted<AddressTestConfig>();
}
private:
AddressTestCallback cb_;
};
//
// FixedAddressLoadBalancingPolicy
//
constexpr char kFixedAddressLbPolicyName[] = "fixed_address_lb";
class FixedAddressConfig : public LoadBalancingPolicy::Config {
public:
explicit FixedAddressConfig(std::string address)
: address_(std::move(address)) {}
const char* name() const override { return kFixedAddressLbPolicyName; }
const std::string& address() const { return address_; }
private:
std::string address_;
};
class FixedAddressLoadBalancingPolicy : public ForwardingLoadBalancingPolicy {
public:
explicit FixedAddressLoadBalancingPolicy(Args args)
: ForwardingLoadBalancingPolicy(
absl::make_unique<Helper>(
RefCountedPtr<FixedAddressLoadBalancingPolicy>(this)),
std::move(args),
/*delegate_policy_name=*/"pick_first",
/*initial_refcount=*/2) {}
~FixedAddressLoadBalancingPolicy() override = default;
const char* name() const override { return kFixedAddressLbPolicyName; }
void UpdateLocked(UpdateArgs args) override {
auto* config = static_cast<FixedAddressConfig*>(args.config.get());
gpr_log(GPR_INFO, "%s: update URI: %s", kFixedAddressLbPolicyName,
config->address().c_str());
auto uri = URI::Parse(config->address());
args.config.reset();
args.addresses.clear();
if (uri.ok()) {
grpc_resolved_address address;
GPR_ASSERT(grpc_parse_uri(*uri, &address));
args.addresses.emplace_back(address, /*args=*/nullptr);
} else {
gpr_log(GPR_ERROR,
"%s: could not parse URI (%s), using empty address list",
kFixedAddressLbPolicyName, uri.status().ToString().c_str());
}
ForwardingLoadBalancingPolicy::UpdateLocked(std::move(args));
}
private:
class Helper : public ChannelControlHelper {
public:
explicit Helper(RefCountedPtr<FixedAddressLoadBalancingPolicy> parent)
: parent_(std::move(parent)) {}
RefCountedPtr<SubchannelInterface> CreateSubchannel(
ServerAddress address, const grpc_channel_args& args) override {
return parent_->channel_control_helper()->CreateSubchannel(
std::move(address), args);
}
void UpdateState(grpc_connectivity_state state, const absl::Status& status,
std::unique_ptr<SubchannelPicker> picker) override {
parent_->channel_control_helper()->UpdateState(state, status,
std::move(picker));
}
void RequestReresolution() override {
parent_->channel_control_helper()->RequestReresolution();
}
absl::string_view GetAuthority() override {
return parent_->channel_control_helper()->GetAuthority();
}
void AddTraceEvent(TraceSeverity severity,
absl::string_view message) override {
parent_->channel_control_helper()->AddTraceEvent(severity, message);
}
private:
RefCountedPtr<FixedAddressLoadBalancingPolicy> parent_;
};
};
class FixedAddressFactory : public LoadBalancingPolicyFactory {
public:
FixedAddressFactory() = default;
OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
LoadBalancingPolicy::Args args) const override {
return MakeOrphanable<FixedAddressLoadBalancingPolicy>(std::move(args));
}
const char* name() const override { return kFixedAddressLbPolicyName; }
RefCountedPtr<LoadBalancingPolicy::Config> ParseLoadBalancingConfig(
const Json& json, grpc_error_handle* error) const override {
std::vector<grpc_error_handle> error_list;
std::string address;
ParseJsonObjectField(json.object_value(), "address", &address, &error_list);
if (!error_list.empty()) {
*error = GRPC_ERROR_CREATE_FROM_VECTOR(
"errors parsing fixed_address_lb config", &error_list);
return nullptr;
}
return MakeRefCounted<FixedAddressConfig>(std::move(address));
}
};
} // namespace
void RegisterTestPickArgsLoadBalancingPolicy(TestPickArgsCallback cb,
const char* delegate_policy_name) {
LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory(
absl::make_unique<TestPickArgsLbFactory>(std::move(cb),
delegate_policy_name));
}
void RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy(
InterceptRecvTrailingMetadataCallback cb) {
LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory(
absl::make_unique<InterceptTrailingFactory>(std::move(cb)));
}
void RegisterAddressTestLoadBalancingPolicy(AddressTestCallback cb) {
LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory(
absl::make_unique<AddressTestFactory>(std::move(cb)));
}
void RegisterFixedAddressLoadBalancingPolicy() {
LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory(
absl::make_unique<FixedAddressFactory>());
}
} // namespace grpc_core