lists.openwall.net   lists  /  announce  owl-users  owl-dev  john-users  john-dev  passwdqc-users  yescrypt  popa3d-users  /  oss-security  kernel-hardening  musl  sabotage  tlsify  passwords  /  crypt-dev  xvendor  /  Bugtraq  Full-Disclosure  linux-kernel  linux-netdev  linux-ext4  linux-hardening  linux-cve-announce  PHC 
Open Source and information security mailing list archives
 
Hash Suite: Windows password security audit tool. GUI, reports in PDF.
[<prev] [next>] [<thread-prev] [thread-next>] [day] [month] [year] [list]
Message-ID: <CANn89iLcJMA9WL7L2NBBkz+qki8XaRSvOoG+oy53FgXA_spR2Q@mail.gmail.com>
Date:   Sat, 11 Jun 2022 11:06:54 -0700
From:   Eric Dumazet <edumazet@...gle.com>
To:     Joanne Koong <joannelkoong@...il.com>
Cc:     netdev <netdev@...r.kernel.org>, Martin KaFai Lau <kafai@...com>,
        Jakub Kicinski <kuba@...nel.org>,
        David Miller <davem@...emloft.net>,
        Paolo Abeni <pabeni@...hat.com>,
        Mat Martineau <mathew.j.martineau@...ux.intel.com>
Subject: Re: [PATCH net-next v3 2/3] net: Add bhash2 hashbucket locks

