ustream-ssl: add support for using a fd instead of ustream as backing
authorFelix Fietkau <nbd@nbd.name>
Fri, 19 Apr 2024 14:43:35 +0000 (16:43 +0200)
committerFelix Fietkau <nbd@nbd.name>
Fri, 19 Apr 2024 14:47:10 +0000 (16:47 +0200)
This improves performance by avoiding double buffering

Signed-off-by: Felix Fietkau <nbd@nbd.name>
ustream-internal.h
ustream-io-openssl.c
ustream-io-wolfssl.c
ustream-mbedtls.c
ustream-mbedtls.h
ustream-openssl.c
ustream-openssl.h
ustream-ssl.c
ustream-ssl.h

index 50e105f0ddb6f90c5e6305ffb40f19c7fabb5720..4eec9cd054c6b70bb39dd3e06501beeface2ddf1 100644 (file)
@@ -34,7 +34,7 @@ enum ssl_conn_status {
        U_SSL_RETRY = -3,
 };
 
-void ustream_set_io(struct ustream_ssl_ctx *ctx, void *ssl, struct ustream *s);
+void ustream_set_io(struct ustream_ssl *us);
 struct ustream_ssl_ctx *__ustream_ssl_context_new(bool server);
 int __ustream_ssl_add_ca_crt_file(struct ustream_ssl_ctx *ctx, const char *file);
 int __ustream_ssl_set_crt_file(struct ustream_ssl_ctx *ctx, const char *file);
@@ -46,5 +46,6 @@ void __ustream_ssl_context_free(struct ustream_ssl_ctx *ctx);
 enum ssl_conn_status __ustream_ssl_connect(struct ustream_ssl *us);
 int __ustream_ssl_read(struct ustream_ssl *us, char *buf, int len);
 int __ustream_ssl_write(struct ustream_ssl *us, const char *buf, int len);
+void __ustream_ssl_session_free(struct ustream_ssl *us);
 
 #endif
index 7045bb660a36102f022b7c112084333e6918e0aa..4ca77deecafabdba3e193e87a75a169474108664 100644 (file)
@@ -137,8 +137,23 @@ static BIO *ustream_bio_new(struct ustream *s)
        return bio;
 }
 
-__hidden void ustream_set_io(struct ustream_ssl_ctx *ctx, void *ssl, struct ustream *conn)
+static BIO *fd_bio_new(int fd)
 {
-       BIO *bio = ustream_bio_new(conn);
-       SSL_set_bio(ssl, bio, bio);
+       BIO *bio = BIO_new(BIO_s_socket());
+
+       BIO_set_fd(bio, fd, BIO_NOCLOSE);
+
+       return bio;
+}
+
+__hidden void ustream_set_io(struct ustream_ssl *us)
+{
+       BIO *bio;
+
+       if (us->conn)
+               bio = ustream_bio_new(us->conn);
+       else
+               bio = fd_bio_new(us->fd.fd);
+
+       SSL_set_bio(us->ssl, bio, bio);
 }
index 4ff85d34e33331535f25645121d90b11078e20df..0a97edcb481eefb086fe7d370af9f024ee4eff82 100644 (file)
@@ -65,10 +65,15 @@ static int io_send_cb(SSL* ssl, char *buf, int sz, void *ctx)
        return s_ustream_write(buf, sz, ctx);
 }
 
-__hidden void ustream_set_io(struct ustream_ssl_ctx *ctx, void *ssl, struct ustream *conn)
+__hidden void ustream_set_io(struct ustream_ssl *us)
 {
-       wolfSSL_SSLSetIORecv(ssl, io_recv_cb);
-       wolfSSL_SSLSetIOSend(ssl, io_send_cb);
-       wolfSSL_SetIOReadCtx(ssl, conn);
-       wolfSSL_SetIOWriteCtx(ssl, conn);
+       if (!us->conn) {
+               wolfSSL_set_fd(us->ssl, us->fd.fd);
+               return;
+       }
+
+       wolfSSL_SSLSetIORecv(us->ssl, io_recv_cb);
+       wolfSSL_SSLSetIOSend(us->ssl, io_send_cb);
+       wolfSSL_SetIOReadCtx(us->ssl, us->conn);
+       wolfSSL_SetIOWriteCtx(us->ssl, us->conn);
 }
