Skip to content

Commit 255bc5c

Browse files
ctillercopybara-github
authored andcommittedAug 6, 2024
[security-handshaker] Simplify refcounting (grpc#37345)
Make the refcounting in the class a little less manual Closes grpc#37345 COPYBARA_INTEGRATE_REVIEW=grpc#37345 from ctiller:things-that-make-you-go-hmmm 67f23b6 PiperOrigin-RevId: 660022927
1 parent 3de09c5 commit 255bc5c

File tree

1 file changed

+60
-71
lines changed

1 file changed

+60
-71
lines changed
 

‎src/core/handshaker/security/security_handshaker.cc

+60-71
Original file line numberDiff line numberDiff line change
@@ -88,27 +88,27 @@ class SecurityHandshaker : public Handshaker {
8888

8989
private:
9090
grpc_error_handle DoHandshakerNextLocked(const unsigned char* bytes_received,
91-
size_t bytes_received_size);
91+
size_t bytes_received_size)
92+
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
9293

9394
grpc_error_handle OnHandshakeNextDoneLocked(
9495
tsi_result result, const unsigned char* bytes_to_send,
95-
size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
96-
void HandshakeFailedLocked(absl::Status error);
96+
size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result)
97+
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
98+
void HandshakeFailedLocked(absl::Status error)
99+
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
97100
void Finish(absl::Status status);
98101

99102
void OnHandshakeDataReceivedFromPeerFn(absl::Status error);
100103
void OnHandshakeDataSentToPeerFn(absl::Status error);
101-
static void OnHandshakeDataReceivedFromPeerFnScheduler(
102-
void* arg, grpc_error_handle error);
103-
static void OnHandshakeDataSentToPeerFnScheduler(void* arg,
104-
grpc_error_handle error);
104+
void OnHandshakeDataReceivedFromPeerFnScheduler(grpc_error_handle error);
105+
void OnHandshakeDataSentToPeerFnScheduler(grpc_error_handle error);
105106
static void OnHandshakeNextDoneGrpcWrapper(
106107
tsi_result result, void* user_data, const unsigned char* bytes_to_send,
107108
size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
108-
static void OnPeerCheckedFn(void* arg, grpc_error_handle error);
109-
void OnPeerCheckedInner(grpc_error_handle error);
109+
void OnPeerCheckedFn(grpc_error_handle error);
110110
size_t MoveReadBufferIntoHandshakeBuffer();
111-
grpc_error_handle CheckPeerLocked();
111+
grpc_error_handle CheckPeerLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
112112

113113
// State set at creation time.
114114
tsi_handshaker* handshaker_;
@@ -125,13 +125,11 @@ class SecurityHandshaker : public Handshaker {
125125
size_t handshake_buffer_size_;
126126
unsigned char* handshake_buffer_;
127127
SliceBuffer outgoing_;
128-
grpc_closure on_handshake_data_sent_to_peer_;
129-
grpc_closure on_handshake_data_received_from_peer_;
130-
grpc_closure on_peer_checked_;
131128
RefCountedPtr<grpc_auth_context> auth_context_;
132129
tsi_handshaker_result* handshaker_result_ = nullptr;
133130
size_t max_frame_size_ = 0;
134131
std::string tsi_handshake_error_;
132+
grpc_closure* on_peer_checked_ ABSL_GUARDED_BY(mu_) = nullptr;
135133
};
136134

137135
SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker,
@@ -143,10 +141,7 @@ SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker,
143141
handshake_buffer_(
144142
static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size_))),
145143
max_frame_size_(
146-
std::max(0, args.GetInt(GRPC_ARG_TSI_MAX_FRAME_SIZE).value_or(0))) {
147-
GRPC_CLOSURE_INIT(&on_peer_checked_, &SecurityHandshaker::OnPeerCheckedFn,
148-
this, grpc_schedule_on_exec_ctx);
149-
}
144+
std::max(0, args.GetInt(GRPC_ARG_TSI_MAX_FRAME_SIZE).value_or(0))) {}
150145

