pex: add support for figuring out the external data port via STUN servers
authorFelix Fietkau <nbd@nbd.name>
Fri, 16 Sep 2022 09:00:15 +0000 (11:00 +0200)
committerFelix Fietkau <nbd@nbd.name>
Fri, 16 Sep 2022 16:56:53 +0000 (18:56 +0200)
When establishing a direct connection on the auth/PEX port via DHT, both sides
need to know the externally mapped data port number in order to establish a
wireguard connection.
If there is an existing data connection, the port can be queried via PEX
over the tunnel. If that is not available, an external public server is needed
in order to poke a hole in the NAT. The easiest way to do this is to use
STUN, since there are a lot of public servers available.

The servers can be configured via the network data, based on the assumption,
that an auth exchange with network data update can be done directly

Signed-off-by: Felix Fietkau <nbd@nbd.name>
12 files changed:
CMakeLists.txt
network.c
network.h
pex-msg.h
pex-stun.c [new file with mode: 0644]
pex.c
pex.h
stun.c [new file with mode: 0644]
stun.h [new file with mode: 0644]
utils.h
wg-linux.c
wg-user.c

index 806e3bf4256ba02b6ffcfb05c336f17f1f168bd9..ec31711dc7ee108c85eb69f677310e969ddf353d 100644 (file)
@@ -4,7 +4,7 @@ PROJECT(unetd C)
 
 
 SET(SOURCES
-       main.c network.c host.c service.c pex.c
+       main.c network.c host.c service.c pex.c pex-stun.c
        wg.c wg-user.c
 )
 
@@ -43,7 +43,7 @@ ELSE()
   SET(ubus "")
 ENDIF()
 
-ADD_LIBRARY(unet SHARED curve25519.c siphash.c sha512.c fprime.c f25519.c ed25519.c edsign.c auth-data.c chacha20.c pex-msg.c utils.c)
+ADD_LIBRARY(unet SHARED curve25519.c siphash.c sha512.c fprime.c f25519.c ed25519.c edsign.c auth-data.c chacha20.c pex-msg.c utils.c stun.c)
 TARGET_LINK_LIBRARIES(unet ubox)
 
 ADD_EXECUTABLE(unetd ${SOURCES})
index 44defdc56b750f98fcea634e841c9a4dc2444257..4a17af62af0a8bfd4d100eebfb388848838e3644 100644 (file)
--- a/network.c
+++ b/network.c
@@ -32,6 +32,7 @@ enum {
        NETCONF_ATTR_PORT,
        NETCONF_ATTR_PEX_PORT,
        NETCONF_ATTR_KEEPALIVE,
+       NETCONF_ATTR_STUN_SERVERS,
        __NETCONF_ATTR_MAX
 };
 
