1 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
2 From: "Jason A. Donenfeld" <Jason@zx2c4.com>
3 Date: Tue, 23 Jun 2020 03:59:45 -0600
4 Subject: [PATCH] wireguard: device: avoid circular netns references
6 commit 900575aa33a3eaaef802b31de187a85c4a4b4bd0 upstream.
8 Before, we took a reference to the creating netns if the new netns was
9 different. This caused issues with circular references, with two
10 wireguard interfaces swapping namespaces. The solution is to rather not
11 take any extra references at all, but instead simply invalidate the
12 creating netns pointer when that netns is deleted.
14 In order to prevent this from happening again, this commit improves the
15 rough object leak tracking by allowing it to account for created and
16 destroyed interfaces, aside from just peers and keys. That then makes it
17 possible to check for the object leak when having two interfaces take a
18 reference to each others' namespaces.
20 Fixes: e7096c131e51 ("net: WireGuard secure network tunnel")
21 Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
22 Signed-off-by: David S. Miller <davem@davemloft.net>
23 Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
25 drivers/net/wireguard/device.c | 58 ++++++++++------------
26 drivers/net/wireguard/device.h | 3 +-
27 drivers/net/wireguard/netlink.c | 14 ++++--
28 drivers/net/wireguard/socket.c | 25 +++++++---
29 tools/testing/selftests/wireguard/netns.sh | 13 ++++-
30 5 files changed, 67 insertions(+), 46 deletions(-)
32 --- a/drivers/net/wireguard/device.c
33 +++ b/drivers/net/wireguard/device.c
34 @@ -45,17 +45,18 @@ static int wg_open(struct net_device *de
36 dev_v6->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_NONE;
38 + mutex_lock(&wg->device_update_lock);
39 ret = wg_socket_init(wg, wg->incoming_port);
42 - mutex_lock(&wg->device_update_lock);
44 list_for_each_entry(peer, &wg->peer_list, peer_list) {
45 wg_packet_send_staged_packets(peer);
46 if (peer->persistent_keepalive_interval)
47 wg_packet_send_keepalive(peer);
50 mutex_unlock(&wg->device_update_lock);
55 #ifdef CONFIG_PM_SLEEP
56 @@ -225,6 +226,7 @@ static void wg_destruct(struct net_devic
57 list_del(&wg->device_list);
59 mutex_lock(&wg->device_update_lock);
60 + rcu_assign_pointer(wg->creating_net, NULL);
61 wg->incoming_port = 0;
62 wg_socket_reinit(wg, NULL, NULL);
63 /* The final references are cleared in the below calls to destroy_workqueue. */
64 @@ -240,13 +242,11 @@ static void wg_destruct(struct net_devic
65 skb_queue_purge(&wg->incoming_handshakes);
66 free_percpu(dev->tstats);
67 free_percpu(wg->incoming_handshakes_worker);
68 - if (wg->have_creating_net_ref)
69 - put_net(wg->creating_net);
70 kvfree(wg->index_hashtable);
71 kvfree(wg->peer_hashtable);
72 mutex_unlock(&wg->device_update_lock);
74 - pr_debug("%s: Interface deleted\n", dev->name);
75 + pr_debug("%s: Interface destroyed\n", dev->name);
79 @@ -292,7 +292,7 @@ static int wg_newlink(struct net *src_ne
80 struct wg_device *wg = netdev_priv(dev);
83 - wg->creating_net = src_net;
84 + rcu_assign_pointer(wg->creating_net, src_net);
85 init_rwsem(&wg->static_identity.lock);
86 mutex_init(&wg->socket_update_lock);
87 mutex_init(&wg->device_update_lock);
88 @@ -393,30 +393,26 @@ static struct rtnl_link_ops link_ops __r
89 .newlink = wg_newlink,
92 -static int wg_netdevice_notification(struct notifier_block *nb,
93 - unsigned long action, void *data)
94 +static void wg_netns_pre_exit(struct net *net)
96 - struct net_device *dev = ((struct netdev_notifier_info *)data)->dev;
97 - struct wg_device *wg = netdev_priv(dev);
101 - if (action != NETDEV_REGISTER || dev->netdev_ops != &netdev_ops)
103 + struct wg_device *wg;
105 - if (dev_net(dev) == wg->creating_net && wg->have_creating_net_ref) {
106 - put_net(wg->creating_net);
107 - wg->have_creating_net_ref = false;
108 - } else if (dev_net(dev) != wg->creating_net &&
109 - !wg->have_creating_net_ref) {
110 - wg->have_creating_net_ref = true;
111 - get_net(wg->creating_net);
113 + list_for_each_entry(wg, &device_list, device_list) {
114 + if (rcu_access_pointer(wg->creating_net) == net) {
115 + pr_debug("%s: Creating namespace exiting\n", wg->dev->name);
116 + netif_carrier_off(wg->dev);
117 + mutex_lock(&wg->device_update_lock);
118 + rcu_assign_pointer(wg->creating_net, NULL);
119 + wg_socket_reinit(wg, NULL, NULL);
120 + mutex_unlock(&wg->device_update_lock);
127 -static struct notifier_block netdevice_notifier = {
128 - .notifier_call = wg_netdevice_notification
129 +static struct pernet_operations pernet_ops = {
130 + .pre_exit = wg_netns_pre_exit
133 int __init wg_device_init(void)
134 @@ -429,18 +425,18 @@ int __init wg_device_init(void)
138 - ret = register_netdevice_notifier(&netdevice_notifier);
139 + ret = register_pernet_device(&pernet_ops);
143 ret = rtnl_link_register(&link_ops);
145 - goto error_netdevice;
151 - unregister_netdevice_notifier(&netdevice_notifier);
153 + unregister_pernet_device(&pernet_ops);
155 #ifdef CONFIG_PM_SLEEP
156 unregister_pm_notifier(&pm_notifier);
157 @@ -451,7 +447,7 @@ error_pm:
158 void wg_device_uninit(void)
160 rtnl_link_unregister(&link_ops);
161 - unregister_netdevice_notifier(&netdevice_notifier);
162 + unregister_pernet_device(&pernet_ops);
163 #ifdef CONFIG_PM_SLEEP
164 unregister_pm_notifier(&pm_notifier);
166 --- a/drivers/net/wireguard/device.h
167 +++ b/drivers/net/wireguard/device.h
168 @@ -40,7 +40,7 @@ struct wg_device {
169 struct net_device *dev;
170 struct crypt_queue encrypt_queue, decrypt_queue;
171 struct sock __rcu *sock4, *sock6;
172 - struct net *creating_net;
173 + struct net __rcu *creating_net;
174 struct noise_static_identity static_identity;
175 struct workqueue_struct *handshake_receive_wq, *handshake_send_wq;
176 struct workqueue_struct *packet_crypt_wq;
177 @@ -56,7 +56,6 @@ struct wg_device {
178 unsigned int num_peers, device_update_gen;
181 - bool have_creating_net_ref;
184 int wg_device_init(void);
185 --- a/drivers/net/wireguard/netlink.c
186 +++ b/drivers/net/wireguard/netlink.c
187 @@ -517,11 +517,15 @@ static int wg_set_device(struct sk_buff
188 if (flags & ~__WGDEVICE_F_ALL)
192 - if ((info->attrs[WGDEVICE_A_LISTEN_PORT] ||
193 - info->attrs[WGDEVICE_A_FWMARK]) &&
194 - !ns_capable(wg->creating_net->user_ns, CAP_NET_ADMIN))
196 + if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
199 + net = rcu_dereference(wg->creating_net);
200 + ret = !net || !ns_capable(net->user_ns, CAP_NET_ADMIN) ? -EPERM : 0;
206 ++wg->device_update_gen;
208 --- a/drivers/net/wireguard/socket.c
209 +++ b/drivers/net/wireguard/socket.c
210 @@ -347,6 +347,7 @@ static void set_sock_opts(struct socket
212 int wg_socket_init(struct wg_device *wg, u16 port)
216 struct udp_tunnel_sock_cfg cfg = {
218 @@ -371,37 +372,47 @@ int wg_socket_init(struct wg_device *wg,
223 + net = rcu_dereference(wg->creating_net);
224 + net = net ? maybe_get_net(net) : NULL;
226 + if (unlikely(!net))
229 #if IS_ENABLED(CONFIG_IPV6)
233 - ret = udp_sock_create(wg->creating_net, &port4, &new4);
234 + ret = udp_sock_create(net, &port4, &new4);
236 pr_err("%s: Could not create IPv4 socket\n", wg->dev->name);
241 - setup_udp_tunnel_sock(wg->creating_net, new4, &cfg);
242 + setup_udp_tunnel_sock(net, new4, &cfg);
244 #if IS_ENABLED(CONFIG_IPV6)
245 if (ipv6_mod_enabled()) {
246 port6.local_udp_port = inet_sk(new4->sk)->inet_sport;
247 - ret = udp_sock_create(wg->creating_net, &port6, &new6);
248 + ret = udp_sock_create(net, &port6, &new6);
250 udp_tunnel_sock_release(new4);
251 if (ret == -EADDRINUSE && !port && retries++ < 100)
253 pr_err("%s: Could not create IPv6 socket\n",
259 - setup_udp_tunnel_sock(wg->creating_net, new6, &cfg);
260 + setup_udp_tunnel_sock(net, new6, &cfg);
264 wg_socket_reinit(wg, new4->sk, new6 ? new6->sk : NULL);
272 void wg_socket_reinit(struct wg_device *wg, struct sock *new4,
273 --- a/tools/testing/selftests/wireguard/netns.sh
274 +++ b/tools/testing/selftests/wireguard/netns.sh
275 @@ -587,9 +587,20 @@ ip0 link set wg0 up
279 +# Ensure there aren't circular reference loops
280 +ip1 link add wg1 type wireguard
281 +ip2 link add wg2 type wireguard
282 +ip1 link set wg1 netns $netns2
283 +ip2 link set wg2 netns $netns1
284 +pp ip netns delete $netns1
285 +pp ip netns delete $netns2
286 +pp ip netns add $netns1
287 +pp ip netns add $netns2
289 +sleep 2 # Wait for cleanup and grace periods
291 while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
292 - [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue
293 + [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ ?[0-9]*)\ .*(created|destroyed).* ]] || continue
294 objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"