pex: add support for sending endpoint notification from the wg port via raw socket
authorFelix Fietkau <nbd@nbd.name>
Wed, 31 Aug 2022 11:03:39 +0000 (13:03 +0200)
committerFelix Fietkau <nbd@nbd.name>
Wed, 31 Aug 2022 11:03:41 +0000 (13:03 +0200)
This makes it possible to use the global PEX socket (used for network data updates)
to be used to receive the endpoint address in a way that works through NAT.

Signed-off-by: Felix Fietkau <nbd@nbd.name>
cli.c
pex-msg.c
pex-msg.h
pex.c

diff --git a/cli.c b/cli.c
index 43b58e3e2382ccbbdb34ce7254afa967dadc77dd..22b605a7b7abdd6c7d11fa4525faf3d478b4f821 100644 (file)
--- a/cli.c
+++ b/cli.c
@@ -158,7 +158,7 @@ pex_handle_update_request(struct sockaddr_in6 *addr, const uint8_t *id, void *da
        pex_msg_update_response_init(&ctx, empty_key, pubkey,
                                     peerpubkey, true, data, net_data, net_data_len);
        while (!done) {
-               __pex_msg_send(-1, NULL);
+               __pex_msg_send(-1, NULL, NULL, 0);
                done = !pex_msg_update_response_continue(&ctx);
        }
        sync_done = true;
@@ -287,7 +287,7 @@ static int cmd_sync(const char *endpoint, int argc, char **argv)
                return 1;
 
        req_id = req->req_id;
-       if (__pex_msg_send(-1, NULL) < 0) {
+       if (__pex_msg_send(-1, NULL, NULL, 0) < 0) {
                if (!quiet)
                        perror("send");
                return 1;
index c97cfeb496bf38656aef1e1393502fb814b9469f..43a6960493ad8b991f449cbdaf5ab0e9d3b19d14 100644 (file)
--- a/pex-msg.c
+++ b/pex-msg.c
@@ -9,6 +9,10 @@
 #include <fcntl.h>
 #include <libubox/list.h>
 #include <libubox/uloop.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip6.h>
+#include <netinet/udp.h>
 #include "pex-msg.h"
 #include "chacha20.h"
 #include "auth-data.h"
@@ -18,6 +22,7 @@ static FILE *pex_urandom;
 static struct uloop_fd pex_fd;
 static LIST_HEAD(requests);
 static struct uloop_timeout gc_timer;
+static int pex_raw_v4_fd = -1, pex_raw_v6_fd = -1;
 
 static pex_recv_cb_t pex_recv_cb;
 
@@ -141,27 +146,160 @@ pex_fd_cb(struct uloop_fd *fd, unsigned int events)
        }
 }
 
-int __pex_msg_send(int fd, const void *addr)
+static inline uint32_t
+csum_tcpudp_nofold(uint32_t saddr, uint32_t daddr, uint32_t len, uint8_t proto)
+{
+       uint64_t sum = 0;
+
+       sum += saddr;
+       sum += daddr;
+#if __BYTE_ORDER == __LITTLE_ENDIAN
+       sum += (proto + len) << 8;
+#else
+       sum += proto + len;
+#endif
+
+       sum = (sum & 0xffffffff) + (sum >> 32);
+       sum = (sum & 0xffffffff) + (sum >> 32);
+
+       return (uint32_t)sum;
+}
+
+static inline uint32_t csum_add(uint32_t sum, uint32_t addend)
+{
+       sum += addend;
+       return sum + (sum < addend);
+}
+
+static inline uint16_t csum_fold(uint32_t sum)
+{
+       sum = (sum & 0xffff) + (sum >> 16);
+       sum = (sum & 0xffff) + (sum >> 16);
+
+       return (uint16_t)~sum;
+}
+
+static uint32_t csum_partial(const void *buf, int len)
+{
+       const uint16_t *data = buf;
+       uint32_t sum = 0;
+
+       while (len > 1) {
+               sum += *data++;
+               len -= 2;
+       }
+
+       if (len == 1)
+#if __BYTE_ORDER == __LITTLE_ENDIAN
+               sum += *(uint8_t *)data;
+#else
+               sum += *(uint8_t *)data << 8;
+#endif
+
+       sum = (sum & 0xffff) + (sum >> 16);
+       sum = (sum & 0xffff) + (sum >> 16);
+
+       return sum;
+}
+
+static void pex_fixup_udpv4(void *hdr, size_t hdrlen, const void *data, size_t len)
+{
+       struct ip *ip = hdr;
+       struct udphdr *udp = hdr + ip->ip_hl * 4;
+       uint16_t udp_len = sizeof(*udp) + len;
+       uint32_t sum;
+
+       if ((void *)&udp[1] > hdr + hdrlen)
+               return;
+
+       udp->uh_sum = 0;
+       udp->uh_ulen = htons(udp_len);
+       sum = csum_tcpudp_nofold(*(uint32_t *)&ip->ip_src, *(uint32_t *)&ip->ip_dst,
+                                ip->ip_p, udp_len);
+       sum = csum_add(sum, csum_partial(udp, sizeof(*udp)));
+       sum = csum_add(sum, csum_partial(data, len));
+       udp->uh_sum = csum_fold(sum);
+
+       ip->ip_len = htons(hdrlen + len);
+       ip->ip_sum = 0;
+       ip->ip_sum = csum_fold(csum_partial(ip, sizeof(*ip)));
+
+#ifdef __APPLE__
+       ip->ip_len = hdrlen + len;
+#endif
+}
+
+static void pex_fixup_udpv6(void *hdr, size_t hdrlen, const void *data, size_t len)
+{
+       struct ip6_hdr *ip = hdr;
+       struct udphdr *udp = hdr + sizeof(*ip);
+       uint16_t udp_len = htons(sizeof(*udp) + len);
+
+       if ((void *)&udp[1] > hdr + hdrlen)
+               return;
+
+       ip->ip6_plen = htons(sizeof(*udp) + len);
+       udp->uh_sum = 0;
+       udp->uh_ulen = udp_len;
+       udp->uh_sum = csum_fold(csum_partial(hdr, sizeof(*ip) + sizeof(*udp)));
+
+#ifdef __APPLE__
+       ip->ip6_plen = sizeof(*udp) + len;
+#endif
+}
+
+static void pex_fixup_header(void *hdr, size_t hdrlen, const void *data, size_t len)
+{
+       if (hdrlen >= sizeof(struct ip6_hdr) + sizeof(struct udphdr))
+               pex_fixup_udpv6(hdr, hdrlen, data, len);
+       else if (hdrlen >= sizeof(struct ip) + sizeof(struct udphdr))
+               pex_fixup_udpv4(hdr, hdrlen, data, len);
+}
+
+int __pex_msg_send(int fd, const void *addr, void *ip_hdr, size_t ip_hdrlen)
 {
        struct pex_hdr *hdr = (struct pex_hdr *)pex_tx_buf;
        const struct sockaddr *sa = addr;
        size_t tx_len = sizeof(*hdr) + hdr->len;
        uint16_t orig_len = hdr->len;
-       size_t addr_len;
        int ret;
 
        if (fd < 0) {
                hdr->len -= sizeof(struct pex_ext_hdr);
-               fd = pex_fd.fd;
+               if (ip_hdrlen)
+                       fd = sa->sa_family == AF_INET6 ? pex_raw_v6_fd : pex_raw_v4_fd;
+               else
+                       fd = pex_fd.fd;
+
+               if (fd < 0)
+                       return -1;
        }
 
        hdr->len = htons(hdr->len);
        if (addr) {
+               struct iovec iov[2] = {
+                       { .iov_base = (void *)ip_hdr, .iov_len = ip_hdrlen },
+                       { .iov_base = pex_tx_buf, .iov_len = tx_len }
+               };
+               struct msghdr msg = {
+                       .msg_name = (void *)addr,
+                       .msg_iov = iov,
+                       .msg_iovlen = ARRAY_SIZE(iov),
+               };
+
                if (sa->sa_family == AF_INET6)
-                       addr_len = sizeof(struct sockaddr_in6);
+                       msg.msg_namelen = sizeof(struct sockaddr_in6);
                else
-                       addr_len = sizeof(struct sockaddr_in);
-               ret = sendto(fd, pex_tx_buf, tx_len, 0, sa, addr_len);
+                       msg.msg_namelen = sizeof(struct sockaddr_in);
+
+               if (ip_hdrlen) {
+                       pex_fixup_header(ip_hdr, ip_hdrlen, pex_tx_buf, tx_len);
+               } else {
+                       msg.msg_iov++;
+                       msg.msg_iovlen--;
+               }
+
+               ret = sendmsg(fd, &msg, 0);
        } else {
                ret = send(fd, pex_tx_buf, tx_len, 0);
        }
@@ -406,9 +544,28 @@ int pex_open(void *addr, size_t addr_len, pex_recv_cb_t cb, bool server)
 
        pex_recv_cb = cb;
 
+       if (server) {
+               pex_raw_v4_fd = fd = socket(PF_INET, SOCK_RAW, IPPROTO_UDP);
+               if (fd < 0)
+                       return -1;
+
+               setsockopt(fd, SOL_SOCKET, SO_BROADCAST, &yes, sizeof(yes));
+               setsockopt(fd, IPPROTO_IP, IP_HDRINCL, &yes, sizeof(yes));
+
+#ifdef linux
+               pex_raw_v6_fd = fd = socket(PF_INET6, SOCK_RAW, IPPROTO_UDP);
+               if (fd < 0)
+                       goto close_raw;
+
+               setsockopt(fd, SOL_SOCKET, SO_BROADCAST, &yes, sizeof(yes));
+               setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &no, sizeof(no));
+               setsockopt(fd, IPPROTO_IPV6, IPV6_HDRINCL, &yes, sizeof(yes));
+#endif
+       }
+
        pex_urandom = fopen("/dev/urandom", "r");
        if (!pex_urandom)
-               return -1;
+               goto close_raw;
 
        fd = socket(sa->sa_family == AF_INET ? PF_INET : PF_INET6, SOCK_DGRAM, IPPROTO_UDP);
        if (fd < 0)
@@ -445,6 +602,13 @@ close_socket:
        close(fd);
 close_urandom:
        fclose(pex_urandom);
+close_raw:
+       if (pex_raw_v4_fd >= 0)
+               close(pex_raw_v4_fd);
+       if (pex_raw_v6_fd >= 0)
+               close(pex_raw_v6_fd);
+       pex_raw_v4_fd = -1;
+       pex_raw_v6_fd = -1;
        return -1;
 }
 
