rustls: implement connect_blocking
Closes #11647
diff --git a/lib/vtls/rustls.c b/lib/vtls/rustls.c
index 70b200a..8751fd9 100644
--- a/lib/vtls/rustls.c
+++ b/lib/vtls/rustls.c
@@ -39,6 +39,7 @@
#include "select.h"
#include "strerror.h"
#include "multiif.h"
+#include "connect.h" /* for the connect timeout */
struct rustls_ssl_backend_data
{
@@ -75,14 +76,6 @@
return backend->data_pending;
}
-static CURLcode
-cr_connect(struct Curl_cfilter *cf UNUSED_PARAM,
- struct Curl_easy *data UNUSED_PARAM)
-{
- infof(data, "rustls_connect: unimplemented");
- return CURLE_SSL_CONNECT_ERROR;
-}
-
struct io_ctx {
struct Curl_cfilter *cf;
struct Curl_easy *data;
@@ -485,9 +478,20 @@
Curl_alpn_set_negotiated(cf, data, protocol, len);
}
+/* Given an established network connection, do a TLS handshake.
+ *
+ * If `blocking` is true, this function will block until the handshake is
+ * complete. Otherwise it will return as soon as I/O would block.
+ *
+ * For the non-blocking I/O case, this function will set `*done` to true
+ * once the handshake is complete. This function never reads the value of
+ * `*done*`.
+ */
static CURLcode
-cr_connect_nonblocking(struct Curl_cfilter *cf,
- struct Curl_easy *data, bool *done)
+cr_connect_common(struct Curl_cfilter *cf,
+ struct Curl_easy *data,
+ bool blocking,
+ bool *done)
{
struct ssl_connect_data *const connssl = cf->ctx;
curl_socket_t sockfd = Curl_conn_cf_get_socket(cf, data);
@@ -501,6 +505,8 @@
bool wants_write;
curl_socket_t writefd;
curl_socket_t readfd;
+ timediff_t timeout_ms;
+ timediff_t socket_check_timeout;
DEBUGASSERT(backend);
@@ -538,12 +544,29 @@
writefd = wants_write?sockfd:CURL_SOCKET_BAD;
readfd = wants_read?sockfd:CURL_SOCKET_BAD;
- what = Curl_socket_check(readfd, CURL_SOCKET_BAD, writefd, 0);
+ /* check allowed time left */
+ timeout_ms = Curl_timeleft(data, NULL, TRUE);
+
+ if(timeout_ms < 0) {
+ /* no need to continue if time already is up */
+ failf(data, "rustls: operation timed out before socket check");
+ return CURLE_OPERATION_TIMEDOUT;
+ }
+
+ socket_check_timeout = blocking?timeout_ms:0;
+
+ what = Curl_socket_check(
+ readfd, CURL_SOCKET_BAD, writefd, socket_check_timeout);
if(what < 0) {
/* fatal error */
failf(data, "select/poll on SSL socket, errno: %d", SOCKERRNO);
return CURLE_SSL_CONNECT_ERROR;
}
+ if(blocking && 0 == what) {
+ failf(data, "rustls connection timeout after %d ms",
+ socket_check_timeout);
+ return CURLE_OPERATION_TIMEDOUT;
+ }
if(0 == what) {
infof(data, "Curl_socket_check: %s would block",
wants_read&&wants_write ? "writing and reading" :
@@ -588,6 +611,21 @@
DEBUGASSERT(false);
}
+static CURLcode
+cr_connect_nonblocking(struct Curl_cfilter *cf,
+ struct Curl_easy *data, bool *done)
+{
+ return cr_connect_common(cf, data, false, done);
+}
+
+static CURLcode
+cr_connect_blocking(struct Curl_cfilter *cf UNUSED_PARAM,
+ struct Curl_easy *data UNUSED_PARAM)
+{
+ bool done; /* unused */
+ return cr_connect_common(cf, data, true, &done);
+}
+
static void cr_adjust_pollset(struct Curl_cfilter *cf,
struct Curl_easy *data,
struct easy_pollset *ps)
@@ -670,7 +708,7 @@
cr_data_pending, /* data_pending */
Curl_none_random, /* random */
Curl_none_cert_status_request, /* cert_status_request */
- cr_connect, /* connect */
+ cr_connect_blocking, /* connect */
cr_connect_nonblocking, /* connect_nonblocking */
cr_adjust_pollset, /* adjust_pollset */
cr_get_internals, /* get_internals */