| /* |
| * |
| * Copyright 2018 gRPC authors. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * |
| */ |
| |
| #include <grpc/support/port_platform.h> |
| |
| #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" |
| |
| #include <stdio.h> |
| #include <stdlib.h> |
| #include <string.h> |
| |
| #include "upb/upb.hpp" |
| |
| #include <grpc/support/alloc.h> |
| #include <grpc/support/log.h> |
| #include <grpc/support/string_util.h> |
| #include <grpc/support/sync.h> |
| #include <grpc/support/thd_id.h> |
| |
| #include "src/core/lib/gprpp/sync.h" |
| #include "src/core/lib/gprpp/thd.h" |
| #include "src/core/lib/iomgr/closure.h" |
| #include "src/core/lib/slice/slice_internal.h" |
| #include "src/core/lib/surface/channel.h" |
| #include "src/core/tsi/alts/frame_protector/alts_frame_protector.h" |
| #include "src/core/tsi/alts/handshaker/alts_handshaker_client.h" |
| #include "src/core/tsi/alts/handshaker/alts_shared_resource.h" |
| #include "src/core/tsi/alts/handshaker/alts_tsi_utils.h" |
| #include "src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.h" |
| |
| /* Main struct for ALTS TSI handshaker. */ |
| struct alts_tsi_handshaker { |
| tsi_handshaker base; |
| grpc_slice target_name; |
| bool is_client; |
| bool has_sent_start_message; |
| bool has_created_handshaker_client; |
| char* handshaker_service_url; |
| grpc_pollset_set* interested_parties; |
| grpc_alts_credentials_options* options; |
| alts_handshaker_client_vtable* client_vtable_for_testing; |
| grpc_channel* channel; |
| bool use_dedicated_cq; |
| // mu synchronizes all fields below. Note these are the |
| // only fields that can be concurrently accessed (due to |
| // potential concurrency of tsi_handshaker_shutdown and |
| // tsi_handshaker_next). |
| gpr_mu mu; |
| alts_handshaker_client* client; |
| // shutdown effectively follows base.handshake_shutdown, |
| // but is synchronized by the mutex of this object. |
| bool shutdown; |
| // Maximum frame size used by frame protector. |
| size_t max_frame_size; |
| }; |
| |
| /* Main struct for ALTS TSI handshaker result. */ |
| typedef struct alts_tsi_handshaker_result { |
| tsi_handshaker_result base; |
| char* peer_identity; |
| char* key_data; |
| unsigned char* unused_bytes; |
| size_t unused_bytes_size; |
| grpc_slice rpc_versions; |
| bool is_client; |
| grpc_slice serialized_context; |
| // Peer's maximum frame size. |
| size_t max_frame_size; |
| } alts_tsi_handshaker_result; |
| |
| static tsi_result handshaker_result_extract_peer( |
| const tsi_handshaker_result* self, tsi_peer* peer) { |
| if (self == nullptr || peer == nullptr) { |
| gpr_log(GPR_ERROR, "Invalid argument to handshaker_result_extract_peer()"); |
| return TSI_INVALID_ARGUMENT; |
| } |
| alts_tsi_handshaker_result* result = |
| reinterpret_cast<alts_tsi_handshaker_result*>( |
| const_cast<tsi_handshaker_result*>(self)); |
| GPR_ASSERT(kTsiAltsNumOfPeerProperties == 5); |
| tsi_result ok = tsi_construct_peer(kTsiAltsNumOfPeerProperties, peer); |
| int index = 0; |
| if (ok != TSI_OK) { |
| gpr_log(GPR_ERROR, "Failed to construct tsi peer"); |
| return ok; |
| } |
| GPR_ASSERT(&peer->properties[index] != nullptr); |
| ok = tsi_construct_string_peer_property_from_cstring( |
| TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE, |
| &peer->properties[index]); |
| if (ok != TSI_OK) { |
| tsi_peer_destruct(peer); |
| gpr_log(GPR_ERROR, "Failed to set tsi peer property"); |
| return ok; |
| } |
| index++; |
| GPR_ASSERT(&peer->properties[index] != nullptr); |
| ok = tsi_construct_string_peer_property_from_cstring( |
| TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, result->peer_identity, |
| &peer->properties[index]); |
| if (ok != TSI_OK) { |
| tsi_peer_destruct(peer); |
| gpr_log(GPR_ERROR, "Failed to set tsi peer property"); |
| } |
| index++; |
| GPR_ASSERT(&peer->properties[index] != nullptr); |
| ok = tsi_construct_string_peer_property( |
| TSI_ALTS_RPC_VERSIONS, |
| reinterpret_cast<char*>(GRPC_SLICE_START_PTR(result->rpc_versions)), |
| GRPC_SLICE_LENGTH(result->rpc_versions), &peer->properties[index]); |
| if (ok != TSI_OK) { |
| tsi_peer_destruct(peer); |
| gpr_log(GPR_ERROR, "Failed to set tsi peer property"); |
| } |
| index++; |
| GPR_ASSERT(&peer->properties[index] != nullptr); |
| ok = tsi_construct_string_peer_property( |
| TSI_ALTS_CONTEXT, |
| reinterpret_cast<char*>(GRPC_SLICE_START_PTR(result->serialized_context)), |
| GRPC_SLICE_LENGTH(result->serialized_context), &peer->properties[index]); |
| if (ok != TSI_OK) { |
| tsi_peer_destruct(peer); |
| gpr_log(GPR_ERROR, "Failed to set tsi peer property"); |
| } |
| index++; |
| GPR_ASSERT(&peer->properties[index] != nullptr); |
| ok = tsi_construct_string_peer_property_from_cstring( |
| TSI_SECURITY_LEVEL_PEER_PROPERTY, |
| tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY), |
| &peer->properties[index]); |
| if (ok != TSI_OK) { |
| tsi_peer_destruct(peer); |
| gpr_log(GPR_ERROR, "Failed to set tsi peer property"); |
| } |
| GPR_ASSERT(++index == kTsiAltsNumOfPeerProperties); |
| return ok; |
| } |
| |
| static tsi_result handshaker_result_create_zero_copy_grpc_protector( |
| const tsi_handshaker_result* self, size_t* max_output_protected_frame_size, |
| tsi_zero_copy_grpc_protector** protector) { |
| if (self == nullptr || protector == nullptr) { |
| gpr_log(GPR_ERROR, |
| "Invalid arguments to create_zero_copy_grpc_protector()"); |
| return TSI_INVALID_ARGUMENT; |
| } |
| alts_tsi_handshaker_result* result = |
| reinterpret_cast<alts_tsi_handshaker_result*>( |
| const_cast<tsi_handshaker_result*>(self)); |
| |
| // In case the peer does not send max frame size (e.g. peer is gRPC Go or |
| // peer uses an old binary), the negotiated frame size is set to |
| // kTsiAltsMinFrameSize (ignoring max_output_protected_frame_size value if |
| // present). Otherwise, it is based on peer and user specified max frame |
| // size (if present). |
| size_t max_frame_size = kTsiAltsMinFrameSize; |
| if (result->max_frame_size) { |
| size_t peer_max_frame_size = result->max_frame_size; |
| max_frame_size = std::min<size_t>(peer_max_frame_size, |
| max_output_protected_frame_size == nullptr |
| ? kTsiAltsMaxFrameSize |
| : *max_output_protected_frame_size); |
| max_frame_size = std::max<size_t>(max_frame_size, kTsiAltsMinFrameSize); |
| } |
| max_output_protected_frame_size = &max_frame_size; |
| gpr_log(GPR_DEBUG, |
| "After Frame Size Negotiation, maximum frame size used by frame " |
| "protector equals %zu", |
| *max_output_protected_frame_size); |
| tsi_result ok = alts_zero_copy_grpc_protector_create( |
| reinterpret_cast<const uint8_t*>(result->key_data), |
| kAltsAes128GcmRekeyKeyLength, /*is_rekey=*/true, result->is_client, |
| /*is_integrity_only=*/false, /*enable_extra_copy=*/false, |
| max_output_protected_frame_size, protector); |
| if (ok != TSI_OK) { |
| gpr_log(GPR_ERROR, "Failed to create zero-copy grpc protector"); |
| } |
| return ok; |
| } |
| |
| static tsi_result handshaker_result_create_frame_protector( |
| const tsi_handshaker_result* self, size_t* max_output_protected_frame_size, |
| tsi_frame_protector** protector) { |
| if (self == nullptr || protector == nullptr) { |
| gpr_log(GPR_ERROR, |
| "Invalid arguments to handshaker_result_create_frame_protector()"); |
| return TSI_INVALID_ARGUMENT; |
| } |
| alts_tsi_handshaker_result* result = |
| reinterpret_cast<alts_tsi_handshaker_result*>( |
| const_cast<tsi_handshaker_result*>(self)); |
| tsi_result ok = alts_create_frame_protector( |
| reinterpret_cast<const uint8_t*>(result->key_data), |
| kAltsAes128GcmRekeyKeyLength, result->is_client, /*is_rekey=*/true, |
| max_output_protected_frame_size, protector); |
| if (ok != TSI_OK) { |
| gpr_log(GPR_ERROR, "Failed to create frame protector"); |
| } |
| return ok; |
| } |
| |
| static tsi_result handshaker_result_get_unused_bytes( |
| const tsi_handshaker_result* self, const unsigned char** bytes, |
| size_t* bytes_size) { |
| if (self == nullptr || bytes == nullptr || bytes_size == nullptr) { |
| gpr_log(GPR_ERROR, |
| "Invalid arguments to handshaker_result_get_unused_bytes()"); |
| return TSI_INVALID_ARGUMENT; |
| } |
| alts_tsi_handshaker_result* result = |
| reinterpret_cast<alts_tsi_handshaker_result*>( |
| const_cast<tsi_handshaker_result*>(self)); |
| *bytes = result->unused_bytes; |
| *bytes_size = result->unused_bytes_size; |
| return TSI_OK; |
| } |
| |
| static void handshaker_result_destroy(tsi_handshaker_result* self) { |
| if (self == nullptr) { |
| return; |
| } |
| alts_tsi_handshaker_result* result = |
| reinterpret_cast<alts_tsi_handshaker_result*>( |
| const_cast<tsi_handshaker_result*>(self)); |
| gpr_free(result->peer_identity); |
| gpr_free(result->key_data); |
| gpr_free(result->unused_bytes); |
| grpc_slice_unref_internal(result->rpc_versions); |
| grpc_slice_unref_internal(result->serialized_context); |
| gpr_free(result); |
| } |
| |
| static const tsi_handshaker_result_vtable result_vtable = { |
| handshaker_result_extract_peer, |
| handshaker_result_create_zero_copy_grpc_protector, |
| handshaker_result_create_frame_protector, |
| handshaker_result_get_unused_bytes, handshaker_result_destroy}; |
| |
| tsi_result alts_tsi_handshaker_result_create(grpc_gcp_HandshakerResp* resp, |
| bool is_client, |
| tsi_handshaker_result** self) { |
| if (self == nullptr || resp == nullptr) { |
| gpr_log(GPR_ERROR, "Invalid arguments to create_handshaker_result()"); |
| return TSI_INVALID_ARGUMENT; |
| } |
| const grpc_gcp_HandshakerResult* hresult = |
| grpc_gcp_HandshakerResp_result(resp); |
| const grpc_gcp_Identity* identity = |
| grpc_gcp_HandshakerResult_peer_identity(hresult); |
| if (identity == nullptr) { |
| gpr_log(GPR_ERROR, "Invalid identity"); |
| return TSI_FAILED_PRECONDITION; |
| } |
| upb_strview peer_service_account = |
| grpc_gcp_Identity_service_account(identity); |
| if (peer_service_account.size == 0) { |
| gpr_log(GPR_ERROR, "Invalid peer service account"); |
| return TSI_FAILED_PRECONDITION; |
| } |
| upb_strview key_data = grpc_gcp_HandshakerResult_key_data(hresult); |
| if (key_data.size < kAltsAes128GcmRekeyKeyLength) { |
| gpr_log(GPR_ERROR, "Bad key length"); |
| return TSI_FAILED_PRECONDITION; |
| } |
| const grpc_gcp_RpcProtocolVersions* peer_rpc_version = |
| grpc_gcp_HandshakerResult_peer_rpc_versions(hresult); |
| if (peer_rpc_version == nullptr) { |
| gpr_log(GPR_ERROR, "Peer does not set RPC protocol versions."); |
| return TSI_FAILED_PRECONDITION; |
| } |
| upb_strview application_protocol = |
| grpc_gcp_HandshakerResult_application_protocol(hresult); |
| if (application_protocol.size == 0) { |
| gpr_log(GPR_ERROR, "Invalid application protocol"); |
| return TSI_FAILED_PRECONDITION; |
| } |
| upb_strview record_protocol = |
| grpc_gcp_HandshakerResult_record_protocol(hresult); |
| if (record_protocol.size == 0) { |
| gpr_log(GPR_ERROR, "Invalid record protocol"); |
| return TSI_FAILED_PRECONDITION; |
| } |
| const grpc_gcp_Identity* local_identity = |
| grpc_gcp_HandshakerResult_local_identity(hresult); |
| if (local_identity == nullptr) { |
| gpr_log(GPR_ERROR, "Invalid local identity"); |
| return TSI_FAILED_PRECONDITION; |
| } |
| upb_strview local_service_account = |
| grpc_gcp_Identity_service_account(local_identity); |
| // We don't check if local service account is empty here |
| // because local identity could be empty in certain situations. |
| alts_tsi_handshaker_result* result = |
| static_cast<alts_tsi_handshaker_result*>(gpr_zalloc(sizeof(*result))); |
| result->key_data = |
| static_cast<char*>(gpr_zalloc(kAltsAes128GcmRekeyKeyLength)); |
| memcpy(result->key_data, key_data.data, kAltsAes128GcmRekeyKeyLength); |
| result->peer_identity = |
| static_cast<char*>(gpr_zalloc(peer_service_account.size + 1)); |
| memcpy(result->peer_identity, peer_service_account.data, |
| peer_service_account.size); |
| result->max_frame_size = grpc_gcp_HandshakerResult_max_frame_size(hresult); |
| upb::Arena rpc_versions_arena; |
| bool serialized = grpc_gcp_rpc_protocol_versions_encode( |
| peer_rpc_version, rpc_versions_arena.ptr(), &result->rpc_versions); |
| if (!serialized) { |
| gpr_log(GPR_ERROR, "Failed to serialize peer's RPC protocol versions."); |
| return TSI_FAILED_PRECONDITION; |
| } |
| upb::Arena context_arena; |
| grpc_gcp_AltsContext* context = grpc_gcp_AltsContext_new(context_arena.ptr()); |
| grpc_gcp_AltsContext_set_application_protocol(context, application_protocol); |
| grpc_gcp_AltsContext_set_record_protocol(context, record_protocol); |
| // ALTS currently only supports the security level of 2, |
| // which is "grpc_gcp_INTEGRITY_AND_PRIVACY". |
| grpc_gcp_AltsContext_set_security_level(context, 2); |
| grpc_gcp_AltsContext_set_peer_service_account(context, peer_service_account); |
| grpc_gcp_AltsContext_set_local_service_account(context, |
| local_service_account); |
| grpc_gcp_AltsContext_set_peer_rpc_versions( |
| context, const_cast<grpc_gcp_RpcProtocolVersions*>(peer_rpc_version)); |
| grpc_gcp_Identity* peer_identity = const_cast<grpc_gcp_Identity*>(identity); |
| if (peer_identity == nullptr) { |
| gpr_log(GPR_ERROR, "Null peer identity in ALTS context."); |
| return TSI_FAILED_PRECONDITION; |
| } |
| if (grpc_gcp_Identity_has_attributes(identity)) { |
| size_t iter = UPB_MAP_BEGIN; |
| grpc_gcp_Identity_AttributesEntry* peer_attributes_entry = |
| grpc_gcp_Identity_attributes_nextmutable(peer_identity, &iter); |
| while (peer_attributes_entry != nullptr) { |
| upb_strview key = grpc_gcp_Identity_AttributesEntry_key( |
| const_cast<grpc_gcp_Identity_AttributesEntry*>( |
| peer_attributes_entry)); |
| upb_strview val = grpc_gcp_Identity_AttributesEntry_value( |
| const_cast<grpc_gcp_Identity_AttributesEntry*>( |
| peer_attributes_entry)); |
| grpc_gcp_AltsContext_peer_attributes_set(context, key, val, |
| context_arena.ptr()); |
| peer_attributes_entry = |
| grpc_gcp_Identity_attributes_nextmutable(peer_identity, &iter); |
| } |
| } |
| size_t serialized_ctx_length; |
| char* serialized_ctx = grpc_gcp_AltsContext_serialize( |
| context, context_arena.ptr(), &serialized_ctx_length); |
| if (serialized_ctx == nullptr) { |
| gpr_log(GPR_ERROR, "Failed to serialize peer's ALTS context."); |
| return TSI_FAILED_PRECONDITION; |
| } |
| result->serialized_context = |
| grpc_slice_from_copied_buffer(serialized_ctx, serialized_ctx_length); |
| result->is_client = is_client; |
| result->base.vtable = &result_vtable; |
| *self = &result->base; |
| return TSI_OK; |
| } |
| |
| /* gRPC provided callback used when gRPC thread model is applied. */ |
| static void on_handshaker_service_resp_recv(void* arg, grpc_error* error) { |
| alts_handshaker_client* client = static_cast<alts_handshaker_client*>(arg); |
| if (client == nullptr) { |
| gpr_log(GPR_ERROR, "ALTS handshaker client is nullptr"); |
| return; |
| } |
| bool success = true; |
| if (error != GRPC_ERROR_NONE) { |
| gpr_log(GPR_ERROR, |
| "ALTS handshaker on_handshaker_service_resp_recv error: %s", |
| grpc_error_string(error)); |
| success = false; |
| } |
| alts_handshaker_client_handle_response(client, success); |
| } |
| |
| /* gRPC provided callback used when dedicatd CQ and thread are used. |
| * It serves to safely bring the control back to application. */ |
| static void on_handshaker_service_resp_recv_dedicated(void* arg, |
| grpc_error* /*error*/) { |
| alts_shared_resource_dedicated* resource = |
| grpc_alts_get_shared_resource_dedicated(); |
| grpc_cq_end_op(resource->cq, arg, GRPC_ERROR_NONE, |
| [](void* /*done_arg*/, grpc_cq_completion* /*storage*/) {}, |
| nullptr, &resource->storage); |
| } |
| |
| /* Returns TSI_OK if and only if no error is encountered. */ |
| static tsi_result alts_tsi_handshaker_continue_handshaker_next( |
| alts_tsi_handshaker* handshaker, const unsigned char* received_bytes, |
| size_t received_bytes_size, tsi_handshaker_on_next_done_cb cb, |
| void* user_data) { |
| if (!handshaker->has_created_handshaker_client) { |
| if (handshaker->channel == nullptr) { |
| grpc_alts_shared_resource_dedicated_start( |
| handshaker->handshaker_service_url); |
| handshaker->interested_parties = |
| grpc_alts_get_shared_resource_dedicated()->interested_parties; |
| GPR_ASSERT(handshaker->interested_parties != nullptr); |
| } |
| grpc_iomgr_cb_func grpc_cb = handshaker->channel == nullptr |
| ? on_handshaker_service_resp_recv_dedicated |
| : on_handshaker_service_resp_recv; |
| grpc_channel* channel = |
| handshaker->channel == nullptr |
| ? grpc_alts_get_shared_resource_dedicated()->channel |
| : handshaker->channel; |
| alts_handshaker_client* client = alts_grpc_handshaker_client_create( |
| handshaker, channel, handshaker->handshaker_service_url, |
| handshaker->interested_parties, handshaker->options, |
| handshaker->target_name, grpc_cb, cb, user_data, |
| handshaker->client_vtable_for_testing, handshaker->is_client, |
| handshaker->max_frame_size); |
| if (client == nullptr) { |
| gpr_log(GPR_ERROR, "Failed to create ALTS handshaker client"); |
| return TSI_FAILED_PRECONDITION; |
| } |
| { |
| grpc_core::MutexLock lock(&handshaker->mu); |
| GPR_ASSERT(handshaker->client == nullptr); |
| handshaker->client = client; |
| if (handshaker->shutdown) { |
| gpr_log(GPR_ERROR, "TSI handshake shutdown"); |
| return TSI_HANDSHAKE_SHUTDOWN; |
| } |
| } |
| handshaker->has_created_handshaker_client = true; |
| } |
| if (handshaker->channel == nullptr && |
| handshaker->client_vtable_for_testing == nullptr) { |
| GPR_ASSERT(grpc_cq_begin_op(grpc_alts_get_shared_resource_dedicated()->cq, |
| handshaker->client)); |
| } |
| grpc_slice slice = (received_bytes == nullptr || received_bytes_size == 0) |
| ? grpc_empty_slice() |
| : grpc_slice_from_copied_buffer( |
| reinterpret_cast<const char*>(received_bytes), |
| received_bytes_size); |
| tsi_result ok = TSI_OK; |
| if (!handshaker->has_sent_start_message) { |
| handshaker->has_sent_start_message = true; |
| ok = handshaker->is_client |
| ? alts_handshaker_client_start_client(handshaker->client) |
| : alts_handshaker_client_start_server(handshaker->client, &slice); |
| // It's unsafe for the current thread to access any state in handshaker |
| // at this point, since alts_handshaker_client_start_client/server |
| // have potentially just started an op batch on the handshake call. |
| // The completion callback for that batch is unsynchronized and so |
| // can invoke the TSI next API callback from any thread, at which point |
| // there is nothing taking ownership of this handshaker to prevent it |
| // from being destroyed. |
| } else { |
| ok = alts_handshaker_client_next(handshaker->client, &slice); |
| } |
| grpc_slice_unref_internal(slice); |
| return ok; |
| } |
| |
| struct alts_tsi_handshaker_continue_handshaker_next_args { |
| alts_tsi_handshaker* handshaker; |
| std::unique_ptr<unsigned char> received_bytes; |
| size_t received_bytes_size; |
| tsi_handshaker_on_next_done_cb cb; |
| void* user_data; |
| grpc_closure closure; |
| }; |
| |
| static void alts_tsi_handshaker_create_channel(void* arg, |
| grpc_error* /* unused_error */) { |
| alts_tsi_handshaker_continue_handshaker_next_args* next_args = |
| static_cast<alts_tsi_handshaker_continue_handshaker_next_args*>(arg); |
| alts_tsi_handshaker* handshaker = next_args->handshaker; |
| GPR_ASSERT(handshaker->channel == nullptr); |
| handshaker->channel = grpc_insecure_channel_create( |
| next_args->handshaker->handshaker_service_url, nullptr, nullptr); |
| tsi_result continue_next_result = |
| alts_tsi_handshaker_continue_handshaker_next( |
| handshaker, next_args->received_bytes.get(), |
| next_args->received_bytes_size, next_args->cb, next_args->user_data); |
| if (continue_next_result != TSI_OK) { |
| next_args->cb(continue_next_result, next_args->user_data, nullptr, 0, |
| nullptr); |
| } |
| delete next_args; |
| } |
| |
| static tsi_result handshaker_next( |
| tsi_handshaker* self, const unsigned char* received_bytes, |
| size_t received_bytes_size, const unsigned char** /*bytes_to_send*/, |
| size_t* /*bytes_to_send_size*/, tsi_handshaker_result** /*result*/, |
| tsi_handshaker_on_next_done_cb cb, void* user_data) { |
| if (self == nullptr || cb == nullptr) { |
| gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()"); |
| return TSI_INVALID_ARGUMENT; |
| } |
| alts_tsi_handshaker* handshaker = |
| reinterpret_cast<alts_tsi_handshaker*>(self); |
| { |
| grpc_core::MutexLock lock(&handshaker->mu); |
| if (handshaker->shutdown) { |
| gpr_log(GPR_ERROR, "TSI handshake shutdown"); |
| return TSI_HANDSHAKE_SHUTDOWN; |
| } |
| } |
| if (handshaker->channel == nullptr && !handshaker->use_dedicated_cq) { |
| alts_tsi_handshaker_continue_handshaker_next_args* args = |
| new alts_tsi_handshaker_continue_handshaker_next_args(); |
| args->handshaker = handshaker; |
| args->received_bytes = nullptr; |
| args->received_bytes_size = received_bytes_size; |
| if (received_bytes_size > 0) { |
| args->received_bytes = std::unique_ptr<unsigned char>( |
| static_cast<unsigned char*>(gpr_zalloc(received_bytes_size))); |
| memcpy(args->received_bytes.get(), received_bytes, received_bytes_size); |
| } |
| args->cb = cb; |
| args->user_data = user_data; |
| GRPC_CLOSURE_INIT(&args->closure, alts_tsi_handshaker_create_channel, args, |
| grpc_schedule_on_exec_ctx); |
| // We continue this handshaker_next call at the bottom of the ExecCtx just |
| // so that we can invoke grpc_channel_create at the bottom of the call |
| // stack. Doing so avoids potential lock cycles between g_init_mu and other |
| // mutexes within core that might be held on the current call stack |
| // (note that g_init_mu gets acquired during channel creation). |
| grpc_core::ExecCtx::Run(DEBUG_LOCATION, &args->closure, GRPC_ERROR_NONE); |
| } else { |
| tsi_result ok = alts_tsi_handshaker_continue_handshaker_next( |
| handshaker, received_bytes, received_bytes_size, cb, user_data); |
| if (ok != TSI_OK) { |
| gpr_log(GPR_ERROR, "Failed to schedule ALTS handshaker requests"); |
| return ok; |
| } |
| } |
| return TSI_ASYNC; |
| } |
| |
| /* |
| * This API will be invoked by a non-gRPC application, and an ExecCtx needs |
| * to be explicitly created in order to invoke ALTS handshaker client API's |
| * that assumes the caller is inside gRPC core. |
| */ |
| static tsi_result handshaker_next_dedicated( |
| tsi_handshaker* self, const unsigned char* received_bytes, |
| size_t received_bytes_size, const unsigned char** bytes_to_send, |
| size_t* bytes_to_send_size, tsi_handshaker_result** result, |
| tsi_handshaker_on_next_done_cb cb, void* user_data) { |
| grpc_core::ExecCtx exec_ctx; |
| return handshaker_next(self, received_bytes, received_bytes_size, |
| bytes_to_send, bytes_to_send_size, result, cb, |
| user_data); |
| } |
| |
| static void handshaker_shutdown(tsi_handshaker* self) { |
| GPR_ASSERT(self != nullptr); |
| alts_tsi_handshaker* handshaker = |
| reinterpret_cast<alts_tsi_handshaker*>(self); |
| grpc_core::MutexLock lock(&handshaker->mu); |
| if (handshaker->shutdown) { |
| return; |
| } |
| if (handshaker->client != nullptr) { |
| alts_handshaker_client_shutdown(handshaker->client); |
| } |
| handshaker->shutdown = true; |
| } |
| |
| static void handshaker_destroy(tsi_handshaker* self) { |
| if (self == nullptr) { |
| return; |
| } |
| alts_tsi_handshaker* handshaker = |
| reinterpret_cast<alts_tsi_handshaker*>(self); |
| alts_handshaker_client_destroy(handshaker->client); |
| grpc_slice_unref_internal(handshaker->target_name); |
| grpc_alts_credentials_options_destroy(handshaker->options); |
| if (handshaker->channel != nullptr) { |
| grpc_channel_destroy_internal(handshaker->channel); |
| } |
| gpr_free(handshaker->handshaker_service_url); |
| gpr_mu_destroy(&handshaker->mu); |
| gpr_free(handshaker); |
| } |
| |
| static const tsi_handshaker_vtable handshaker_vtable = { |
| nullptr, nullptr, |
| nullptr, nullptr, |
| nullptr, handshaker_destroy, |
| handshaker_next, handshaker_shutdown}; |
| |
| static const tsi_handshaker_vtable handshaker_vtable_dedicated = { |
| nullptr, |
| nullptr, |
| nullptr, |
| nullptr, |
| nullptr, |
| handshaker_destroy, |
| handshaker_next_dedicated, |
| handshaker_shutdown}; |
| |
| bool alts_tsi_handshaker_has_shutdown(alts_tsi_handshaker* handshaker) { |
| GPR_ASSERT(handshaker != nullptr); |
| grpc_core::MutexLock lock(&handshaker->mu); |
| return handshaker->shutdown; |
| } |
| |
| tsi_result alts_tsi_handshaker_create( |
| const grpc_alts_credentials_options* options, const char* target_name, |
| const char* handshaker_service_url, bool is_client, |
| grpc_pollset_set* interested_parties, tsi_handshaker** self, |
| size_t user_specified_max_frame_size) { |
| if (handshaker_service_url == nullptr || self == nullptr || |
| options == nullptr || (is_client && target_name == nullptr)) { |
| gpr_log(GPR_ERROR, "Invalid arguments to alts_tsi_handshaker_create()"); |
| return TSI_INVALID_ARGUMENT; |
| } |
| alts_tsi_handshaker* handshaker = |
| static_cast<alts_tsi_handshaker*>(gpr_zalloc(sizeof(*handshaker))); |
| gpr_mu_init(&handshaker->mu); |
| handshaker->use_dedicated_cq = interested_parties == nullptr; |
| handshaker->client = nullptr; |
| handshaker->is_client = is_client; |
| handshaker->has_sent_start_message = false; |
| handshaker->target_name = target_name == nullptr |
| ? grpc_empty_slice() |
| : grpc_slice_from_static_string(target_name); |
| handshaker->interested_parties = interested_parties; |
| handshaker->has_created_handshaker_client = false; |
| handshaker->handshaker_service_url = gpr_strdup(handshaker_service_url); |
| handshaker->options = grpc_alts_credentials_options_copy(options); |
| handshaker->max_frame_size = user_specified_max_frame_size != 0 |
| ? user_specified_max_frame_size |
| : kTsiAltsMaxFrameSize; |
| handshaker->base.vtable = handshaker->use_dedicated_cq |
| ? &handshaker_vtable_dedicated |
| : &handshaker_vtable; |
| *self = &handshaker->base; |
| return TSI_OK; |
| } |
| |
| void alts_tsi_handshaker_result_set_unused_bytes(tsi_handshaker_result* self, |
| grpc_slice* recv_bytes, |
| size_t bytes_consumed) { |
| GPR_ASSERT(recv_bytes != nullptr && self != nullptr); |
| if (GRPC_SLICE_LENGTH(*recv_bytes) == bytes_consumed) { |
| return; |
| } |
| alts_tsi_handshaker_result* result = |
| reinterpret_cast<alts_tsi_handshaker_result*>(self); |
| result->unused_bytes_size = GRPC_SLICE_LENGTH(*recv_bytes) - bytes_consumed; |
| result->unused_bytes = |
| static_cast<unsigned char*>(gpr_zalloc(result->unused_bytes_size)); |
| memcpy(result->unused_bytes, |
| GRPC_SLICE_START_PTR(*recv_bytes) + bytes_consumed, |
| result->unused_bytes_size); |
| } |
| |
| namespace grpc_core { |
| namespace internal { |
| |
| bool alts_tsi_handshaker_get_has_sent_start_message_for_testing( |
| alts_tsi_handshaker* handshaker) { |
| GPR_ASSERT(handshaker != nullptr); |
| return handshaker->has_sent_start_message; |
| } |
| |
| void alts_tsi_handshaker_set_client_vtable_for_testing( |
| alts_tsi_handshaker* handshaker, alts_handshaker_client_vtable* vtable) { |
| GPR_ASSERT(handshaker != nullptr); |
| handshaker->client_vtable_for_testing = vtable; |
| } |
| |
| bool alts_tsi_handshaker_get_is_client_for_testing( |
| alts_tsi_handshaker* handshaker) { |
| GPR_ASSERT(handshaker != nullptr); |
| return handshaker->is_client; |
| } |
| |
| alts_handshaker_client* alts_tsi_handshaker_get_client_for_testing( |
| alts_tsi_handshaker* handshaker) { |
| return handshaker->client; |
| } |
| |
| } // namespace internal |
| } // namespace grpc_core |