On Fri, Jun 10, 2022 at 7:17 PM Joanne Koong <joannelkoong@...il.com> wrote:
>
> Currently, the bhash2 hashbucket uses its corresponding bhash
> hashbucket's lock for serializing concurrent accesses. There,
> however, can be the case where the bhash2 hashbucket is accessed
> concurrently by multiple processes that hash to different bhash
> hashbuckets but to the same bhash2 hashbucket.
>
> As such, each bhash2 hashbucket will need to have its own lock
> instead of using its corresponding bhash hashbucket's lock.
>
> Fixes: d5a42de8bdbe ("net: Add a second bind table hashed by port and address")
> Signed-off-by: Joanne Koong <joannelkoong@...il.com>
> ---
>  include/net/inet_hashtables.h   |  25 +++----
>  net/dccp/proto.c                |   3 +-
>  net/ipv4/inet_connection_sock.c |  60 +++++++++-------
>  net/ipv4/inet_hashtables.c      | 119 +++++++++++++++-----------------
>  net/ipv4/tcp.c                  |   7 +-
>  5 files changed, 107 insertions(+), 107 deletions(-)
>
> diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
> index 2c331ce6ca73..c5b112f0938b 100644
> --- a/include/net/inet_hashtables.h
> +++ b/include/net/inet_hashtables.h
> @@ -124,15 +124,6 @@ struct inet_bind_hashbucket {
>         struct hlist_head       chain;
>  };
>
> -/* This is synchronized using the inet_bind_hashbucket's spinlock.
> - * Instead of having separate spinlocks, the inet_bind2_hashbucket can share
> - * the inet_bind_hashbucket's given that in every case where the bhash2 table
> - * is useful, a lookup in the bhash table also occurs.
> - */
> -struct inet_bind2_hashbucket {
> -       struct hlist_head       chain;
> -};
> -
>  /* Sockets can be hashed in established or listening table.
>   * We must use different 'nulls' end-of-chain value for all hash buckets :
>   * A socket might transition from ESTABLISH to LISTEN state without
> @@ -169,7 +160,7 @@ struct inet_hashinfo {
>          * conflicts.
>          */
>         struct kmem_cache               *bind2_bucket_cachep;
> -       struct inet_bind2_hashbucket    *bhash2;
> +       struct inet_bind_hashbucket     *bhash2;
>         unsigned int                    bhash_size;
>
>         /* The 2nd listener table hashed by local port and address */
> @@ -240,7 +231,7 @@ static inline bool check_bind_bucket_match(struct inet_bind_bucket *tb,
>
>  struct inet_bind2_bucket *
>  inet_bind2_bucket_create(struct kmem_cache *cachep, struct net *net,
> -                        struct inet_bind2_hashbucket *head,
> +                        struct inet_bind_hashbucket *head,
>                          const unsigned short port, int l3mdev,
>                          const struct sock *sk);
>
> @@ -248,12 +239,12 @@ void inet_bind2_bucket_destroy(struct kmem_cache *cachep,
>                                struct inet_bind2_bucket *tb);
>
>  struct inet_bind2_bucket *
> -inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
> +inet_bind2_bucket_find(struct inet_bind_hashbucket *head,
> +                      struct inet_hashinfo *hinfo, struct net *net,
>                        const unsigned short port, int l3mdev,
> -                      struct sock *sk,
> -                      struct inet_bind2_hashbucket **head);
> +                      struct sock *sk);
>
> -bool check_bind2_bucket_match_nulladdr(struct inet_bind2_bucket *tb,
> +bool check_bind2_bucket_match_addr_any(struct inet_bind2_bucket *tb,
>                                        struct net *net,
>                                        const unsigned short port,
>                                        int l3mdev,
> @@ -265,6 +256,10 @@ static inline u32 inet_bhashfn(const struct net *net, const __u16 lport,
>         return (lport + net_hash_mix(net)) & (bhash_size - 1);
>  }
>
> +struct inet_bind_hashbucket *
> +inet_bhashfn_portaddr(struct inet_hashinfo *hinfo, const struct sock *sk,
> +                     const struct net *net, unsigned short port);
> +
>  void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
>                     struct inet_bind2_bucket *tb2, const unsigned short snum);
>
> diff --git a/net/dccp/proto.c b/net/dccp/proto.c
> index 2e78458900f2..f4f2ad5f9c08 100644
> --- a/net/dccp/proto.c
> +++ b/net/dccp/proto.c
> @@ -1182,7 +1182,7 @@ static int __init dccp_init(void)
>                 goto out_free_dccp_locks;
>         }
>
> -       dccp_hashinfo.bhash2 = (struct inet_bind2_hashbucket *)
> +       dccp_hashinfo.bhash2 = (struct inet_bind_hashbucket *)
>                 __get_free_pages(GFP_ATOMIC | __GFP_NOWARN, bhash_order);
>
>         if (!dccp_hashinfo.bhash2) {
> @@ -1193,6 +1193,7 @@ static int __init dccp_init(void)
>         for (i = 0; i < dccp_hashinfo.bhash_size; i++) {
>                 spin_lock_init(&dccp_hashinfo.bhash[i].lock);
>                 INIT_HLIST_HEAD(&dccp_hashinfo.bhash[i].chain);
> +               spin_lock_init(&dccp_hashinfo.bhash2[i].lock);
>                 INIT_HLIST_HEAD(&dccp_hashinfo.bhash2[i].chain);
>         }
>
> diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c
> index c0b7e6c21360..24a42e4d8234 100644
> --- a/net/ipv4/inet_connection_sock.c
> +++ b/net/ipv4/inet_connection_sock.c
> @@ -131,14 +131,14 @@ static bool use_bhash2_on_bind(const struct sock *sk)
>         return sk->sk_rcv_saddr != htonl(INADDR_ANY);
>  }
>
> -static u32 get_bhash2_nulladdr_hash(const struct sock *sk, struct net *net,
> +static u32 get_bhash2_addr_any_hash(const struct sock *sk, struct net *net,
>                                     int port)
>  {
>  #if IS_ENABLED(CONFIG_IPV6)
> -       struct in6_addr nulladdr = {};
> +       struct in6_addr addr_any = {};
>
>         if (sk->sk_family == AF_INET6)
> -               return ipv6_portaddr_hash(net, &nulladdr, port);
> +               return ipv6_portaddr_hash(net, &addr_any, port);
>  #endif
>         return ipv4_portaddr_hash(net, 0, port);
>  }
> @@ -204,18 +204,18 @@ static bool check_bhash2_conflict(const struct sock *sk,
>         return false;
>  }
>
> -/* This should be called only when the corresponding inet_bind_bucket spinlock
> - * is held
> - */
> +/* This should be called only when the tb and tb2 hashbuckets' locks are held */
>  static int inet_csk_bind_conflict(const struct sock *sk, int port,
>                                   struct inet_bind_bucket *tb,
>                                   struct inet_bind2_bucket *tb2, /* may be null */
> +                                 struct inet_bind_hashbucket *head_tb2,
>                                   bool relax, bool reuseport_ok)
>  {
>         struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
>         kuid_t uid = sock_i_uid((struct sock *)sk);
>         struct sock_reuseport *reuseport_cb;
> -       struct inet_bind2_hashbucket *head2;
> +       struct inet_bind_hashbucket *head_addr_any;
> +       bool addr_any_conflict = false;
>         bool reuseport_cb_ok;
>         struct sock *sk2;
>         struct net *net;
> @@ -254,33 +254,39 @@ static int inet_csk_bind_conflict(const struct sock *sk, int port,
>         /* check there's no conflict with an existing IPV6_ADDR_ANY (if ipv6) or
>          * INADDR_ANY (if ipv4) socket.
>          */
> -       hash = get_bhash2_nulladdr_hash(sk, net, port);
> -       head2 = &hinfo->bhash2[hash & (hinfo->bhash_size - 1)];
> +       hash = get_bhash2_addr_any_hash(sk, net, port);
> +       head_addr_any = &hinfo->bhash2[hash & (hinfo->bhash_size - 1)];
>
>         l3mdev = inet_sk_bound_l3mdev(sk);
> -       inet_bind_bucket_for_each(tb2, &head2->chain)
> -               if (check_bind2_bucket_match_nulladdr(tb2, net, port, l3mdev, sk))
> +
> +       if (head_addr_any != head_tb2)
> +               spin_lock_bh(&head_addr_any->lock);

I do not think you need the _bh.  Look at my next comment for the rationale.

> +
> +       inet_bind_bucket_for_each(tb2, &head_addr_any->chain)
> +               if (check_bind2_bucket_match_addr_any(tb2, net, port, l3mdev, sk))
>                         break;
>
>         if (tb2 && check_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok,
>                                          reuseport_ok))
> -               return true;
> +               addr_any_conflict = true;
>
> -       return false;
> +       if (head_addr_any != head_tb2)
> +               spin_unlock_bh(&head_addr_any->lock);
> +
> +       return addr_any_conflict;
>  }
>
>  /*
>   * Find an open port number for the socket.  Returns with the
> - * inet_bind_hashbucket lock held.
> + * inet_bind_hashbucket locks held if successful.
>   */
>  static struct inet_bind_hashbucket *
>  inet_csk_find_open_port(struct sock *sk, struct inet_bind_bucket **tb_ret,
>                         struct inet_bind2_bucket **tb2_ret,
> -                       struct inet_bind2_hashbucket **head2_ret, int *port_ret)
> +                       struct inet_bind_hashbucket **head2_ret, int *port_ret)
>  {
>         struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
> -       struct inet_bind2_hashbucket *head2;
> -       struct inet_bind_hashbucket *head;
> +       struct inet_bind_hashbucket *head, *head2;
>         struct net *net = sock_net(sk);
>         int i, low, high, attempt_half;
>         struct inet_bind2_bucket *tb2;
> @@ -325,19 +331,22 @@ inet_csk_find_open_port(struct sock *sk, struct inet_bind_bucket **tb_ret,
>                         continue;
>                 head = &hinfo->bhash[inet_bhashfn(net, port,
>                                                   hinfo->bhash_size)];
> +               head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
>                 spin_lock_bh(&head->lock);
> -               tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
> -                                            &head2);
> +               spin_lock_bh(&head2->lock);

Note: No need to disable BH twice.

  spin_lock_bh(&lock1);
  spin_lock_bh(&lock2);
  spin_unlock_bh(&lock2);
  spin_unlock_bh(&lock1);

Can instead be:

  spin_lock_bh(&lock1);
  spin_lock(&lock2);
  spin_unlock(&lock2);
  spin_unlock_bh(&lock1);

> +
> +               tb2 = inet_bind2_bucket_find(head2, hinfo, net, port, l3mdev, sk);
>                 inet_bind_bucket_for_each(tb, &head->chain)
>                         if (check_bind_bucket_match(tb, net, port, l3mdev)) {
>                                 if (!inet_csk_bind_conflict(sk, port, tb, tb2,
> -                                                           relax, false))
> +                                                           head2, relax, false))
>                                         goto success;
>                                 goto next_port;
>                         }
>                 tb = NULL;
>                 goto success;
>  next_port:
> +               spin_unlock_bh(&head2->lock);
>                 spin_unlock_bh(&head->lock);
>                 cond_resched();
>         }
> @@ -459,10 +468,9 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
>         bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN;
>         struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
>         bool bhash_created = false, bhash2_created = false;
> +       struct inet_bind_hashbucket *head, *head2;
>         struct inet_bind2_bucket *tb2 = NULL;
> -       struct inet_bind2_hashbucket *head2;
>         struct inet_bind_bucket *tb = NULL;
> -       struct inet_bind_hashbucket *head;
>         struct net *net = sock_net(sk);
>         int ret = 1, port = snum;
>         bool found_port = false;
> @@ -480,13 +488,14 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
>         } else {
>                 head = &hinfo->bhash[inet_bhashfn(net, port,
>                                                   hinfo->bhash_size)];
> +               head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
>                 spin_lock_bh(&head->lock);
>                 inet_bind_bucket_for_each(tb, &head->chain)
>                         if (check_bind_bucket_match(tb, net, port, l3mdev))
>                                 break;
>
> -               tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
> -                                            &head2);
> +               spin_lock_bh(&head2->lock);
> +               tb2 = inet_bind2_bucket_find(head2, hinfo, net, port, l3mdev, sk);
>         }
>
>         if (!tb) {
> @@ -513,7 +522,7 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
>                 if ((tb->fastreuse > 0 && reuse) ||
>                     sk_reuseport_match(tb, sk))
>                         goto success;
> -               if (inet_csk_bind_conflict(sk, port, tb, tb2, true, true))
> +               if (inet_csk_bind_conflict(sk, port, tb, tb2, head2, true, true))
>                         goto fail_unlock;
>         }
>  success:
> @@ -533,6 +542,7 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
>                         inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep,
>                                                   tb2);
>         }
> +       spin_unlock_bh(&head2->lock);
>         spin_unlock_bh(&head->lock);
>         return ret;
>  }
> diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
> index 73f18134b2d5..8fe8010c1a00 100644
> --- a/net/ipv4/inet_hashtables.c
> +++ b/net/ipv4/inet_hashtables.c
> @@ -83,7 +83,7 @@ struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep,
>
>  struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep,
>                                                    struct net *net,
> -                                                  struct inet_bind2_hashbucket *head,
> +                                                  struct inet_bind_hashbucket *head,
>                                                    const unsigned short port,
>                                                    int l3mdev,
>                                                    const struct sock *sk)
> @@ -127,9 +127,7 @@ void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket
>         }
>  }
>
> -/* Caller must hold the lock for the corresponding hashbucket in the bhash table
> - * with local BH disabled
> - */
> +/* Caller must hold hashbucket lock for this tb with local BH disabled */
>  void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb)
>  {
>         if (hlist_empty(&tb->owners)) {
> @@ -157,6 +155,9 @@ static void __inet_put_port(struct sock *sk)
>         const int bhash = inet_bhashfn(sock_net(sk), inet_sk(sk)->inet_num,
>                         hashinfo->bhash_size);
>         struct inet_bind_hashbucket *head = &hashinfo->bhash[bhash];
> +       struct inet_bind_hashbucket *head2 =
> +               inet_bhashfn_portaddr(hashinfo, sk, sock_net(sk),
> +                                     inet_sk(sk)->inet_num);
>         struct inet_bind2_bucket *tb2;
>         struct inet_bind_bucket *tb;
>
> @@ -167,12 +168,15 @@ static void __inet_put_port(struct sock *sk)
>         inet_sk(sk)->inet_num = 0;
>         inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb);
>
> +       spin_lock(&head2->lock);
>         if (inet_csk(sk)->icsk_bind2_hash) {
>                 tb2 = inet_csk(sk)->icsk_bind2_hash;
>                 __sk_del_bind2_node(sk);
>                 inet_csk(sk)->icsk_bind2_hash = NULL;
>                 inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2);
>         }
> +       spin_unlock(&head2->lock);
> +
>         spin_unlock(&head->lock);
>  }
>
> @@ -191,7 +195,9 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
>         const int bhash = inet_bhashfn(sock_net(sk), port,
>                                        table->bhash_size);
>         struct inet_bind_hashbucket *head = &table->bhash[bhash];
> -       struct inet_bind2_hashbucket *head_bhash2;
> +       struct inet_bind_hashbucket *head2 =
> +               inet_bhashfn_portaddr(table, child, sock_net(sk),
> +                                     port);
>         bool created_inet_bind_bucket = false;
>         struct net *net = sock_net(sk);
>         struct inet_bind2_bucket *tb2;
> @@ -199,9 +205,11 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
>         int l3mdev;
>
>         spin_lock(&head->lock);
> +       spin_lock(&head2->lock);
>         tb = inet_csk(sk)->icsk_bind_hash;
>         tb2 = inet_csk(sk)->icsk_bind2_hash;
>         if (unlikely(!tb || !tb2)) {
> +               spin_unlock(&head2->lock);
>                 spin_unlock(&head->lock);
>                 return -ENOENT;
>         }
> @@ -221,6 +229,7 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
>                         tb = inet_bind_bucket_create(table->bind_bucket_cachep,
>                                                      net, head, port, l3mdev);
>                         if (!tb) {
> +                               spin_unlock(&head2->lock);
>                                 spin_unlock(&head->lock);
>                                 return -ENOMEM;
>                         }
> @@ -233,17 +242,17 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
>                 l3mdev = inet_sk_bound_l3mdev(sk);
>
>  bhash2_find:
> -               tb2 = inet_bind2_bucket_find(table, net, port, l3mdev, child,
> -                                            &head_bhash2);
> +               tb2 = inet_bind2_bucket_find(head2, table, net, port, l3mdev, child);
>                 if (!tb2) {
>                         tb2 = inet_bind2_bucket_create(table->bind2_bucket_cachep,
> -                                                      net, head_bhash2, port,
> +                                                      net, head2, port,
>                                                        l3mdev, child);
>                         if (!tb2)
>                                 goto error;
>                 }
>         }
>         inet_bind_hash(child, tb, tb2, port);
> +       spin_unlock(&head2->lock);
>         spin_unlock(&head->lock);
>
>         return 0;
> @@ -251,6 +260,7 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
>  error:
>         if (created_inet_bind_bucket)
>                 inet_bind_bucket_destroy(table->bind_bucket_cachep, tb);
> +       spin_unlock(&head2->lock);
>         spin_unlock(&head->lock);
>         return -ENOMEM;
>  }
> @@ -771,24 +781,24 @@ static bool check_bind2_bucket_match(struct inet_bind2_bucket *tb,
>                         tb->l3mdev == l3mdev && tb->rcv_saddr == sk->sk_rcv_saddr;
>  }
>
> -bool check_bind2_bucket_match_nulladdr(struct inet_bind2_bucket *tb,
> +bool check_bind2_bucket_match_addr_any(struct inet_bind2_bucket *tb,
>                                        struct net *net, const unsigned short port,
>                                        int l3mdev, const struct sock *sk)
>  {
>  #if IS_ENABLED(CONFIG_IPV6)
> -       struct in6_addr nulladdr = {};
> +       struct in6_addr addr_any = {};
>
>         if (sk->sk_family == AF_INET6)
>                 return net_eq(ib2_net(tb), net) && tb->port == port &&
>                         tb->l3mdev == l3mdev &&
> -                       ipv6_addr_equal(&tb->v6_rcv_saddr, &nulladdr);
> +                       ipv6_addr_equal(&tb->v6_rcv_saddr, &addr_any);
>         else
>  #endif
>                 return net_eq(ib2_net(tb), net) && tb->port == port &&
>                         tb->l3mdev == l3mdev && tb->rcv_saddr == 0;
>  }
>
> -static struct inet_bind2_hashbucket *
> +struct inet_bind_hashbucket *
>  inet_bhashfn_portaddr(struct inet_hashinfo *hinfo, const struct sock *sk,
>                       const struct net *net, unsigned short port)
>  {
> @@ -803,55 +813,21 @@ inet_bhashfn_portaddr(struct inet_hashinfo *hinfo, const struct sock *sk,
>         return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)];
>  }
>
> -/* This should only be called when the spinlock for the socket's corresponding
> - * bind_hashbucket is held
> - */
> +/* The socket's bhash2 hashbucket spinlock must be held when this is called */
>  struct inet_bind2_bucket *
> -inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
> -                      const unsigned short port, int l3mdev, struct sock *sk,
> -                      struct inet_bind2_hashbucket **head)
> +inet_bind2_bucket_find(struct inet_bind_hashbucket *head,
> +                      struct inet_hashinfo *hinfo, struct net *net,
> +                      const unsigned short port, int l3mdev, struct sock *sk)
>  {
>         struct inet_bind2_bucket *bhash2 = NULL;
> -       struct inet_bind2_hashbucket *h;
>
> -       h = inet_bhashfn_portaddr(hinfo, sk, net, port);
> -       inet_bind_bucket_for_each(bhash2, &h->chain) {
> +       inet_bind_bucket_for_each(bhash2, &head->chain)
>                 if (check_bind2_bucket_match(bhash2, net, port, l3mdev, sk))
>                         break;
> -       }
> -
> -       if (head)
> -               *head = h;
>
>         return bhash2;
>  }
>
> -/* the lock for the socket's corresponding bhash entry must be held */
> -static int __inet_bhash2_update_saddr(struct sock *sk,
> -                                     struct inet_hashinfo *hinfo,
> -                                     struct net *net, int port, int l3mdev)
> -{
> -       struct inet_bind2_hashbucket *head2;
> -       struct inet_bind2_bucket *tb2;
> -
> -       tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
> -                                    &head2);
> -       if (!tb2) {
> -               tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
> -                                              net, head2, port, l3mdev, sk);
> -               if (!tb2)
> -                       return -ENOMEM;
> -       }
> -
> -       /* Remove the socket's old entry from bhash2 */
> -       __sk_del_bind2_node(sk);
> -
> -       sk_add_bind2_node(sk, &tb2->owners);
> -       inet_csk(sk)->icsk_bind2_hash = tb2;
> -
> -       return 0;
> -}
> -
>  /* This should be called if/when a socket's rcv saddr changes after it has
>   * been binded.
>   */
> @@ -862,17 +838,31 @@ int inet_bhash2_update_saddr(struct sock *sk)
>         struct inet_bind_hashbucket *head;
>         int port = inet_sk(sk)->inet_num;
>         struct net *net = sock_net(sk);
> -       int err;
> +       struct inet_bind2_bucket *tb2;
>
> -       head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)];
> +       head = inet_bhashfn_portaddr(hinfo, sk, net, port);
>
>         spin_lock_bh(&head->lock);
>
> -       err = __inet_bhash2_update_saddr(sk, hinfo, net, port, l3mdev);
> +       tb2 = inet_bind2_bucket_find(head, hinfo, net, port, l3mdev, sk);
> +       if (!tb2) {
> +               tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
> +                                              net, head, port, l3mdev, sk);
> +               if (!tb2) {
> +                       spin_unlock_bh(&head->lock);
> +                       return -ENOMEM;
> +               }
> +       }
> +
> +       /* Remove the socket's old entry from bhash2 */
> +       __sk_del_bind2_node(sk);
> +
> +       sk_add_bind2_node(sk, &tb2->owners);
> +       inet_csk(sk)->icsk_bind2_hash = tb2;
>
>         spin_unlock_bh(&head->lock);
>
> -       return err;
> +       return 0;
>  }
>
>  /* RFC 6056 3.3.4.  Algorithm 4: Double-Hash Port Selection Algorithm
> @@ -894,9 +884,8 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>                         struct sock *, __u16, struct inet_timewait_sock **))
>  {
>         struct inet_hashinfo *hinfo = death_row->hashinfo;
> +       struct inet_bind_hashbucket *head, *head2;
>         struct inet_timewait_sock *tw = NULL;
> -       struct inet_bind2_hashbucket *head2;
> -       struct inet_bind_hashbucket *head;
>         int port = inet_sk(sk)->inet_num;
>         struct net *net = sock_net(sk);
>         struct inet_bind2_bucket *tb2;
> @@ -907,8 +896,6 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>         int l3mdev;
>         u32 index;
>
> -       l3mdev = inet_sk_bound_l3mdev(sk);
> -
>         if (port) {
>                 head = &hinfo->bhash[inet_bhashfn(net, port,
>                                                   hinfo->bhash_size)];
> @@ -917,8 +904,7 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>                 spin_lock_bh(&head->lock);
>
>                 if (prev_inaddr_any) {
> -                       ret = __inet_bhash2_update_saddr(sk, hinfo, net, port,
> -                                                        l3mdev);
> +                       ret = inet_bhash2_update_saddr(sk);
>                         if (ret) {
>                                 spin_unlock_bh(&head->lock);
>                                 return ret;
> @@ -937,6 +923,8 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>                 return ret;
>         }
>
> +       l3mdev = inet_sk_bound_l3mdev(sk);
> +
>         inet_get_local_port_range(net, &low, &high);
>         high++; /* [32768, 60999] -> [32768, 61000[ */
>         remaining = high - low;
> @@ -1006,7 +994,10 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>         /* Find the corresponding tb2 bucket since we need to
>          * add the socket to the bhash2 table as well
>          */
> -       tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk, &head2);
> +       head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
> +       spin_lock(&head2->lock);
> +
> +       tb2 = inet_bind2_bucket_find(head2, hinfo, net, port, l3mdev, sk);
>         if (!tb2) {
>                 tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, net,
>                                                head2, port, l3mdev, sk);
> @@ -1024,6 +1015,9 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>
>         /* Head lock still held and bh's disabled */
>         inet_bind_hash(sk, tb, tb2, port);
> +
> +       spin_unlock(&head2->lock);
> +
>         if (sk_unhashed(sk)) {
>                 inet_sk(sk)->inet_sport = htons(port);
>                 inet_ehash_nolisten(sk, (struct sock *)tw, NULL);
> @@ -1037,6 +1031,7 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>         return 0;
>
>  error:
> +       spin_unlock_bh(&head2->lock);

Are you sure _bh has been used at spin_lock() side ?

>         if (tb_created)
>                 inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb);
>         spin_unlock_bh(&head->lock);
> diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
> index 9984d23a7f3e..1ebba8c27642 100644
> --- a/net/ipv4/tcp.c
> +++ b/net/ipv4/tcp.c
> @@ -4633,8 +4633,7 @@ void __init tcp_init(void)
>                 panic("TCP: failed to alloc ehash_locks");
>         tcp_hashinfo.bhash =
>                 alloc_large_system_hash("TCP bind bhash tables",
> -                                       sizeof(struct inet_bind_hashbucket) +
> -                                       sizeof(struct inet_bind2_hashbucket),
> +                                       2 * sizeof(struct inet_bind_hashbucket),
>                                         tcp_hashinfo.ehash_mask + 1,
>                                         17, /* one slot per 128 KB of memory */
>                                         0,
> @@ -4643,11 +4642,11 @@ void __init tcp_init(void)
>                                         0,
>                                         64 * 1024);
>         tcp_hashinfo.bhash_size = 1U << tcp_hashinfo.bhash_size;
> -       tcp_hashinfo.bhash2 =
> -               (struct inet_bind2_hashbucket *)(tcp_hashinfo.bhash + tcp_hashinfo.bhash_size);
> +       tcp_hashinfo.bhash2 = tcp_hashinfo.bhash + tcp_hashinfo.bhash_size;
>         for (i = 0; i < tcp_hashinfo.bhash_size; i++) {
>                 spin_lock_init(&tcp_hashinfo.bhash[i].lock);
>                 INIT_HLIST_HEAD(&tcp_hashinfo.bhash[i].chain);
> +               spin_lock_init(&tcp_hashinfo.bhash2[i].lock);
>                 INIT_HLIST_HEAD(&tcp_hashinfo.bhash2[i].chain);
>         }
>
> --
> 2.30.2
>

Powered by blists - more mailing lists

Powered by Openwall GNU/*/Linux Powered by OpenVZ