return ret;
}
-__hidden void ustream_set_io(struct ustream_ssl_ctx *ctx, void *ssl, struct ustream *conn)
+static int s_fd_read(void *ctx, unsigned char *buf, size_t len)
{
- mbedtls_ssl_set_bio(ssl, conn, s_ustream_write, s_ustream_read, NULL);
+ struct uloop_fd *ufd = ctx;
+ mbedtls_net_context net = {
+ .fd = ufd->fd
+ };
+
+ return mbedtls_net_recv(&net, buf, len);
+}
+
+static int s_fd_write(void *ctx, const unsigned char *buf, size_t len)
+{
+ struct uloop_fd *ufd = ctx;
+ mbedtls_net_context net = {
+ .fd = ufd->fd
+ };
+
+ return mbedtls_net_send(&net, buf, len);
+}
+
+__hidden void ustream_set_io(struct ustream_ssl *us)
+{
+ if (us->conn)
+ mbedtls_ssl_set_bio(us->ssl, us->conn, s_ustream_write, s_ustream_read, NULL);
+ else
+ mbedtls_ssl_set_bio(us->ssl, &us->fd, s_fd_write, s_fd_read, NULL);
}
static int _random(void *ctx, unsigned char *out, size_t len)
return ssl;
}
-__hidden void __ustream_ssl_session_free(void *ssl)
+__hidden void __ustream_ssl_session_free(struct ustream_ssl *us)
{
- mbedtls_ssl_free(ssl);
- free(ssl);
+ mbedtls_ssl_free(us->ssl);
+ free(us->ssl);
}
}
}
-static bool __ustream_ssl_poll(struct ustream *s)
+static bool __ustream_ssl_poll(struct ustream_ssl *us)
{
- struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
char *buf;
int len, ret;
bool more = false;
ret = __ustream_ssl_read(us, buf, len);
if (ret == U_SSL_PENDING) {
- ustream_poll(us->conn);
+ if (us->conn)
+ ustream_poll(us->conn);
ret = __ustream_ssl_read(us, buf, len);
}
static void ustream_ssl_notify_read(struct ustream *s, int bytes)
{
- __ustream_ssl_poll(s);
+ struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
+
+ __ustream_ssl_poll(us);
}
static void ustream_ssl_notify_write(struct ustream *s, int bytes)
if (!us->connected || us->error)
return 0;
- if (us->conn->w.data_bytes)
+ if (us->conn && us->conn->w.data_bytes)
return 0;
return __ustream_ssl_write(us, buf, len);
static void ustream_ssl_set_read_blocked(struct ustream *s)
{
struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
+ unsigned int ev = ULOOP_WRITE | ULOOP_EDGE_TRIGGER;
+
+ if (us->conn) {
+ ustream_set_read_blocked(us->conn, !!s->read_blocked);
+ return;
+ }
+
+ if (!s->read_blocked)
+ ev |= ULOOP_READ;
- ustream_set_read_blocked(us->conn, !!s->read_blocked);
+ uloop_fd_add(&us->fd, ev);
}
static void ustream_ssl_free(struct ustream *s)
us->conn->notify_read = NULL;
us->conn->notify_write = NULL;
us->conn->notify_state = NULL;
+ } else {
+ uloop_fd_delete(&us->fd);
}
uloop_timeout_cancel(&us->error_timer);
- __ustream_ssl_session_free(us->ssl);
+ __ustream_ssl_session_free(us);
free(us->peer_cn);
us->ctx = NULL;
static bool ustream_ssl_poll(struct ustream *s)
{
struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
- bool fd_poll;
+ bool fd_poll = false;
+
+ if (us->conn)
+ fd_poll = ustream_poll(us->conn);
+
+ return __ustream_ssl_poll(us) || fd_poll;
+}
+
+static void ustream_ssl_fd_cb(struct uloop_fd *fd, unsigned int events)
+{
+ struct ustream_ssl *us = container_of(fd, struct ustream_ssl, fd);
- fd_poll = ustream_poll(us->conn);
- return __ustream_ssl_poll(us->conn) || fd_poll;
+ __ustream_ssl_poll(us);
}
static void ustream_ssl_stream_init(struct ustream_ssl *us)
struct ustream *conn = us->conn;
struct ustream *s = &us->stream;
- conn->notify_read = ustream_ssl_notify_read;
- conn->notify_write = ustream_ssl_notify_write;
- conn->notify_state = ustream_ssl_notify_state;
+ if (conn) {
+ conn->notify_read = ustream_ssl_notify_read;
+ conn->notify_write = ustream_ssl_notify_write;
+ conn->notify_state = ustream_ssl_notify_state;
+ } else {
+ us->fd.cb = ustream_ssl_fd_cb;
+ uloop_fd_add(&us->fd, ULOOP_READ | ULOOP_WRITE | ULOOP_EDGE_TRIGGER);
+ }
+ s->set_read_blocked = ustream_ssl_set_read_blocked;
s->free = ustream_ssl_free;
s->write = ustream_ssl_write;
s->poll = ustream_ssl_poll;
- s->set_read_blocked = ustream_ssl_set_read_blocked;
ustream_init_defaults(s);
}
-static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, struct ustream_ssl_ctx *ctx, bool server)
+static int _ustream_ssl_init_common(struct ustream_ssl *us)
{
us->error_timer.cb = ustream_ssl_error_cb;
- us->server = server;
- us->conn = conn;
- us->ctx = ctx;
us->ssl = __ustream_ssl_session_new(us->ctx);
if (!us->ssl)
return -ENOMEM;
- conn->r.max_buffers = 4;
- conn->next = &us->stream;
- ustream_set_io(ctx, us->ssl, conn);
+ ustream_set_io(us);
ustream_ssl_stream_init(us);
if (us->server_name)
return 0;
}
+static int _ustream_ssl_init_fd(struct ustream_ssl *us, int fd, struct ustream_ssl_ctx *ctx, bool server)
+{
+ us->server = server;
+ us->ctx = ctx;
+ us->fd.fd = fd;
+
+ return _ustream_ssl_init_common(us);
+}
+
+static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, struct ustream_ssl_ctx *ctx, bool server)
+{
+ us->server = server;
+ us->ctx = ctx;
+
+ us->conn = conn;
+ conn->r.max_buffers = 4;
+ conn->next = &us->stream;
+
+ return _ustream_ssl_init_common(us);
+}
+
static int _ustream_ssl_set_peer_cn(struct ustream_ssl *us, const char *name)
{
us->peer_cn = strdup(name);
.context_set_debug = __ustream_ssl_set_debug,
.context_free = __ustream_ssl_context_free,
.init = _ustream_ssl_init,
+ .init_fd = _ustream_ssl_init_fd,
.set_peer_cn = _ustream_ssl_set_peer_cn,
};