[UDP]: Randomize port selection.
authorStephen Hemminger <shemminger@linux-foundation.org>
Sat, 25 Aug 2007 06:09:41 +0000 (23:09 -0700)
committerDavid S. Miller <davem@sunset.davemloft.net>
Wed, 10 Oct 2007 23:48:31 +0000 (16:48 -0700)
This patch causes UDP port allocation to be randomized like TCP.
The earlier code would always choose same port (ie first empty list).

Signed-off-by: Stephen Hemminger <shemminger@linux-foundation.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/ipv4/udp.c
net/ipv4/udp_impl.h
net/ipv4/udplite.c

index 69d4bd10f9c6ef31e5ce3c9762486bd61ab98bda..a581b543bff7f23fb53304f4439d07166d01c4e9 100644 (file)
@@ -113,9 +113,8 @@ DEFINE_SNMP_STAT(struct udp_mib, udp_statistics) __read_mostly;
 struct hlist_head udp_hash[UDP_HTABLE_SIZE];
 DEFINE_RWLOCK(udp_hash_lock);
 
-static int udp_port_rover;
-
-static inline int __udp_lib_lport_inuse(__u16 num, struct hlist_head udptable[])
+static inline int __udp_lib_lport_inuse(__u16 num,
+                                       const struct hlist_head udptable[])
 {
        struct sock *sk;
        struct hlist_node *node;
@@ -132,11 +131,10 @@ static inline int __udp_lib_lport_inuse(__u16 num, struct hlist_head udptable[])
  *  @sk:          socket struct in question
  *  @snum:        port number to look up
  *  @udptable:    hash list table, must be of UDP_HTABLE_SIZE
- *  @port_rover:  pointer to record of last unallocated port
  *  @saddr_comp:  AF-dependent comparison of bound local IP addresses
  */
 int __udp_lib_get_port(struct sock *sk, unsigned short snum,
-                      struct hlist_head udptable[], int *port_rover,
+                      struct hlist_head udptable[],
                       int (*saddr_comp)(const struct sock *sk1,
                                         const struct sock *sk2 )    )
 {
@@ -146,49 +144,56 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
        int    error = 1;
 
        write_lock_bh(&udp_hash_lock);
-       if (snum == 0) {
-               int best_size_so_far, best, result, i;
-
-               if (*port_rover > sysctl_local_port_range[1] ||
-                   *port_rover < sysctl_local_port_range[0])
-                       *port_rover = sysctl_local_port_range[0];
-               best_size_so_far = 32767;
-               best = result = *port_rover;
-               for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) {
-                       int size;
-
-                       head = &udptable[result & (UDP_HTABLE_SIZE - 1)];
-                       if (hlist_empty(head)) {
-                               if (result > sysctl_local_port_range[1])
-                                       result = sysctl_local_port_range[0] +
-                                               ((result - sysctl_local_port_range[0]) &
-                                                (UDP_HTABLE_SIZE - 1));
+
+       if (!snum) {
+               int i;
+               int low = sysctl_local_port_range[0];
+               int high = sysctl_local_port_range[1];
+               unsigned rover, best, best_size_so_far;
+
+               best_size_so_far = UINT_MAX;
+               best = rover = net_random() % (high - low) + low;
+
+               /* 1st pass: look for empty (or shortest) hash chain */
+               for (i = 0; i < UDP_HTABLE_SIZE; i++) {
+                       int size = 0;
+
+                       head = &udptable[rover & (UDP_HTABLE_SIZE - 1)];
+                       if (hlist_empty(head))
                                goto gotit;
-                       }
-                       size = 0;
+
                        sk_for_each(sk2, node, head) {
                                if (++size >= best_size_so_far)
                                        goto next;
                        }
                        best_size_so_far = size;
-                       best = result;
+                       best = rover;
                next:
-                       ;
+                       /* fold back if end of range */
+                       if (++rover > high)
+                               rover = low + ((rover - low)
+                                              & (UDP_HTABLE_SIZE - 1));
+
+
                }
-               result = best;
-               for (i = 0; i < (1 << 16) / UDP_HTABLE_SIZE;
-                    i++, result += UDP_HTABLE_SIZE) {
-                       if (result > sysctl_local_port_range[1])
-                               result = sysctl_local_port_range[0]
-                                       + ((result - sysctl_local_port_range[0]) &
-                                          (UDP_HTABLE_SIZE - 1));
-                       if (! __udp_lib_lport_inuse(result, udptable))
-                               break;
+
+               /* 2nd pass: find hole in shortest hash chain */
+               rover = best;
+               for (i = 0; i < (1 << 16) / UDP_HTABLE_SIZE; i++) {
+                       if (! __udp_lib_lport_inuse(rover, udptable))
+                               goto gotit;
+                       rover += UDP_HTABLE_SIZE;
+                       if (rover > high)
+                               rover = low + ((rover - low)
+                                              & (UDP_HTABLE_SIZE - 1));
                }
-               if (i >= (1 << 16) / UDP_HTABLE_SIZE)
-                       goto fail;
+
+
+               /* All ports in use! */
+               goto fail;
+
 gotit:
-               *port_rover = snum = result;
+               snum = rover;
        } else {
                head = &udptable[snum & (UDP_HTABLE_SIZE - 1)];
 
@@ -201,6 +206,7 @@ gotit:
                            (*saddr_comp)(sk, sk2)                             )
                                goto fail;
        }
+
        inet_sk(sk)->num = snum;
        sk->sk_hash = snum;
        if (sk_unhashed(sk)) {
@@ -217,7 +223,7 @@ fail:
 int udp_get_port(struct sock *sk, unsigned short snum,
                        int (*scmp)(const struct sock *, const struct sock *))
 {
-       return  __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, scmp);
+       return  __udp_lib_get_port(sk, snum, udp_hash, scmp);
 }
 
 int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
index 820a477cfaa6e9899234b03ee8120c9ca158babd..6c55828e41badfb84a1ad4c4852015f3da5f3276 100644 (file)
@@ -9,7 +9,7 @@ extern int      __udp4_lib_rcv(struct sk_buff *, struct hlist_head [], int );
 extern void    __udp4_lib_err(struct sk_buff *, u32, struct hlist_head []);
 
 extern int     __udp_lib_get_port(struct sock *sk, unsigned short snum,
-                                  struct hlist_head udptable[], int *port_rover,
+                                  struct hlist_head udptable[],
                                   int (*)(const struct sock*,const struct sock*));
 extern int     ipv4_rcv_saddr_equal(const struct sock *, const struct sock *);
 
index f34fd686a8f15ac6e49b687742701c1037ab1417..94977205abb47581686bfd46bc1f5232100f005a 100644 (file)
 DEFINE_SNMP_STAT(struct udp_mib, udplite_statistics)   __read_mostly;
 
 struct hlist_head      udplite_hash[UDP_HTABLE_SIZE];
-static int             udplite_port_rover;
 
 int udplite_get_port(struct sock *sk, unsigned short p,
                     int (*c)(const struct sock *, const struct sock *))
 {
-       return  __udp_lib_get_port(sk, p, udplite_hash, &udplite_port_rover, c);
+       return  __udp_lib_get_port(sk, p, udplite_hash, c);
 }
 
 static int udplite_v4_get_port(struct sock *sk, unsigned short snum)