@@ -40,6 +41,7 @@ static const struct blobmsg_policy netconf_policy[__NETCONF_ATTR_MAX] = {
        [NETCONF_ATTR_PORT] = { "port", BLOBMSG_TYPE_INT32 },
        [NETCONF_ATTR_PEX_PORT] = { "peer-exchange-port", BLOBMSG_TYPE_INT32 },
        [NETCONF_ATTR_KEEPALIVE] = { "keepalive", BLOBMSG_TYPE_INT32 },
+       [NETCONF_ATTR_STUN_SERVERS] = { "stun-servers", BLOBMSG_TYPE_ARRAY },
 };
 
 const struct blobmsg_policy network_policy[__NETWORK_ATTR_MAX] = {
@@ -61,6 +63,15 @@ const struct blobmsg_policy network_policy[__NETWORK_ATTR_MAX] = {
 AVL_TREE(networks, avl_strcmp, false, NULL);
 static struct blob_buf b;
 
+static void network_load_stun_servers(struct network *net, struct blob_attr *data)
+{
+       struct blob_attr *cur;
+       int rem;
+
+       blobmsg_for_each_attr(cur, data, rem)
+               network_stun_server_add(net, blobmsg_get_string(cur));
+}
+
 static void network_load_config_data(struct network *net, struct blob_attr *data)
 {
        struct blob_attr *tb[__NETCONF_ATTR_MAX];
@@ -95,6 +106,10 @@ static void network_load_config_data(struct network *net, struct blob_attr *data
                net->net_config.keepalive = blobmsg_get_u32(cur);
        else
                net->net_config.keepalive = 0;
+
+       if ((cur = tb[NETCONF_ATTR_STUN_SERVERS]) != NULL &&
+           blobmsg_check_array(cur, BLOBMSG_TYPE_STRING) > 0)
+               network_load_stun_servers(net, cur);
 }
 
 static int network_load_data(struct network *net, struct blob_attr *data)
@@ -398,6 +413,7 @@ static void network_reload(struct uloop_timeout *t)
 
        memset(&net->net_config, 0, sizeof(net->net_config));
 
+       network_stun_free(net);
        network_pex_close(net);
        network_services_free(net);
        network_hosts_update_start(net);
@@ -424,6 +440,7 @@ static void network_reload(struct uloop_timeout *t)
        unetd_write_hosts();
        network_do_update(net, true);
        network_pex_open(net);
+       network_stun_start(net);
        unetd_ubus_notify(net);
 }
 
@@ -469,6 +486,7 @@ static void network_teardown(struct network *net)
        uloop_timeout_cancel(&net->connect_timer);
        uloop_timeout_cancel(&net->reload_timer);
        network_do_update(net, false);
+       network_stun_free(net);
        network_pex_close(net);
        network_pex_free(net);
        network_hosts_free(net);
@@ -600,6 +618,7 @@ network_alloc(const char *name)
        avl_insert(&networks, &net->node);
 
        network_pex_init(net);
+       network_stun_init(net);
        network_hosts_init(net);
        network_services_init(net);
 
index 344c5d23e61679dc8bc760a29b63f390b9eb6206..dc53bb1a84a41879855f1597892d04301125506d 100644 (file)
--- a/network.h
+++ b/network.h
@@ -49,6 +49,7 @@ struct network {
                int port;
                int pex_port;
                bool local_host_changed;
+               struct blob_attr *stun_list;
        } net_config;
 
        void *net_data;
@@ -71,6 +72,7 @@ struct network {
        struct uloop_timeout connect_timer;
 
        struct network_pex pex;
+       struct network_stun stun;
 };
 
 enum {
index b365aeb5bcb166bb5f947d227acef879a8affb12..3ca1984063b37e190103295b5e9b35ba1c4cc179 100644 (file)
--- a/pex-msg.h
+++ b/pex-msg.h
@@ -24,6 +24,7 @@ enum pex_opcode {
        PEX_MSG_UPDATE_RESPONSE_DATA,
        PEX_MSG_UPDATE_RESPONSE_NO_DATA,
        PEX_MSG_ENDPOINT_NOTIFY,
+       PEX_MSG_ENDPOINT_PORT_NOTIFY,
 };
 
 #define PEX_ID_LEN             8
@@ -76,6 +77,10 @@ struct pex_update_response_no_data {
        uint64_t cur_version;
 };
 
+struct pex_endpoint_port_notify {
+       uint16_t port;
+};
+
 struct pex_msg_update_send_ctx {
        const uint8_t *pubkey;
        const uint8_t *auth_key;
diff --git a/pex-stun.c b/pex-stun.c
new file mode 100644 (file)
index 0000000..444612b
--- /dev/null
@@ -0,0 +1,352 @@
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/udp.h>
+#include <string.h>
+#include <errno.h>
+
+#include <libubox/usock.h>
+
+#include "unetd.h"
+
+static inline int avl_stun_cmp(const void *k1, const void *k2, void *priv)
+{
+       return memcmp(k1, k2, 12);
+}
+
+static bool has_connected_peer(struct network *net, bool pex)
+{
+       struct network_peer *peer;
+
+       vlist_for_each_element(&net->peers, peer, node) {
+               if (pex && !peer->pex_port)
+                       continue;
+
+               if (peer->state.connected)
+                       return true;
+       }
+
+       return false;
+}
+
+void network_stun_server_add(struct network *net, const char *host)
+{
+       struct network_stun *stun = &net->stun;
+       struct network_stun_server *s;
+       char *name_buf;
+
+       s = calloc_a(sizeof(*s), &name_buf, strlen(host) + 1);
+       s->pending_node.key = s->req.transaction;
+       s->host = strcpy(name_buf, host);
+
+       list_add_tail(&s->list, &stun->servers);
+}
+
+static void
+network_stun_close_socket(struct network *net)
+{
+       struct network_host *local = net->net_config.local_host;
+       struct network_stun *stun = &net->stun;
+
+       if (!stun->wgport_disabled)
+               return;
+
+       D_NET(net, "close STUN socket");
+       uloop_fd_delete(&stun->socket);
+       close(stun->socket.fd);
+       wg_init_local(net, &local->peer);
+       stun->wgport_disabled = false;
+}
+
+static void
+network_stun_socket_cb(struct uloop_fd *fd, unsigned int events)
+{
+       struct network_stun *stun = container_of(fd, struct network_stun, socket);
+       struct network *net = container_of(stun, struct network, stun);
+       char buf[1024];
+       ssize_t len;
+
+       while (1) {
+               len = recv(fd->fd, buf, sizeof(buf), 0);
+               if (len < 0) {
+                       if (errno == EAGAIN)
+                               break;
+                       if (errno == EINTR)
+                               continue;
+
+                       perror("recv");
+                       network_stun_close_socket(net);
+                       return;
+               }
+
+               if (!stun_msg_is_valid(buf, len))
+                       continue;
+
+               network_stun_rx_packet(net, buf, len);
+       }
+}
+
+static void
+network_stun_open_socket(struct network *net)
+{
+       struct network_host *local = net->net_config.local_host;
+       struct network_stun *stun = &net->stun;
+       int fd;
+
+       if (stun->wgport_disabled)
+               return;
+
+       D_NET(net, "open STUN socket");
+       wg_init_local(net, NULL);
+
+       fd = usock(USOCK_SERVER | USOCK_UDP | USOCK_IPV4ONLY | USOCK_NONBLOCK,
+                  NULL, usock_port(stun->port_local));
+       if (fd < 0) {
+               wg_init_local(net, &local->peer);
+               return;
+       }
+
+       stun->socket.fd = fd;
+       uloop_fd_add(&stun->socket, ULOOP_READ);
+       stun->wgport_disabled = true;
+}
+
+static bool
+network_stun_query_next(struct network *net)
+{
+       struct network_stun *stun = &net->stun;
+       struct network_stun_server *s;
+       char addrstr[INET6_ADDRSTRLEN];
+       union network_endpoint ep;
+       uint16_t res_port = 0;
+       const void *msg;
+       ssize_t ret;
+       size_t len;
+
+       s = list_first_entry(&stun->servers, struct network_stun_server, list);
+       if (s->pending)
+               return false;
+
+       /* send next query */
+       if (network_get_endpoint(&ep, AF_INET, s->host, 0, s->seq++) < 0) {
+               D_NET(net, "lookup failed for STUN host %s", s->host);
+               goto out;
+       }
+
+       if (ep.sa.sa_family != AF_INET || !ep.in.sin_port)
+               goto out;
+
+       if (!stun->wgport_disabled && stun->auth_port_ext)
+               res_port = stun->auth_port_ext;
+
+       D_NET(net, "Send STUN query to %s, res_port=%d, wg_disabled=%d",
+             inet_ntop(ep.sa.sa_family, network_endpoint_addr(&ep, NULL),
+                       addrstr, sizeof(addrstr)), res_port, stun->wgport_disabled);
+       msg = stun_msg_request_prepare(&s->req, &len, res_port);
+       if (!msg)
+               goto out;
+
+retry:
+       s->req_auth_port = false;
+       if (stun->wgport_disabled) {
+               ret = sendto(stun->socket.fd, msg, len, 0, &ep.sa, sizeof(ep.in));
+       } else if (!stun->auth_port_ext) {
+               s->req_auth_port = true;
+               ret = sendto(pex_socket(), msg, len, 0, &ep.sa, sizeof(ep.in));
+       } else {
+               struct {
+                   struct ip ip;
+                   struct udphdr udp;
+               } packet_hdr = {};
+               union network_addr local_addr = {};
+
+               network_get_local_addr(&local_addr, &ep);
+               packet_hdr.ip = (struct ip){
+                       .ip_hl = 5,
+                       .ip_v = 4,
+                       .ip_ttl = 64,
+                       .ip_p = IPPROTO_UDP,
+                       .ip_src = local_addr.in,
+                       .ip_dst = ep.in.sin_addr,
+               };
+               packet_hdr.udp = (struct udphdr){
+                       .uh_sport = htons(stun->port_local),
+                       .uh_dport = ep.in.sin_port,
+               };
+               ep.in.sin_port = 0;
+
+               ret = sendto_rawudp(pex_raw_socket(AF_INET), &ep,
+                                   &packet_hdr, sizeof(packet_hdr),
+                                   msg, len);
+       }
+
+       if (ret < 0 && errno == EINTR)
+               goto retry;
+
+out:
+       avl_insert(&stun->pending, &s->pending_node);
+       s->pending = true;
+
+       if (!list_is_last(&s->list, &stun->servers))
+               list_move_tail(&s->list, &stun->servers);
+
+       return true;
+}
+
+static void
+network_stun_query_clear_pending(struct network *net)
+{
+       struct network_stun *stun = &net->stun;
+       struct network_stun_server *s;
+
+       list_for_each_entry(s, &stun->servers, list) {
+               if (!s->pending)
+                       continue;
+
+               avl_delete(&stun->pending, &s->pending_node);
+               s->pending = false;
+       }
+}
+
+void network_stun_rx_packet(struct network *net, const void *data, size_t len)
+{
+       struct network_stun *stun = &net->stun;
+       const struct stun_msg_hdr *hdr = data;
+       struct network_stun_server *s;
+
+       s = avl_find_element(&stun->pending, hdr->transaction, s, pending_node);
+       if (!s)
+               return;
+
+       if (!stun_msg_request_complete(&s->req, data, len))
+               return;
+
+       if (!s->req.port)
+               return;
+
+       network_stun_update_port(net, s->req_auth_port, s->req.port);
+       if (s->req_auth_port)
+               stun->state = STUN_STATE_STUN_QUERY_SEND;
+       else
+               stun->state = STUN_STATE_IDLE;
+
+       network_stun_query_clear_pending(net);
+
+       uloop_timeout_set(&stun->timer, 1);
+}
+
+static void
+network_stun_timer_cb(struct uloop_timeout *t)
+{
+       struct network_stun *stun = container_of(t, struct network_stun, timer);
+       struct network *net = container_of(stun, struct network, stun);
+       unsigned int next = 0;
+
+restart:
+       switch (stun->state) {
+       case STUN_STATE_IDLE:
+               network_stun_close_socket(net);
+               next = 15 * 60 * 1000;
+               stun->state = STUN_STATE_STUN_QUERY_SEND;
+               D_NET(net, "STUN idle");
+               break;
+       case STUN_STATE_PEX_QUERY_WAIT:
+               stun->state = STUN_STATE_STUN_QUERY_SEND;
+               fallthrough;
+       case STUN_STATE_STUN_QUERY_SEND:
+               if (network_stun_query_next(net)) {
+                       next = 50;
+                       break;
+               }
+
+               stun->state = STUN_STATE_STUN_QUERY_WAIT;
+               D_NET(net, "wait for STUN server responses");
+               next = 1000;
+               break;
+       case STUN_STATE_STUN_QUERY_WAIT:
+               D_NET(net, "timeout waiting for STUN server responses, retry=%d", stun->retry);
+               network_stun_query_clear_pending(net);
+               if (stun->retry > 0) {
+                       stun->retry--;
+                       stun->state = STUN_STATE_STUN_QUERY_SEND;
+                       goto restart;
+               }
+
+               if (!stun->port_ext && !stun->wgport_disabled) {
+                       network_stun_open_socket(net);
+                       stun->state = STUN_STATE_STUN_QUERY_SEND;
+                       stun->retry = 2;
+               } else {
+                       stun->state = STUN_STATE_IDLE;
+               }
+               goto restart;
+       }
+
+       if (next)
+               uloop_timeout_set(t, next);
+}
+
+void network_stun_update_port(struct network *net, bool auth, uint16_t val)
+{
+       struct network_stun *stun = &net->stun;
+       uint16_t *port = auth ? &stun->auth_port_ext : &stun->port_ext;
+
+       D_NET(net, "Update external %s port: %d", auth ? "auth" : "data", val);
+       *port = val;
+}
+
+void network_stun_start(struct network *net)
+{
+       struct network_host *local = net->net_config.local_host;
+       struct network_stun *stun = &net->stun;
+       unsigned int next = 1;
+
+       if (!local || list_empty(&stun->servers))
+               return;
+
+       if (local->peer.port != stun->port_local) {
+               stun->port_ext = 0;
+               stun->port_local = local->peer.port;
+       }
+
+       if (!stun->port_ext && has_connected_peer(net, true)) {
+               D_NET(net, "wait for port information from PEX");
+               stun->state = STUN_STATE_PEX_QUERY_WAIT;
+               next = 60 * 1000;
+       } else {
+               if (!stun->port_ext && !has_connected_peer(net, false))
+                       network_stun_open_socket(net);
+
+               stun->state = STUN_STATE_STUN_QUERY_SEND;
+               stun->retry = 2;
+       }
+
+       uloop_timeout_set(&stun->timer, next);
+}
+
+void network_stun_init(struct network *net)
+{
+       struct network_stun *stun = &net->stun;
+
+       stun->socket.cb = network_stun_socket_cb;
+       stun->timer.cb = network_stun_timer_cb;
+       INIT_LIST_HEAD(&stun->servers);
+       avl_init(&stun->pending, avl_stun_cmp, true, NULL);
+}
+
+void network_stun_free(struct network *net)
+{
+       struct network_stun *stun = &net->stun;
+       struct network_stun_server *s, *tmp;
+
+       uloop_timeout_cancel(&stun->timer);
+       network_stun_close_socket(net);
+
+       avl_remove_all_elements(&stun->pending, s, pending_node, tmp)
+               s->pending = false;
+
+       list_for_each_entry_safe(s, tmp, &stun->servers, list) {
+               list_del(&s->list);
+               free(s);
+       }
+}
diff --git a/pex.c b/pex.c
index c8b073104029f060d1c4b17e8aabfba397affc24..3f28f5137bcc9979e106469de17861ae0e222ba5 100644 (file)
--- a/pex.c
+++ b/pex.c
@@ -166,7 +166,7 @@ network_pex_handle_endpoint_change(struct network *net, struct network_peer *pee
 }
 
 static void
-network_pex_host_request_update(struct network *net, struct network_pex_host *host)
+network_pex_host_send_endpoint_notify(struct network *net, struct network_pex_host *host)
 {
        union {
                struct {
@@ -179,32 +179,10 @@ network_pex_host_request_update(struct network *net, struct network_pex_host *ho
                } ipv6;
        } packet = {};
        struct udphdr *udp;
-       char addrstr[INET6_ADDRSTRLEN];
        union network_endpoint dest_ep;
        union network_addr local_addr = {};
-       uint64_t version = 0;
        int len;
 
-       if (net->net_data_len)
-               version = net->net_data_version;
-
-       D("request network data from host %s",
-         inet_ntop(host->endpoint.sa.sa_family,
-                   (host->endpoint.sa.sa_family == AF_INET6 ?
-                    (const void *)&host->endpoint.in6.sin6_addr :
-                    (const void *)&host->endpoint.in.sin_addr),
-                   addrstr, sizeof(addrstr)));
-
-       if (!pex_msg_update_request_init(net->config.pubkey, net->config.key,
-                                        net->config.auth_key, &host->endpoint,
-                                        version, true))
-               return;
-
-       __pex_msg_send(-1, &host->endpoint, NULL, 0);
-
-       if (!net->net_config.local_host)
-               return;
-
        pex_msg_init_ext(net, PEX_MSG_ENDPOINT_NOTIFY, true);
 
        memcpy(&dest_ep, &host->endpoint, sizeof(dest_ep));
@@ -252,6 +230,53 @@ network_pex_host_request_update(struct network *net, struct network_pex_host *ho
                D_NET(net, "pex_msg_send_raw failed: %s", strerror(errno));
 }
 
+
+static void
+network_pex_host_send_port_notify(struct network *net, struct network_pex_host *host)
+{
+       struct pex_endpoint_port_notify *data;
+
+       if (!net->stun.port_ext)
+               return;
+
+       pex_msg_init_ext(net, PEX_MSG_ENDPOINT_PORT_NOTIFY, true);
+
+       data = pex_msg_append(sizeof(*data));
+       data->port = htons(net->stun.port_ext);
+
+       __pex_msg_send(-1, &host->endpoint, NULL, 0);
+}
+
+static void
+network_pex_host_request_update(struct network *net, struct network_pex_host *host)
+{
+       char addrstr[INET6_ADDRSTRLEN];
+       uint64_t version = 0;
+
+       if (net->net_data_len)
+               version = net->net_data_version;
+
+       D("request network data from host %s",
+         inet_ntop(host->endpoint.sa.sa_family,
+                   (host->endpoint.sa.sa_family == AF_INET6 ?
+                    (const void *)&host->endpoint.in6.sin6_addr :
+                    (const void *)&host->endpoint.in.sin_addr),
+                   addrstr, sizeof(addrstr)));
+
+       if (!pex_msg_update_request_init(net->config.pubkey, net->config.key,
+                                        net->config.auth_key, &host->endpoint,
+                                        version, true))
+               return;
+
+       __pex_msg_send(-1, &host->endpoint, NULL, 0);
+
+       if (!net->net_config.local_host)
+               return;
+
+       network_pex_host_send_port_notify(net, host);
+       network_pex_host_send_endpoint_notify(net, host);
+}
+
 static void
 network_pex_request_update_cb(struct uloop_timeout *t)
 {
@@ -300,9 +325,8 @@ network_pex_query_hosts(struct network *net)
                struct network_peer *peer = &host->peer;
                void *id;
 
-               if (host == net->net_config.local_host ||
-                   peer->state.connected ||
-                   peer->endpoint)
+               if ((net->stun.port_ext && host == net->net_config.local_host) ||
+                   peer->state.connected || peer->endpoint)
                        continue;
 
                id = pex_msg_append(PEX_ID_LEN);
@@ -434,11 +458,13 @@ network_pex_recv_peers(struct network *net, struct network_peer *peer,
                void *addr;
                int len;
 
-               cur = pex_msg_peer(net, data->peer_id);
-               if (!cur)
+               if (!memcmp(data->peer_id, &local->key, PEX_ID_LEN)) {
+                       network_stun_update_port(net, false, ntohs(data->port));
                        continue;
+               }
 
-               if (cur == peer || cur == local)
+               cur = pex_msg_peer(net, data->peer_id);
+               if (!cur || cur == peer)
                        continue;
 
                D_PEER(net, peer, "received peer address for %s",
@@ -863,6 +889,11 @@ global_pex_recv(void *msg, size_t msg_len, struct sockaddr_in6 *addr)
        void *data;
        int addr_len;
 
+       if (stun_msg_is_valid(msg, msg_len)) {
+               avl_for_each_element(&networks, net, node)
+                       network_stun_rx_packet(net, msg, msg_len);
+       }
+
        hdr = pex_rx_accept(msg, msg_len, true);
        if (!hdr)
                return;
@@ -899,6 +930,9 @@ global_pex_recv(void *msg, size_t msg_len, struct sockaddr_in6 *addr)
        case PEX_MSG_UPDATE_RESPONSE_NO_DATA:
                network_pex_recv_update_response(net, data, hdr->len, addr, hdr->opcode);
                break;
+       case PEX_MSG_ENDPOINT_PORT_NOTIFY:
+               if (hdr->len < sizeof(struct pex_endpoint_port_notify))
+                       break;
        case PEX_MSG_ENDPOINT_NOTIFY:
                peer = pex_msg_peer(net, hdr->id);
                if (!peer)
@@ -909,6 +943,11 @@ global_pex_recv(void *msg, size_t msg_len, struct sockaddr_in6 *addr)
                            buf, sizeof(buf)));
 
                memcpy(&peer->state.next_endpoint, addr, sizeof(*addr));
+               if (hdr->opcode == PEX_MSG_ENDPOINT_PORT_NOTIFY) {
+                       struct pex_endpoint_port_notify *port = data;
+
+                       peer->state.next_endpoint.in.sin_port = port->port;
+               }
                break;
        }
 }
diff --git a/pex.h b/pex.h
index d5d08878791591019ba67357e23de86cf53639dd..acaf372befc83b82aa8e04845fe6d8559e6b805b 100644 (file)
--- a/pex.h
+++ b/pex.h
@@ -5,7 +5,9 @@
 #ifndef __UNETD_PEX_H
 #define __UNETD_PEX_H
 
+#include <sys/socket.h>
 #include <libubox/uloop.h>
+#include "stun.h"
 
 struct network;
 
@@ -22,6 +24,43 @@ struct network_pex {
        struct uloop_timeout request_update_timer;
 };
 
+enum network_stun_state {
+       STUN_STATE_IDLE,
+       STUN_STATE_PEX_QUERY_WAIT,
+       STUN_STATE_STUN_QUERY_SEND,
+       STUN_STATE_STUN_QUERY_WAIT,
+};
+
+struct network_stun_server {
+       struct list_head list;
+
+       struct avl_node pending_node;
+       struct stun_request req;
+
+       const char *host;
+       uint8_t seq;
+       bool req_auth_port;
+       bool pending;
+};
+
+struct network_stun {
+       struct list_head servers;
+       struct avl_tree pending;
+
+       struct uloop_timeout timer;
+
+       enum network_stun_state state;
+       bool wgport_disabled;
+
+       uint16_t auth_port_ext;
+       uint16_t port_local;
+       uint16_t port_ext;
+
+       int retry;
+
+       struct uloop_fd socket;
+};
+
 enum pex_event {
        PEX_EV_HANDSHAKE,
        PEX_EV_ENDPOINT_CHANGE,
@@ -39,6 +78,13 @@ void network_pex_event(struct network *net, struct network_peer *peer,
 void network_pex_create_host(struct network *net, union network_endpoint *ep,
                             unsigned int timeout);
 
+void network_stun_init(struct network *net);
+void network_stun_free(struct network *net);
+void network_stun_server_add(struct network *net, const char *host);
+void network_stun_rx_packet(struct network *net, const void *data, size_t len);
+void network_stun_update_port(struct network *net, bool auth, uint16_t val);
+void network_stun_start(struct network *net);
+
 static inline bool network_pex_active(struct network_pex *pex)
 {
        return pex->fd.fd >= 0;
diff --git a/stun.c b/stun.c
new file mode 100644 (file)
index 0000000..47e8b7b
--- /dev/null
+++ b/stun.c
@@ -0,0 +1,177 @@
+// SPDX-License-Identifier: GPL-2.0-or-later
+/*
+ * Copyright (C) 2022 Felix Fietkau <nbd@nbd.name>
+ */
+#include <sys/types.h>
+#include <arpa/inet.h>
+#include <string.h>
+#include <stdio.h>
+#include "stun.h"
+
+static uint8_t tx_buf[256];
+
+bool stun_msg_is_valid(const void *data, size_t len)
+{
+       const struct stun_msg_hdr *hdr = data;
+
+       if (len <= sizeof(*hdr))
+               return false;
+
+       return hdr->magic == htonl(STUN_MAGIC);
+}
+
+static void *stun_msg_init(uint16_t type)
+{
+       struct stun_msg_hdr *hdr = (struct stun_msg_hdr *)tx_buf;
+
+       memset(hdr, 0, sizeof(*hdr));
+       hdr->msg_type = htons(type);
+       hdr->magic = htonl(STUN_MAGIC);
+
+       return hdr;
+}
+
+static void *stun_msg_add_tlv(uint16_t type, uint16_t len)
+{
+       struct stun_msg_hdr *hdr = (struct stun_msg_hdr *)tx_buf;
+       uint16_t data_len = ntohs(hdr->msg_len);
+       struct stun_msg_tlv *tlv;
+       void *data = hdr + 1;
+
+       data += data_len;
+
+       tlv = data;
+       tlv->type = htons(type);
+       tlv->len = htons(len);
+
+       if (len & 3)
+               len = (len + 3) & ~3;
+
+       data_len += sizeof(*tlv) + len;
+       hdr->msg_len = htons(data_len);
+
+       return tlv + 1;
+}
+
+static void
+stun_msg_parse_attr(const struct stun_tlv_policy *policy,
+                   const struct stun_msg_tlv **tb, int len,
+                   const struct stun_msg_tlv *tlv)
+{
+       uint16_t type;
+       int i;
+
+       type = ntohs(tlv->type);
+
+       for (i = 0; i < len; i++) {
+               if (policy[i].type != type)
+                       continue;
+
+               if (ntohs(tlv->len) < policy[i].min_len)
+                       return;
+
+               tb[i] = tlv;
+               return;
+       }
+}
+
+static void
+stun_msg_parse(const struct stun_tlv_policy *policy,
+              const struct stun_msg_tlv **tb, int len,
+              const void *data, size_t data_len)
+{
+       const struct stun_msg_hdr *hdr = data;
+       const struct stun_msg_tlv *tlv;
+       const void *end = data + data_len;
+       uint16_t cur_len;
+
+       data += sizeof(*hdr);
+       while (1) {
+               tlv = data;
+               data = tlv + 1;
+               if (data > end)
+                       break;
+
+               cur_len = ntohs(tlv->len);
+               if (data + cur_len > end)
+                       break;
+
+               stun_msg_parse_attr(policy, tb, len, tlv);
+               data += (cur_len + 3) & ~3;
+       }
+}
+
+const void *stun_msg_request_prepare(struct stun_request *req, size_t *len,
+                                    uint16_t response_port)
+{
+       struct stun_msg_hdr *hdr;
+       FILE *f;
+
+       hdr = stun_msg_init(STUN_MSGTYPE_BINDING_REQUEST);
+       if (response_port) {
+               uint16_t *tlv_port = stun_msg_add_tlv(STUN_TLV_RESPONSE_PORT, 2);
+               *tlv_port = htons(response_port);
+       }
+
+       f = fopen("/dev/urandom", "r");
+       if (!f)
+               return NULL;
+
+       if (fread(hdr->transaction, 12, 1, f) != 1)
+               return NULL;
+
+       fclose(f);
+       memcpy(req->transaction, hdr->transaction, sizeof(req->transaction));
+       req->pending = true;
+       req->port = 0;
+       *len = htons(hdr->msg_len) + sizeof(*hdr);
+
+       return hdr;
+}
+
+bool stun_msg_request_complete(struct stun_request *req, const void *data,
+                              size_t len)
+{
+       enum {
+               PARSE_ATTR_MAPPED,
+               PARSE_ATTR_XOR_MAPPED,
+               __PARSE_ATTR_MAX
+       };
+       const struct stun_msg_tlv *tb[__PARSE_ATTR_MAX];
+       static const struct stun_tlv_policy policy[__PARSE_ATTR_MAX] = {
+               [PARSE_ATTR_MAPPED] = { STUN_TLV_MAPPED_ADDRESS, 8 },
+               [PARSE_ATTR_XOR_MAPPED] = { STUN_TLV_XOR_MAPPED_ADDRESS, 8 }
+       };
+       const struct stun_msg_hdr *hdr = data;
+       const void *tlv_data;
+       uint16_t port;
+
+       if (!req->pending)
+               return false;
+
+       if (!stun_msg_is_valid(data, len))
+               return false;
+
+       if (hdr->msg_type != htons(STUN_MSGTYPE_BINDING_RESPONSE))
+               return false;
+
+       if (memcmp(hdr->transaction, req->transaction, sizeof(hdr->transaction)) != 0)
+               return false;
+
+       stun_msg_parse(policy, tb, __PARSE_ATTR_MAX, data, len);
+
+       if (tb[PARSE_ATTR_XOR_MAPPED]) {
+               tlv_data = tb[PARSE_ATTR_XOR_MAPPED] + 1;
+               tlv_data += 2;
+               port = ntohs(*(const uint16_t *)tlv_data);
+               port ^= STUN_MAGIC >> 16;
+       } else if (tb[PARSE_ATTR_MAPPED]) {
+               tlv_data = tb[PARSE_ATTR_MAPPED] + 1;
+               tlv_data += 2;
+               port = ntohs(*(const uint16_t *)tlv_data);
+       } else
+               return false;
+
+       req->port = port;
+       return true;
+}
diff --git a/stun.h b/stun.h
new file mode 100644 (file)
index 0000000..64493a8
--- /dev/null
+++ b/stun.h
@@ -0,0 +1,57 @@
+#ifndef __UNETD_STUN_H
+#define __UNETD_STUN_H
+
+#include <stdint.h>
+#include <stdbool.h>
+
+#define STUN_MSGTYPE_BINDING_REQUEST           0x0001
+#define STUN_MSGTYPE_BINDING_RESPONSE          0x0101
+#define STUN_MSGTYPE_BINDING_ERROR             0x0111
+#define STUN_MSGTYPE_BINDING_INDICATION                0x0011
+
+#define STUN_MSGTYPE_SHARED_SECRET_REQUEST     0x0002
+#define STUN_MSGTYPE_SHARED_SECRET_RESPONSE    0x0102
+#define STUN_MSGTYPE_SHARED_SECRET_ERROR       0x0112
+
+#define STUN_MAGIC                             0x2112a442
+
+enum tlv_type {
+       STUN_TLV_MAPPED_ADDRESS =               0x01,
+       STUN_TLV_RESPONSE_ADDRESS =             0x02,
+       STUN_TLV_CHANGE_REQUEST =               0x03,
+       STUN_TLV_SOURCE_ADDRESS =               0x04,
+       STUN_TLV_CHANGED_ADDRESS =              0x05,
+       STUN_TLV_XOR_MAPPED_ADDRESS =           0x20,
+       STUN_TLV_RESPONSE_PORT =                0x27,
+};
+
+struct stun_msg_hdr {
+       uint16_t msg_type;
+       uint16_t msg_len;
+       uint32_t magic;
+       uint8_t transaction[12];
+};
+
+struct stun_msg_tlv {
+       uint16_t type;
+       uint16_t len;
+};
+
+struct stun_tlv_policy {
+       uint16_t type;
+       uint16_t min_len;
+};
+
+struct stun_request {
+       uint8_t transaction[12];
+       uint16_t port;
+       bool pending;
+};
+
+bool stun_msg_is_valid(const void *data, size_t len);
+const void *stun_msg_request_prepare(struct stun_request *req, size_t *len,
+                                    uint16_t response_port);
+bool stun_msg_request_complete(struct stun_request *req, const void *data,
+                              size_t len);
+
+#endif
diff --git a/utils.h b/utils.h
index c7fc2804f41099bd17266115da461ff0d47e2709..5d7acc85637bc69aa219c1dfc1495c3f4c86cf57 100644 (file)
--- a/utils.h
+++ b/utils.h
@@ -30,11 +30,13 @@ static inline void *
 network_endpoint_addr(union network_endpoint *ep, int *addr_len)
 {
        if (ep->sa.sa_family == AF_INET6) {
-               *addr_len = sizeof(ep->in6.sin6_addr);
+               if (addr_len)
+                       *addr_len = sizeof(ep->in6.sin6_addr);
                return &ep->in6.sin6_addr;
        }
 
-       *addr_len = sizeof(ep->in.sin_addr);
+       if (addr_len)
+               *addr_len = sizeof(ep->in.sin_addr);
        return &ep->in.sin_addr;
 }
 
index a9f37b9d443b1ba54c50119b7f5131c5ca38120b..9eaa9c4ec622b2b3e2bc76093243cf28951495c2 100644 (file)
@@ -102,7 +102,7 @@ wg_linux_init_local(struct network *net, struct network_peer *peer)
        struct nl_msg *msg;
 
        msg = wg_genl_msg(net, true);
-       nla_put_u16(msg, WGDEVICE_A_LISTEN_PORT, peer->port);
+       nla_put_u16(msg, WGDEVICE_A_LISTEN_PORT, peer ? peer->port : 0);
 
        return wg_genl_call(msg);
 }
index a057a10dd6329286bc63ec4290acc9f4c6d2a080..bf057e51131b00b6dbab9951e9b044be92e2c4d1 100644 (file)
--- a/wg-user.c
+++ b/wg-user.c
@@ -274,7 +274,7 @@ wg_user_init_local(struct network *net, struct network_peer *peer)
        if (wg_req_init(&req, net, true))
                return -1;
 
-       wg_req_set_int(&req, "listen_port", peer->port);
+       wg_req_set_int(&req, "listen_port", peer ? peer->port : 0);
 
        return wg_req_done(&req);
 }