@@ -453,6 +617,13 @@ void pex_close(void)
        if (!pex_fd.cb)
                return;
 
+       if (pex_raw_v4_fd >= 0)
+               close(pex_raw_v4_fd);
+       if (pex_raw_v6_fd >= 0)
+               close(pex_raw_v6_fd);
+       pex_raw_v4_fd = -1;
+       pex_raw_v6_fd = -1;
+
        fclose(pex_urandom);
        uloop_fd_delete(&pex_fd);
        close(pex_fd.fd);
index 7928f022b168524fafd166c0fca35b7e54300560..653eb049fcaff7ee628273db6a79554caa5b6799 100644 (file)
--- a/pex-msg.h
+++ b/pex-msg.h
@@ -21,6 +21,7 @@ enum pex_opcode {
        PEX_MSG_UPDATE_RESPONSE,
        PEX_MSG_UPDATE_RESPONSE_DATA,
        PEX_MSG_UPDATE_RESPONSE_NO_DATA,
+       PEX_MSG_ENDPOINT_NOTIFY,
 };
 
 #define PEX_ID_LEN             8
@@ -93,7 +94,7 @@ uint64_t pex_network_hash(const uint8_t *auth_key, uint64_t req_id);
 struct pex_hdr *__pex_msg_init(const uint8_t *pubkey, uint8_t opcode);
 struct pex_hdr *__pex_msg_init_ext(const uint8_t *pubkey, const uint8_t *auth_key,
                                   uint8_t opcode, bool ext);