index 6b8e1c0fc0c541fb3e28e9406ffadc6ef4d40fad..361ff9939fa93696dc97e38b620406917506123e 100644 (file)
@@ -85,9 +85,32 @@ static int s_ustream_write(void *ctx, const unsigned char *buf, size_t len)
        return ret;
 }
 
-__hidden void ustream_set_io(struct ustream_ssl_ctx *ctx, void *ssl, struct ustream *conn)
+static int s_fd_read(void *ctx, unsigned char *buf, size_t len)
 {
-       mbedtls_ssl_set_bio(ssl, conn, s_ustream_write, s_ustream_read, NULL);
+       struct uloop_fd *ufd = ctx;
+       mbedtls_net_context net = {
+               .fd = ufd->fd
+       };
+
+       return mbedtls_net_recv(&net, buf, len);
+}
+
+static int s_fd_write(void *ctx, const unsigned char *buf, size_t len)
+{
+       struct uloop_fd *ufd = ctx;
+       mbedtls_net_context net = {
+               .fd = ufd->fd
+       };
+
+       return mbedtls_net_send(&net, buf, len);
+}
+
+__hidden void ustream_set_io(struct ustream_ssl *us)
+{
+       if (us->conn)
+               mbedtls_ssl_set_bio(us->ssl, us->conn, s_ustream_write, s_ustream_read, NULL);
+       else
+               mbedtls_ssl_set_bio(us->ssl, &us->fd, s_fd_write, s_fd_read, NULL);
 }
 
 static int _random(void *ctx, unsigned char *out, size_t len)
@@ -553,8 +576,8 @@ __hidden void *__ustream_ssl_session_new(struct ustream_ssl_ctx *ctx)
        return ssl;
 }
 
-__hidden void __ustream_ssl_session_free(void *ssl)
+__hidden void __ustream_ssl_session_free(struct ustream_ssl *us)
 {
-       mbedtls_ssl_free(ssl);
-       free(ssl);
+       mbedtls_ssl_free(us->ssl);
+       free(us->ssl);
 }
index 31df680d2e1d9915bf57b4de54675afda0619d3e..281b9195abc9e37034c234e5a24e64ef04ef553f 100644 (file)
@@ -64,7 +64,6 @@ static inline void __ustream_ssl_update_peer_cn(struct ustream_ssl *us)
        mbedtls_ssl_set_hostname(us->ssl, us->peer_cn);
 }
 
-void __ustream_ssl_session_free(void *ssl);
 void *__ustream_ssl_session_new(struct ustream_ssl_ctx *ctx);
 
 #endif
index 3d576be1cf9397adf812bc03f1c87fa4096f9678..b080081c172ffc14786ef61b8e5c1e197b4e3003 100644 (file)
@@ -245,13 +245,18 @@ __hidden void __ustream_ssl_context_free(struct ustream_ssl_ctx *ctx)
        free(ctx);
 }
 