151146
SecurityHandshaker::~SecurityHandshaker() {
152147
tsi_handshaker_destroy(handshaker_);
@@ -220,8 +215,9 @@ MakeChannelzSecurityFromAuthContext(grpc_auth_context* auth_context) {
220215

221216
} // namespace
222217

223-
void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) {
218+
void SecurityHandshaker::OnPeerCheckedFn(grpc_error_handle error) {
224219
MutexLock lock(&mu_);
220+
on_peer_checked_ = nullptr;
225221
if (!error.ok() || is_shutdown_) {
226222
HandshakeFailedLocked(error);
227223
return;
@@ -317,11 +313,6 @@ void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) {
317313
Finish(absl::OkStatus());
318314
}
319315

320-
void SecurityHandshaker::OnPeerCheckedFn(void* arg, grpc_error_handle error) {
321-
RefCountedPtr<SecurityHandshaker>(static_cast<SecurityHandshaker*>(arg))
322-
->OnPeerCheckedInner(error);
323-
}
324-
325316
grpc_error_handle SecurityHandshaker::CheckPeerLocked() {
326317
tsi_peer peer;
327318
tsi_result result =
@@ -330,8 +321,12 @@ grpc_error_handle SecurityHandshaker::CheckPeerLocked() {
330321
return GRPC_ERROR_CREATE(absl::StrCat("Peer extraction failed (",
331322
tsi_result_to_string(result), ")"));
332323
}
324+
on_peer_checked_ = NewClosure(
325+
[self = RefAsSubclass<SecurityHandshaker>()](absl::Status status) {
326+
self->OnPeerCheckedFn(std::move(status));
327+
});
333328
connector_->check_peer(peer, args_->endpoint.get(), args_->args,
334-
&auth_context_, &on_peer_checked_);
329+
&auth_context_, on_peer_checked_);
335330
grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name(
336331
auth_context_.get(), GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME);
337332
const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it);
@@ -356,10 +351,10 @@ grpc_error_handle SecurityHandshaker::OnHandshakeNextDoneLocked(
356351
CHECK_EQ(bytes_to_send_size, 0u);
357352
grpc_endpoint_read(
358353
args_->endpoint.get(), args_->read_buffer.c_slice_buffer(),
359-
GRPC_CLOSURE_INIT(
360-
&on_handshake_data_received_from_peer_,
361-
&SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler,
362-
this, grpc_schedule_on_exec_ctx),
354+
NewClosure([self = RefAsSubclass<SecurityHandshaker>()](
355+
absl::Status status) {
356+
self->OnHandshakeDataReceivedFromPeerFnScheduler(std::move(status));
357+
}),
363358
/*urgent=*/true, /*min_progress_size=*/1);
364359
return error;
365360
}
@@ -387,19 +382,19 @@ grpc_error_handle SecurityHandshaker::OnHandshakeNextDoneLocked(
387382
reinterpret_cast<const char*>(bytes_to_send), bytes_to_send_size));
388383
grpc_endpoint_write(
389384
args_->endpoint.get(), outgoing_.c_slice_buffer(),
390-
GRPC_CLOSURE_INIT(
391-
&on_handshake_data_sent_to_peer_,
392-
&SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler, this,
393-
grpc_schedule_on_exec_ctx),
385+
NewClosure(
386+
[self = RefAsSubclass<SecurityHandshaker>()](absl::Status status) {
387+
self->OnHandshakeDataSentToPeerFnScheduler(std::move(status));
388+
}),
394389
nullptr, /*max_frame_size=*/INT_MAX);
395390
} else if (handshaker_result == nullptr) {
396391
// There is nothing to send, but need to read from peer.
397392
grpc_endpoint_read(
398393
args_->endpoint.get(), args_->read_buffer.c_slice_buffer(),
399-
GRPC_CLOSURE_INIT(
400-
&on_handshake_data_received_from_peer_,
401-
&SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler,
402-
this, grpc_schedule_on_exec_ctx),
394+
NewClosure([self = RefAsSubclass<SecurityHandshaker>()](
395+
absl::Status status) {
396+
self->OnHandshakeDataReceivedFromPeerFnScheduler(std::move(status));
397+
}),
403398
/*urgent=*/true, /*min_progress_size=*/1);
404399
} else {
405400
// Handshake has finished, check peer and so on.
@@ -418,8 +413,6 @@ void SecurityHandshaker::OnHandshakeNextDoneGrpcWrapper(
418413
result, bytes_to_send, bytes_to_send_size, handshaker_result);
419414
if (!error.ok()) {
420415
h->HandshakeFailedLocked(std::move(error));
421-
} else {
422-
h.release(); // Avoid unref
423416
}
424417
}
425418