-int __pex_msg_send(int fd, const void *addr);
+int __pex_msg_send(int fd, const void *addr, void *ip_hdr, size_t ip_hdrlen);
 void *pex_msg_append(size_t len);
 
 struct pex_update_request *
diff --git a/pex.c b/pex.c
index 62a30f48428822230739acc38a79240ccf8cc655..839567dd2763a7b4ec1c339b10e34eba8a7e87c4 100644 (file)
--- a/pex.c
+++ b/pex.c
@@ -5,6 +5,10 @@
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <arpa/inet.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip6.h>
+#include <netinet/udp.h>
 #include <fcntl.h>
 #include <stdlib.h>
 #include <inttypes.h>
@@ -70,7 +74,7 @@ static void pex_msg_send(struct network *net, struct network_peer *peer)
                return;
 
        pex_get_peer_addr(&sin6, net, peer);
-       if (__pex_msg_send(net->pex.fd.fd, &sin6) < 0)
+       if (__pex_msg_send(net->pex.fd.fd, &sin6, NULL, 0) < 0)
                D_PEER(net, peer, "pex_msg_send failed: %s", strerror(errno));
 }
 
@@ -82,7 +86,7 @@ static void pex_msg_send_ext(struct network *net, struct network_peer *peer,
        if (!addr)
                return pex_msg_send(net, peer);
 
-       if (__pex_msg_send(-1, addr) < 0)
+       if (__pex_msg_send(-1, addr, NULL, 0) < 0)
                D_NET(net, "pex_msg_send_ext(%s) failed: %s",
                      inet_ntop(addr->sin6_family, (const void *)&addr->sin6_addr, addrbuf,
                                sizeof(addrbuf)),
@@ -164,8 +168,22 @@ network_pex_handle_endpoint_change(struct network *net, struct network_peer *pee
 static void
 network_pex_host_request_update(struct network *net, struct network_pex_host *host)
 {
+       union {
+               struct {
+                       struct ip ip;
+                       struct udphdr udp;
+               } ipv4;
+               struct {
+                       struct ip6_hdr ip;
+                       struct udphdr udp;
+               } ipv6;
+       } packet = {};
+       struct udphdr *udp;
        char addrstr[INET6_ADDRSTRLEN];
+       union network_endpoint dest_ep;
+       union network_addr local_addr = {};
        uint64_t version = 0;
+       int len;
 
        if (net->net_data_len)
                version = net->net_data_version;
@@ -181,7 +199,57 @@ network_pex_host_request_update(struct network *net, struct network_pex_host *ho
                                         net->config.auth_key, &host->endpoint,
                                         version, true))
                return;
-       __pex_msg_send(-1, &host->endpoint);
+
+       __pex_msg_send(-1, &host->endpoint, NULL, 0);
+
+       if (!net->net_config.local_host)
+               return;
+
+       pex_msg_init_ext(net, PEX_MSG_ENDPOINT_NOTIFY, true);
+
+       memcpy(&dest_ep, &host->endpoint, sizeof(dest_ep));
+
+       /* work around issue with local address lookup for local broadcast */
+       if (host->endpoint.sa.sa_family == AF_INET) {
+               uint8_t *data = (uint8_t *)&dest_ep.in.sin_addr;
+
+               if (data[3] == 0xff)
+                       data[3] = 0xfe;
+       }
+       network_get_local_addr(&local_addr, &dest_ep);
+
+       memset(&dest_ep, 0, sizeof(dest_ep));
+       dest_ep.sa.sa_family = host->endpoint.sa.sa_family;
+       if (host->endpoint.sa.sa_family == AF_INET) {
+               packet.ipv4.ip = (struct ip){
+                       .ip_hl = 5,
+                       .ip_v = 4,
+                       .ip_ttl = 64,
+                       .ip_p = IPPROTO_UDP,
+                       .ip_src = local_addr.in,
+                       .ip_dst = host->endpoint.in.sin_addr,
+               };
+               dest_ep.in.sin_addr = host->endpoint.in.sin_addr;
+               udp = &packet.ipv4.udp;
+               len = sizeof(packet.ipv4);
+       } else {
+               packet.ipv6.ip = (struct ip6_hdr){
+                       .ip6_flow = htonl(6 << 28),
+                       .ip6_hops = 128,
+                       .ip6_nxt = IPPROTO_UDP,
+                       .ip6_src = local_addr.in6,
+                       .ip6_dst = host->endpoint.in6.sin6_addr,
+               };
+               dest_ep.in6.sin6_addr = host->endpoint.in6.sin6_addr;
+               udp = &packet.ipv6.udp;
+               len = sizeof(packet.ipv6);
+       }
+
+       udp->uh_sport = htons(net->net_config.local_host->peer.port);
+       udp->uh_dport = host->endpoint.in6.sin6_port;
+
+       if (__pex_msg_send(-1, &dest_ep, &packet, len) < 0)
+               D_NET(net, "pex_msg_send_raw failed: %s", strerror(errno));
 }
 
 static void
@@ -543,6 +611,8 @@ network_pex_recv(struct network *net, struct network_peer *peer, struct pex_hdr
                network_pex_recv_update_response(net, data, hdr->len,
                                              NULL, hdr->opcode);
                break;
+       case PEX_MSG_ENDPOINT_NOTIFY:
+               break;
        }
 }
 
@@ -740,6 +810,8 @@ global_pex_recv(struct pex_hdr *hdr, struct sockaddr_in6 *addr)
        struct network_peer *peer;
        struct network *net;
        void *data = (void *)(ehdr + 1);
+       char buf[INET6_ADDRSTRLEN];
+       int addr_len;
 
        if (hdr->version != 0)
                return;
@@ -768,6 +840,28 @@ global_pex_recv(struct pex_hdr *hdr, struct sockaddr_in6 *addr)
        case PEX_MSG_UPDATE_RESPONSE_NO_DATA:
                network_pex_recv_update_response(net, data, hdr->len, addr, hdr->opcode);
                break;
+       case PEX_MSG_ENDPOINT_NOTIFY:
+               peer = pex_msg_peer(net, hdr->id);
+               if (!peer)
+                       break;
+
+               if (IN6_IS_ADDR_V4MAPPED(&addr->sin6_addr)) {
+                       struct sockaddr_in *sin = (struct sockaddr_in *)addr;
+                       struct in_addr in = *(struct in_addr *)&addr->sin6_addr.s6_addr[12];
+                       int port = addr->sin6_port;
+
+                       memset(addr, 0, sizeof(*addr));
+                       sin->sin_port = port;
+                       sin->sin_family = AF_INET;
+                       sin->sin_addr = in;
+               }
+
+               D_PEER(net, peer, "receive endpoint notification from %s",
+                 inet_ntop(addr->sin6_family, network_endpoint_addr((void *)addr, &addr_len),
+                           buf, sizeof(buf)));
+
+               memcpy(&peer->state.next_endpoint, addr, sizeof(*addr));
+               break;
        }
 }