-void __ustream_ssl_session_free(void *ssl)
+__hidden void __ustream_ssl_session_free(struct ustream_ssl *us)
 {
-       BIO *bio = SSL_get_wbio(ssl);
-       struct bio_ctx *ctx = BIO_get_data(bio);
+       BIO *bio = SSL_get_wbio(us->ssl);
+       struct bio_ctx *ctx;
 
-       SSL_shutdown(ssl);
-       SSL_free(ssl);
+       SSL_shutdown(us->ssl);
+       SSL_free(us->ssl);
+
+       if (!us->conn)
+               return;
+
+       ctx = BIO_get_data(bio);
        if (ctx) {
                BIO_meth_free(ctx->meth);
                free(ctx);
index f547aa663f98f6f93602d47b12ede05f0cc790a6..847f5aa64bff547e2bca19d04b35e7c544904c40 100644 (file)
@@ -36,8 +36,6 @@ struct ustream_ssl_ctx {
        void *debug_cb_priv;
 };
 
-void __ustream_ssl_session_free(void *ssl);
-
 struct bio_ctx {
        BIO_METHOD *meth;
        struct ustream *stream;
index d3048ca2ded69f4ced3a0ddd778b1e6e923fdd20..b07629931525fe285f94b8c28b50d95b04fed5b8 100644 (file)
@@ -67,9 +67,8 @@ static void ustream_ssl_check_conn(struct ustream_ssl *us)
        }
 }
 
-static bool __ustream_ssl_poll(struct ustream *s)
+static bool __ustream_ssl_poll(struct ustream_ssl *us)
 {
-       struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
        char *buf;
        int len, ret;
        bool more = false;
@@ -85,7 +84,8 @@ static bool __ustream_ssl_poll(struct ustream *s)
 
                ret = __ustream_ssl_read(us, buf, len);
                if (ret == U_SSL_PENDING) {
-                       ustream_poll(us->conn);
+                       if (us->conn)
+                               ustream_poll(us->conn);
                        ret = __ustream_ssl_read(us, buf, len);
                }
 
@@ -110,7 +110,9 @@ static bool __ustream_ssl_poll(struct ustream *s)
 
 static void ustream_ssl_notify_read(struct ustream *s, int bytes)
 {
-       __ustream_ssl_poll(s);
+       struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
+
+       __ustream_ssl_poll(us);
 }
 
 static void ustream_ssl_notify_write(struct ustream *s, int bytes)
@@ -134,7 +136,7 @@ static int ustream_ssl_write(struct ustream *s, const char *buf, int len, bool m
        if (!us->connected || us->error)
                return 0;
 
-       if (us->conn->w.data_bytes)
+       if (us->conn && us->conn->w.data_bytes)
                return 0;
 
        return __ustream_ssl_write(us, buf, len);
@@ -143,8 +145,17 @@ static int ustream_ssl_write(struct ustream *s, const char *buf, int len, bool m
 static void ustream_ssl_set_read_blocked(struct ustream *s)
 {
        struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
+       unsigned int ev = ULOOP_WRITE | ULOOP_EDGE_TRIGGER;
+
+       if (us->conn) {
+               ustream_set_read_blocked(us->conn, !!s->read_blocked);
+               return;
+       }
+
+       if (!s->read_blocked)
+               ev |= ULOOP_READ;
 
-       ustream_set_read_blocked(us->conn, !!s->read_blocked);
+       uloop_fd_add(&us->fd, ev);
 }
 
 static void ustream_ssl_free(struct ustream *s)
@@ -156,10 +167,12 @@ static void ustream_ssl_free(struct ustream *s)
                us->conn->notify_read = NULL;
                us->conn->notify_write = NULL;
                us->conn->notify_state = NULL;
+       } else {
+               uloop_fd_delete(&us->fd);
        }
 
        uloop_timeout_cancel(&us->error_timer);
-       __ustream_ssl_session_free(us->ssl);
+       __ustream_ssl_session_free(us);
        free(us->peer_cn);
 
        us->ctx = NULL;
@@ -175,10 +188,19 @@ static void ustream_ssl_free(struct ustream *s)
 static bool ustream_ssl_poll(struct ustream *s)
 {
        struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
-       bool fd_poll;
+       bool fd_poll = false;
+
+       if (us->conn)
+               fd_poll = ustream_poll(us->conn);
+
+       return __ustream_ssl_poll(us) || fd_poll;
+}
+
+static void ustream_ssl_fd_cb(struct uloop_fd *fd, unsigned int events)
+{
+       struct ustream_ssl *us = container_of(fd, struct ustream_ssl, fd);
 
-       fd_poll = ustream_poll(us->conn);
-       return __ustream_ssl_poll(us->conn) || fd_poll;
+       __ustream_ssl_poll(us);
 }
 
 static void ustream_ssl_stream_init(struct ustream_ssl *us)
@@ -186,31 +208,31 @@ static void ustream_ssl_stream_init(struct ustream_ssl *us)
        struct ustream *conn = us->conn;
        struct ustream *s = &us->stream;
 
-       conn->notify_read = ustream_ssl_notify_read;
-       conn->notify_write = ustream_ssl_notify_write;
-       conn->notify_state = ustream_ssl_notify_state;
+       if (conn) {
+               conn->notify_read = ustream_ssl_notify_read;
+               conn->notify_write = ustream_ssl_notify_write;
+               conn->notify_state = ustream_ssl_notify_state;
+       } else {
+               us->fd.cb = ustream_ssl_fd_cb;
+               uloop_fd_add(&us->fd, ULOOP_READ | ULOOP_WRITE | ULOOP_EDGE_TRIGGER);
+       }
 
+       s->set_read_blocked = ustream_ssl_set_read_blocked;
        s->free = ustream_ssl_free;
        s->write = ustream_ssl_write;
        s->poll = ustream_ssl_poll;
-       s->set_read_blocked = ustream_ssl_set_read_blocked;
        ustream_init_defaults(s);
 }
 
-static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, struct ustream_ssl_ctx *ctx, bool server)
+static int _ustream_ssl_init_common(struct ustream_ssl *us)
 {
        us->error_timer.cb = ustream_ssl_error_cb;
-       us->server = server;
-       us->conn = conn;
-       us->ctx = ctx;
 
        us->ssl = __ustream_ssl_session_new(us->ctx);
        if (!us->ssl)
                return -ENOMEM;
 
-       conn->r.max_buffers = 4;
-       conn->next = &us->stream;
-       ustream_set_io(ctx, us->ssl, conn);
+       ustream_set_io(us);
        ustream_ssl_stream_init(us);
 
        if (us->server_name)
@@ -221,6 +243,27 @@ static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, struc
        return 0;
 }
 
+static int _ustream_ssl_init_fd(struct ustream_ssl *us, int fd, struct ustream_ssl_ctx *ctx, bool server)
+{
+       us->server = server;
+       us->ctx = ctx;
+       us->fd.fd = fd;
+
+       return _ustream_ssl_init_common(us);
+}
+
+static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, struct ustream_ssl_ctx *ctx, bool server)
+{
+       us->server = server;
+       us->ctx = ctx;
+
+       us->conn = conn;
+       conn->r.max_buffers = 4;
+       conn->next = &us->stream;
+
+       return _ustream_ssl_init_common(us);
+}
+
 static int _ustream_ssl_set_peer_cn(struct ustream_ssl *us, const char *name)
 {
        us->peer_cn = strdup(name);
@@ -239,5 +282,6 @@ const struct ustream_ssl_ops ustream_ssl_ops = {
        .context_set_debug = __ustream_ssl_set_debug,
        .context_free = __ustream_ssl_context_free,
        .init = _ustream_ssl_init,
+       .init_fd = _ustream_ssl_init_fd,
        .set_peer_cn = _ustream_ssl_set_peer_cn,
 };
index b1115c6451a8e1a6e5f6dcc958027d561786f9dc..fe545f48961661e7034442c8f6c700ded9bd3d66 100644 (file)
@@ -25,6 +25,7 @@ struct ustream_ssl {
        struct ustream stream;
        struct ustream *conn;
        struct uloop_timeout error_timer;
+       struct uloop_fd fd;
 
        void (*notify_connected)(struct ustream_ssl *us);
        void (*notify_error)(struct ustream_ssl *us, int error, const char *str);
@@ -56,6 +57,7 @@ struct ustream_ssl_ops {
        int (*context_add_ca_crt_file)(struct ustream_ssl_ctx *ctx, const char *file);
        void (*context_free)(struct ustream_ssl_ctx *ctx);
 
+       int (*init_fd)(struct ustream_ssl *us, int fd, struct ustream_ssl_ctx *ctx, bool server);
        int (*init)(struct ustream_ssl *us, struct ustream *conn, struct ustream_ssl_ctx *ctx, bool server);
        int (*set_peer_cn)(struct ustream_ssl *conn, const char *name);