@@ -429,13 +422,15 @@ grpc_error_handle SecurityHandshaker::DoHandshakerNextLocked(
429422
const unsigned char* bytes_to_send = nullptr;
430423
size_t bytes_to_send_size = 0;
431424
tsi_handshaker_result* hs_result = nullptr;
425+
auto self = RefAsSubclass<SecurityHandshaker>();
432426
tsi_result result = tsi_handshaker_next(
433427
handshaker_, bytes_received, bytes_received_size, &bytes_to_send,
434-
&bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper, this,
435-
&tsi_handshake_error_);
428+
&bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper,
429+
self.get(), &tsi_handshake_error_);
436430
if (result == TSI_ASYNC) {
437-
// Handshaker operating asynchronously. Nothing else to do here;
438-
// callback will be invoked in a TSI thread.
431+
// Handshaker operating asynchronously. Callback will be invoked in a TSI
432+
// thread. We no longer own the ref held in self.
433+
self.release();
439434
return absl::OkStatus();
440435
}
441436
// Handshaker returned synchronously. Invoke callback directly in
@@ -449,18 +444,18 @@ grpc_error_handle SecurityHandshaker::DoHandshakerNextLocked(
449444
// TODO(roth): This will no longer be necessary once we migrate to the
450445
// EventEngine endpoint API.
451446
void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler(
452-
void* arg, grpc_error_handle error) {
453-
SecurityHandshaker* handshaker = static_cast<SecurityHandshaker*>(arg);
454-
handshaker->args_->event_engine->Run(
455-
[handshaker, error = std::move(error)]() mutable {
456-
ApplicationCallbackExecCtx callback_exec_ctx;
457-
ExecCtx exec_ctx;
458-
handshaker->OnHandshakeDataReceivedFromPeerFn(std::move(error));
459-
});
447+
grpc_error_handle error) {
448+
args_->event_engine->Run([self = RefAsSubclass<SecurityHandshaker>(),
449+
error = std::move(error)]() mutable {
450+
ApplicationCallbackExecCtx callback_exec_ctx;
451+
ExecCtx exec_ctx;
452+
self->OnHandshakeDataReceivedFromPeerFn(std::move(error));
453+
// Avoid destruction outside of an ExecCtx (since this is non-cancelable).
454+
self.reset();
455+
});
460456
}
461457

462458
void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(absl::Status error) {
463-
RefCountedPtr<SecurityHandshaker> handshaker(this);
464459
MutexLock lock(&mu_);
465460
if (!error.ok() || is_shutdown_) {
466461
HandshakeFailedLocked(
@@ -473,8 +468,6 @@ void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(absl::Status error) {
473468
error = DoHandshakerNextLocked(handshake_buffer_, bytes_received_size);
474469
if (!error.ok()) {
475470
HandshakeFailedLocked(std::move(error));
476-
} else {
477-
handshaker.release(); // Avoid unref
478471
}
479472
}
480473

@@ -483,18 +476,18 @@ void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(absl::Status error) {
483476
// TODO(roth): This will no longer be necessary once we migrate to the
484477
// EventEngine endpoint API.
485478
void SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler(
486-
void* arg, grpc_error_handle error) {
487-
SecurityHandshaker* handshaker = static_cast<SecurityHandshaker*>(arg);
488-
handshaker->args_->event_engine->Run(
489-
[handshaker, error = std::move(error)]() mutable {
490-
ApplicationCallbackExecCtx callback_exec_ctx;
491-
ExecCtx exec_ctx;
492-
handshaker->OnHandshakeDataSentToPeerFn(std::move(error));
493-
});
479+
grpc_error_handle error) {
480+
args_->event_engine->Run([self = RefAsSubclass<SecurityHandshaker>(),
481+
error = std::move(error)]() mutable {
482+
ApplicationCallbackExecCtx callback_exec_ctx;
483+
ExecCtx exec_ctx;
484+
self->OnHandshakeDataSentToPeerFn(std::move(error));
485+
// Avoid destruction outside of an ExecCtx (since this is non-cancelable).
486+
self.reset();
487+
});
494488
}
495489

496490
void SecurityHandshaker::OnHandshakeDataSentToPeerFn(absl::Status error) {
497-
RefCountedPtr<SecurityHandshaker> handshaker(this);
498491
MutexLock lock(&mu_);
499492
if (!error.ok() || is_shutdown_) {
500493
HandshakeFailedLocked(
@@ -505,10 +498,10 @@ void SecurityHandshaker::OnHandshakeDataSentToPeerFn(absl::Status error) {
505498
if (handshaker_result_ == nullptr) {
506499
grpc_endpoint_read(
507500
args_->endpoint.get(), args_->read_buffer.c_slice_buffer(),
508-
GRPC_CLOSURE_INIT(
509-
&on_handshake_data_received_from_peer_,
510-
&SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler,
511-
this, grpc_schedule_on_exec_ctx),
501+
NewClosure([self = RefAsSubclass<SecurityHandshaker>()](
502+
absl::Status status) {
503+
self->OnHandshakeDataReceivedFromPeerFnScheduler(std::move(status));
504+
}),
512505
/*urgent=*/true, /*min_progress_size=*/1);
513506
} else {
514507
error = CheckPeerLocked();
@@ -517,7 +510,6 @@ void SecurityHandshaker::OnHandshakeDataSentToPeerFn(absl::Status error) {
517510
return;
518511
}
519512
}
520-
handshaker.release(); // Avoid unref
521513
}
522514

523515
//
@@ -528,7 +520,7 @@ void SecurityHandshaker::Shutdown(grpc_error_handle error) {
528520
MutexLock lock(&mu_);
529521
if (!is_shutdown_) {
530522
is_shutdown_ = true;
531-
connector_->cancel_check_peer(&on_peer_checked_, std::move(error));
523+
connector_->cancel_check_peer(on_peer_checked_, std::move(error));
532524
tsi_handshaker_shutdown(handshaker_);
533525
args_->endpoint.reset();
534526
}
@@ -537,7 +529,6 @@ void SecurityHandshaker::Shutdown(grpc_error_handle error) {
537529
void SecurityHandshaker::DoHandshake(
538530
HandshakerArgs* args,
539531
absl::AnyInvocable<void(absl::Status)> on_handshake_done) {
540-
auto ref = Ref();
541532
MutexLock lock(&mu_);
542533
args_ = args;
543534
on_handshake_done_ = std::move(on_handshake_done);
@@ -546,8 +537,6 @@ void SecurityHandshaker::DoHandshake(
546537
DoHandshakerNextLocked(handshake_buffer_, bytes_received_size);
547538
if (!error.ok()) {
548539
HandshakeFailedLocked(error);
549-
} else {
550-
ref.release(); // Avoid unref
551540
}
552541
}
553542

0 commit comments

Comments
 (0)
Please sign in to comment.