lists.openwall.net   lists  /  announce  owl-users  owl-dev  john-users  john-dev  passwdqc-users  yescrypt  popa3d-users  /  oss-security  kernel-hardening  musl  sabotage  tlsify  passwords  /  crypt-dev  xvendor  /  Bugtraq  Full-Disclosure  linux-kernel  linux-netdev  linux-ext4  linux-hardening  linux-cve-announce  PHC 
Open Source and information security mailing list archives
 
Hash Suite: Windows password security audit tool. GUI, reports in PDF.
[<prev] [next>] [<thread-prev] [thread-next>] [day] [month] [year] [list]
Message-id: <20180201000716.69301-12-cpaasch@apple.com>
Date:   Wed, 31 Jan 2018 16:07:13 -0800
From:   Christoph Paasch <cpaasch@...le.com>
To:     netdev@...r.kernel.org
Cc:     Eric Dumazet <edumazet@...gle.com>,
        Mat Martineau <mathew.j.martineau@...ux.intel.com>,
        Ivan Delalande <colona@...sta.com>
Subject: [RFC v2 11/14] tcp_md5: Move TCP-MD5 code out of TCP itself

This is all just copy-pasting the TCP_MD5-code into functions that are
placed in net/ipv4/tcp_md5.c.

Cc: Ivan Delalande <colona@...sta.com>
Signed-off-by: Christoph Paasch <cpaasch@...le.com>
Reviewed-by: Mat Martineau <mathew.j.martineau@...ux.intel.com>
---

Notes:
    v2: * Add SPDX-identifier (Mat Martineau's feedback)

 include/linux/inet_diag.h |    1 +
 include/linux/tcp_md5.h   |  137 ++++++
 include/net/tcp.h         |   77 ----
 net/ipv4/Makefile         |    1 +
 net/ipv4/tcp.c            |  133 +-----
 net/ipv4/tcp_diag.c       |   81 +---
 net/ipv4/tcp_input.c      |   38 --
 net/ipv4/tcp_ipv4.c       |  520 ++-------------------
 net/ipv4/tcp_md5.c        | 1103 +++++++++++++++++++++++++++++++++++++++++++++
 net/ipv4/tcp_minisocks.c  |   23 +-
 net/ipv4/tcp_output.c     |    4 +-
 net/ipv6/tcp_ipv6.c       |  318 +------------
 12 files changed, 1304 insertions(+), 1132 deletions(-)
 create mode 100644 include/linux/tcp_md5.h
 create mode 100644 net/ipv4/tcp_md5.c

diff --git a/include/linux/inet_diag.h b/include/linux/inet_diag.h
index 39faaaf843e1..1ef6727e41c9 100644
--- a/include/linux/inet_diag.h
+++ b/include/linux/inet_diag.h
@@ -2,6 +2,7 @@
 #ifndef _INET_DIAG_H_
 #define _INET_DIAG_H_ 1
 
+#include <linux/user_namespace.h>
 #include <uapi/linux/inet_diag.h>
 
 struct net;
diff --git a/include/linux/tcp_md5.h b/include/linux/tcp_md5.h
new file mode 100644
index 000000000000..d4a2175030d0
--- /dev/null
+++ b/include/linux/tcp_md5.h
@@ -0,0 +1,137 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+#ifndef _LINUX_TCP_MD5_H
+#define _LINUX_TCP_MD5_H
+
+#include <linux/skbuff.h>
+
+#ifdef CONFIG_TCP_MD5SIG
+#include <linux/types.h>
+
+#include <net/tcp.h>
+
+union tcp_md5_addr {
+	struct in_addr  a4;
+#if IS_ENABLED(CONFIG_IPV6)
+	struct in6_addr	a6;
+#endif
+};
+
+/* - key database */
+struct tcp_md5sig_key {
+	struct hlist_node	node;
+	u8			keylen;
+	u8			family; /* AF_INET or AF_INET6 */
+	union tcp_md5_addr	addr;
+	u8			prefixlen;
+	u8			key[TCP_MD5SIG_MAXKEYLEN];
+	struct rcu_head		rcu;
+};
+
+/* - sock block */
+struct tcp_md5sig_info {
+	struct hlist_head	head;
+	struct rcu_head		rcu;
+};
+
+union tcp_md5sum_block {
+	struct tcp4_pseudohdr ip4;
+#if IS_ENABLED(CONFIG_IPV6)
+	struct tcp6_pseudohdr ip6;
+#endif
+};
+
+/* - pool: digest algorithm, hash description and scratch buffer */
+struct tcp_md5sig_pool {
+	struct ahash_request	*md5_req;
+	void			*scratch;
+};
+
+extern const struct tcp_sock_af_ops tcp_sock_ipv4_specific;
+extern const struct tcp_sock_af_ops tcp_sock_ipv6_specific;
+extern const struct tcp_sock_af_ops tcp_sock_ipv6_mapped_specific;
+
+/* - functions */
+int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
+			const struct sock *sk, const struct sk_buff *skb);
+
+struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
+					 const struct sock *addr_sk);
+
+void tcp_v4_md5_destroy_sock(struct sock *sk);
+
+int tcp_v4_md5_send_response_prepare(struct sk_buff *skb, u8 flags,
+				     unsigned int remaining,
+				     struct tcp_out_options *opts,
+				     const struct sock *sk);
+
+void tcp_v4_md5_send_response_write(__be32 *topt, struct sk_buff *skb,
+				    struct tcphdr *t1,
+				    struct tcp_out_options *opts,
+				    const struct sock *sk);
+
+int tcp_v6_md5_send_response_prepare(struct sk_buff *skb, u8 flags,
+				     unsigned int remaining,
+				     struct tcp_out_options *opts,
+				     const struct sock *sk);
+
+void tcp_v6_md5_send_response_write(__be32 *topt, struct sk_buff *skb,
+				    struct tcphdr *t1,
+				    struct tcp_out_options *opts,
+				    const struct sock *sk);
+
+bool tcp_v4_inbound_md5_hash(const struct sock *sk,
+			     const struct sk_buff *skb);
+
+void tcp_v4_md5_syn_recv_sock(const struct sock *listener, struct sock *sk);
+
+void tcp_v6_md5_syn_recv_sock(const struct sock *listener, struct sock *sk);
+
+void tcp_md5_time_wait(struct sock *sk, struct inet_timewait_sock *tw);
+
+struct tcp_md5sig_key *tcp_v6_md5_lookup(const struct sock *sk,
+					 const struct sock *addr_sk);
+
+int tcp_v6_md5_hash_skb(char *md5_hash,
+			const struct tcp_md5sig_key *key,
+			const struct sock *sk,
+			const struct sk_buff *skb);
+
+bool tcp_v6_inbound_md5_hash(const struct sock *sk,
+			     const struct sk_buff *skb);
+
+static inline void tcp_md5_twsk_destructor(struct tcp_timewait_sock *twsk)
+{
+	if (twsk->tw_md5_key)
+		kfree_rcu(twsk->tw_md5_key, rcu);
+}
+
+static inline void tcp_md5_add_header_len(const struct sock *listener,
+					  struct sock *sk)
+{
+	struct tcp_sock *tp = tcp_sk(sk);
+
+	if (tp->af_specific->md5_lookup(listener, sk))
+		tp->tcp_header_len += TCPOLEN_MD5SIG_ALIGNED;
+}
+
+int tcp_md5_diag_get_aux(struct sock *sk, bool net_admin, struct sk_buff *skb);
+
+int tcp_md5_diag_get_aux_size(struct sock *sk, bool net_admin);
+
+#else
+
+static inline bool tcp_v4_inbound_md5_hash(const struct sock *sk,
+					   const struct sk_buff *skb)
+{
+	return false;
+}
+
+static inline bool tcp_v6_inbound_md5_hash(const struct sock *sk,
+					   const struct sk_buff *skb)
+{
+	return false;
+}
+
+#endif
+
+#endif /* _LINUX_TCP_MD5_H */
diff --git a/include/net/tcp.h b/include/net/tcp.h
index 2a565883e2ef..d2738cb01cf2 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -406,7 +406,6 @@ void tcp_parse_options(const struct net *net, const struct sk_buff *skb,
 		       struct tcp_options_received *opt_rx,
 		       int estab, struct tcp_fastopen_cookie *foc,
 		       struct sock *sk);
-const u8 *tcp_parse_md5sig_option(const struct tcphdr *th);
 
 /*
  *	TCP v4 functions exported for the inet6 API
@@ -1416,30 +1415,6 @@ static inline void tcp_clear_all_retrans_hints(struct tcp_sock *tp)
 	tp->retransmit_skb_hint = NULL;
 }
 
-union tcp_md5_addr {
-	struct in_addr  a4;
-#if IS_ENABLED(CONFIG_IPV6)
-	struct in6_addr	a6;
-#endif
-};
-
-/* - key database */
-struct tcp_md5sig_key {
-	struct hlist_node	node;
-	u8			keylen;
-	u8			family; /* AF_INET or AF_INET6 */
-	union tcp_md5_addr	addr;
-	u8			prefixlen;
-	u8			key[TCP_MD5SIG_MAXKEYLEN];
-	struct rcu_head		rcu;
-};
-
-/* - sock block */
-struct tcp_md5sig_info {
-	struct hlist_head	head;
-	struct rcu_head		rcu;
-};
-
 /* - pseudo header */
 struct tcp4_pseudohdr {
 	__be32		saddr;
@@ -1456,58 +1431,6 @@ struct tcp6_pseudohdr {
 	__be32		protocol;	/* including padding */
 };
 
-union tcp_md5sum_block {
-	struct tcp4_pseudohdr ip4;
-#if IS_ENABLED(CONFIG_IPV6)
-	struct tcp6_pseudohdr ip6;
-#endif
-};
-
-/* - pool: digest algorithm, hash description and scratch buffer */
-struct tcp_md5sig_pool {
-	struct ahash_request	*md5_req;
-	void			*scratch;
-};
-
-/* - functions */
-int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
-			const struct sock *sk, const struct sk_buff *skb);
-int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
-		   int family, u8 prefixlen, const u8 *newkey, u8 newkeylen,
-		   gfp_t gfp);
-int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr,
-		   int family, u8 prefixlen);
-struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
-					 const struct sock *addr_sk);
-
-#ifdef CONFIG_TCP_MD5SIG
-struct tcp_md5sig_key *tcp_md5_do_lookup(const struct sock *sk,
-					 const union tcp_md5_addr *addr,
-					 int family);
-#define tcp_twsk_md5_key(twsk)	((twsk)->tw_md5_key)
-#else
-static inline struct tcp_md5sig_key *tcp_md5_do_lookup(const struct sock *sk,
-					 const union tcp_md5_addr *addr,
-					 int family)
-{
-	return NULL;
-}
-#define tcp_twsk_md5_key(twsk)	NULL
-#endif
-
-bool tcp_alloc_md5sig_pool(void);
-
-struct tcp_md5sig_pool *tcp_get_md5sig_pool(void);
-static inline void tcp_put_md5sig_pool(void)
-{
-	local_bh_enable();
-}
-
-int tcp_md5_hash_skb_data(struct tcp_md5sig_pool *, const struct sk_buff *,
-			  unsigned int header_len);
-int tcp_md5_hash_key(struct tcp_md5sig_pool *hp,
-		     const struct tcp_md5sig_key *key);
-
 /* From tcp_fastopen.c */
 void tcp_fastopen_cache_get(struct sock *sk, u16 *mss,
 			    struct tcp_fastopen_cookie *cookie);
diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile
index 47a0a6649a9d..dd6bd3b29f5c 100644
--- a/net/ipv4/Makefile
+++ b/net/ipv4/Makefile
@@ -60,6 +60,7 @@ obj-$(CONFIG_TCP_CONG_LP) += tcp_lp.o
 obj-$(CONFIG_TCP_CONG_YEAH) += tcp_yeah.o
 obj-$(CONFIG_TCP_CONG_ILLINOIS) += tcp_illinois.o
 obj-$(CONFIG_NETLABEL) += cipso_ipv4.o
+obj-$(CONFIG_TCP_MD5SIG) += tcp_md5.o
 
 obj-$(CONFIG_XFRM) += xfrm4_policy.o xfrm4_state.o xfrm4_input.o \
 		      xfrm4_output.o xfrm4_protocol.o
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index f08542d91e1c..fc5c9cb19b9b 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -271,6 +271,7 @@
 #include <linux/slab.h>
 #include <linux/errqueue.h>
 #include <linux/static_key.h>
+#include <linux/tcp_md5.h>
 
 #include <net/icmp.h>
 #include <net/inet_common.h>
@@ -3370,138 +3371,6 @@ int compat_tcp_getsockopt(struct sock *sk, int level, int optname,
 EXPORT_SYMBOL(compat_tcp_getsockopt);
 #endif
 
-#ifdef CONFIG_TCP_MD5SIG
-static DEFINE_PER_CPU(struct tcp_md5sig_pool, tcp_md5sig_pool);
-static DEFINE_MUTEX(tcp_md5sig_mutex);
-static bool tcp_md5sig_pool_populated = false;
-
-static void __tcp_alloc_md5sig_pool(void)
-{
-	struct crypto_ahash *hash;
-	int cpu;
-
-	hash = crypto_alloc_ahash("md5", 0, CRYPTO_ALG_ASYNC);
-	if (IS_ERR(hash))
-		return;
-
-	for_each_possible_cpu(cpu) {
-		void *scratch = per_cpu(tcp_md5sig_pool, cpu).scratch;
-		struct ahash_request *req;
-
-		if (!scratch) {
-			scratch = kmalloc_node(sizeof(union tcp_md5sum_block) +
-					       sizeof(struct tcphdr),
-					       GFP_KERNEL,
-					       cpu_to_node(cpu));
-			if (!scratch)
-				return;
-			per_cpu(tcp_md5sig_pool, cpu).scratch = scratch;
-		}
-		if (per_cpu(tcp_md5sig_pool, cpu).md5_req)
-			continue;
-
-		req = ahash_request_alloc(hash, GFP_KERNEL);
-		if (!req)
-			return;
-
-		ahash_request_set_callback(req, 0, NULL, NULL);
-
-		per_cpu(tcp_md5sig_pool, cpu).md5_req = req;
-	}
-	/* before setting tcp_md5sig_pool_populated, we must commit all writes
-	 * to memory. See smp_rmb() in tcp_get_md5sig_pool()
-	 */
-	smp_wmb();
-	tcp_md5sig_pool_populated = true;
-}
-
-bool tcp_alloc_md5sig_pool(void)
-{
-	if (unlikely(!tcp_md5sig_pool_populated)) {
-		mutex_lock(&tcp_md5sig_mutex);
-
-		if (!tcp_md5sig_pool_populated)
-			__tcp_alloc_md5sig_pool();
-
-		mutex_unlock(&tcp_md5sig_mutex);
-	}
-	return tcp_md5sig_pool_populated;
-}
-EXPORT_SYMBOL(tcp_alloc_md5sig_pool);
-
-
-/**
- *	tcp_get_md5sig_pool - get md5sig_pool for this user
- *
- *	We use percpu structure, so if we succeed, we exit with preemption
- *	and BH disabled, to make sure another thread or softirq handling
- *	wont try to get same context.
- */
-struct tcp_md5sig_pool *tcp_get_md5sig_pool(void)
-{
-	local_bh_disable();
-
-	if (tcp_md5sig_pool_populated) {
-		/* coupled with smp_wmb() in __tcp_alloc_md5sig_pool() */
-		smp_rmb();
-		return this_cpu_ptr(&tcp_md5sig_pool);
-	}
-	local_bh_enable();
-	return NULL;
-}
-EXPORT_SYMBOL(tcp_get_md5sig_pool);
-
-int tcp_md5_hash_skb_data(struct tcp_md5sig_pool *hp,
-			  const struct sk_buff *skb, unsigned int header_len)
-{
-	struct scatterlist sg;
-	const struct tcphdr *tp = tcp_hdr(skb);
-	struct ahash_request *req = hp->md5_req;
-	unsigned int i;
-	const unsigned int head_data_len = skb_headlen(skb) > header_len ?
-					   skb_headlen(skb) - header_len : 0;
-	const struct skb_shared_info *shi = skb_shinfo(skb);
-	struct sk_buff *frag_iter;
-
-	sg_init_table(&sg, 1);
-
-	sg_set_buf(&sg, ((u8 *) tp) + header_len, head_data_len);
-	ahash_request_set_crypt(req, &sg, NULL, head_data_len);
-	if (crypto_ahash_update(req))
-		return 1;
-
-	for (i = 0; i < shi->nr_frags; ++i) {
-		const struct skb_frag_struct *f = &shi->frags[i];
-		unsigned int offset = f->page_offset;
-		struct page *page = skb_frag_page(f) + (offset >> PAGE_SHIFT);
-
-		sg_set_page(&sg, page, skb_frag_size(f),
-			    offset_in_page(offset));
-		ahash_request_set_crypt(req, &sg, NULL, skb_frag_size(f));
-		if (crypto_ahash_update(req))
-			return 1;
-	}
-
-	skb_walk_frags(skb, frag_iter)
-		if (tcp_md5_hash_skb_data(hp, frag_iter, 0))
-			return 1;
-
-	return 0;
-}
-EXPORT_SYMBOL(tcp_md5_hash_skb_data);
-
-int tcp_md5_hash_key(struct tcp_md5sig_pool *hp, const struct tcp_md5sig_key *key)
-{
-	struct scatterlist sg;
-
-	sg_init_one(&sg, key->key, key->keylen);
-	ahash_request_set_crypt(hp->md5_req, &sg, NULL, key->keylen);
-	return crypto_ahash_update(hp->md5_req);
-}
-EXPORT_SYMBOL(tcp_md5_hash_key);
-
-#endif
-
 struct hlist_head *tcp_extopt_get_list(const struct sock *sk)
 {
 	if (sk_fullsock(sk))
diff --git a/net/ipv4/tcp_diag.c b/net/ipv4/tcp_diag.c
index 81148f7a2323..82097a58976a 100644
--- a/net/ipv4/tcp_diag.c
+++ b/net/ipv4/tcp_diag.c
@@ -15,6 +15,7 @@
 #include <linux/inet_diag.h>
 
 #include <linux/tcp.h>
+#include <linux/tcp_md5.h>
 
 #include <net/netlink.h>
 #include <net/tcp.h>
@@ -37,70 +38,14 @@ static void tcp_diag_get_info(struct sock *sk, struct inet_diag_msg *r,
 		tcp_get_info(sk, info);
 }
 
-#ifdef CONFIG_TCP_MD5SIG
-static void tcp_diag_md5sig_fill(struct tcp_diag_md5sig *info,
-				 const struct tcp_md5sig_key *key)
-{
-	info->tcpm_family = key->family;
-	info->tcpm_prefixlen = key->prefixlen;
-	info->tcpm_keylen = key->keylen;
-	memcpy(info->tcpm_key, key->key, key->keylen);
-
-	if (key->family == AF_INET)
-		info->tcpm_addr[0] = key->addr.a4.s_addr;
-	#if IS_ENABLED(CONFIG_IPV6)
-	else if (key->family == AF_INET6)
-		memcpy(&info->tcpm_addr, &key->addr.a6,
-		       sizeof(info->tcpm_addr));
-	#endif
-}
-
-static int tcp_diag_put_md5sig(struct sk_buff *skb,
-			       const struct tcp_md5sig_info *md5sig)
-{
-	const struct tcp_md5sig_key *key;
-	struct tcp_diag_md5sig *info;
-	struct nlattr *attr;
-	int md5sig_count = 0;
-
-	hlist_for_each_entry_rcu(key, &md5sig->head, node)
-		md5sig_count++;
-	if (md5sig_count == 0)
-		return 0;
-
-	attr = nla_reserve(skb, INET_DIAG_MD5SIG,
-			   md5sig_count * sizeof(struct tcp_diag_md5sig));
-	if (!attr)
-		return -EMSGSIZE;
-
-	info = nla_data(attr);
-	memset(info, 0, md5sig_count * sizeof(struct tcp_diag_md5sig));
-	hlist_for_each_entry_rcu(key, &md5sig->head, node) {
-		tcp_diag_md5sig_fill(info++, key);
-		if (--md5sig_count == 0)
-			break;
-	}
-
-	return 0;
-}
-#endif
-
 static int tcp_diag_get_aux(struct sock *sk, bool net_admin,
 			    struct sk_buff *skb)
 {
 #ifdef CONFIG_TCP_MD5SIG
-	if (net_admin) {
-		struct tcp_md5sig_info *md5sig;
-		int err = 0;
-
-		rcu_read_lock();
-		md5sig = rcu_dereference(tcp_sk(sk)->md5sig_info);
-		if (md5sig)
-			err = tcp_diag_put_md5sig(skb, md5sig);
-		rcu_read_unlock();
-		if (err < 0)
-			return err;
-	}
+	int err = tcp_md5_diag_get_aux(sk, net_admin, skb);
+
+	if (err < 0)
+		return err;
 #endif
 
 	return 0;
@@ -111,21 +56,7 @@ static size_t tcp_diag_get_aux_size(struct sock *sk, bool net_admin)
 	size_t size = 0;
 
 #ifdef CONFIG_TCP_MD5SIG
-	if (net_admin && sk_fullsock(sk)) {
-		const struct tcp_md5sig_info *md5sig;
-		const struct tcp_md5sig_key *key;
-		size_t md5sig_count = 0;
-
-		rcu_read_lock();
-		md5sig = rcu_dereference(tcp_sk(sk)->md5sig_info);
-		if (md5sig) {
-			hlist_for_each_entry_rcu(key, &md5sig->head, node)
-				md5sig_count++;
-		}
-		rcu_read_unlock();
-		size += nla_total_size(md5sig_count *
-				       sizeof(struct tcp_diag_md5sig));
-	}
+	size += tcp_md5_diag_get_aux_size(sk, net_admin);
 #endif
 
 	return size;
diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c
index fd2693baee4a..1ac1d8d431ad 100644
--- a/net/ipv4/tcp_input.c
+++ b/net/ipv4/tcp_input.c
@@ -3867,44 +3867,6 @@ static bool tcp_fast_parse_options(const struct net *net,
 	return false;
 }
 
-#ifdef CONFIG_TCP_MD5SIG
-/*
- * Parse MD5 Signature option
- */
-const u8 *tcp_parse_md5sig_option(const struct tcphdr *th)
-{
-	int length = (th->doff << 2) - sizeof(*th);
-	const u8 *ptr = (const u8 *)(th + 1);
-
-	/* If the TCP option is too short, we can short cut */
-	if (length < TCPOLEN_MD5SIG)
-		return NULL;
-
-	while (length > 0) {
-		int opcode = *ptr++;
-		int opsize;
-
-		switch (opcode) {
-		case TCPOPT_EOL:
-			return NULL;
-		case TCPOPT_NOP:
-			length--;
-			continue;
-		default:
-			opsize = *ptr++;
-			if (opsize < 2 || opsize > length)
-				return NULL;
-			if (opcode == TCPOPT_MD5SIG)
-				return opsize == TCPOLEN_MD5SIG ? ptr : NULL;
-		}
-		ptr += opsize - 2;
-		length -= opsize;
-	}
-	return NULL;
-}
-EXPORT_SYMBOL(tcp_parse_md5sig_option);
-#endif
-
 /* Sorry, PAWS as specified is broken wrt. pure-ACKs -DaveM
  *
  * It is not fatal. If this ACK does _not_ change critical state (seqs, window)
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index 4211f8e38ef9..2446a4cb1749 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -62,6 +62,7 @@
 #include <linux/init.h>
 #include <linux/times.h>
 #include <linux/slab.h>
+#include <linux/tcp_md5.h>
 
 #include <net/net_namespace.h>
 #include <net/icmp.h>
@@ -87,11 +88,6 @@
 
 #include <trace/events/tcp.h>
 
-#ifdef CONFIG_TCP_MD5SIG
-static int tcp_v4_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key,
-			       __be32 daddr, __be32 saddr, const struct tcphdr *th);
-#endif
-
 struct inet_hashinfo tcp_hashinfo;
 EXPORT_SYMBOL(tcp_hashinfo);
 
@@ -603,16 +599,13 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb)
 		__be32 opt[(MAX_TCP_OPTION_SPACE >> 2)];
 	} rep;
 	struct hlist_head *extopt_list = NULL;
+	struct tcp_out_options opts;
 	struct ip_reply_arg arg;
-#ifdef CONFIG_TCP_MD5SIG
-	struct tcp_md5sig_key *key = NULL;
-	const __u8 *hash_location = NULL;
-	unsigned char newhash[16];
-	int genhash;
-	struct sock *sk1 = NULL;
-#endif
 	struct net *net;
 	int offset = 0;
+#ifdef CONFIG_TCP_MD5SIG
+	int ret;
+#endif
 
 	/* Never send a reset in response to a reset. */
 	if (th->rst)
@@ -627,6 +620,8 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb)
 	if (sk)
 		extopt_list = tcp_extopt_get_list(sk);
 
+	memset(&opts, 0, sizeof(opts));
+
 	/* Swap the send and the receive. */
 	memset(&rep, 0, sizeof(rep));
 	rep.th.dest   = th->source;
@@ -647,55 +642,28 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb)
 	arg.iov[0].iov_len  = sizeof(rep.th);
 
 	net = sk ? sock_net(sk) : dev_net(skb_dst(skb)->dev);
-#ifdef CONFIG_TCP_MD5SIG
-	rcu_read_lock();
-	hash_location = tcp_parse_md5sig_option(th);
-	if (sk && sk_fullsock(sk)) {
-		key = tcp_md5_do_lookup(sk, (union tcp_md5_addr *)
-					&ip_hdr(skb)->saddr, AF_INET);
-	} else if (hash_location) {
-		/*
-		 * active side is lost. Try to find listening socket through
-		 * source port, and then find md5 key through listening socket.
-		 * we are not loose security here:
-		 * Incoming packet is checked with md5 hash with finding key,
-		 * no RST generated if md5 hash doesn't match.
-		 */
-		sk1 = __inet_lookup_listener(net, &tcp_hashinfo, NULL, 0,
-					     ip_hdr(skb)->saddr,
-					     th->source, ip_hdr(skb)->daddr,
-					     ntohs(th->source), inet_iif(skb),
-					     tcp_v4_sdif(skb));
-		/* don't send rst if it can't find key */
-		if (!sk1)
-			goto out;
-
-		key = tcp_md5_do_lookup(sk1, (union tcp_md5_addr *)
-					&ip_hdr(skb)->saddr, AF_INET);
-		if (!key)
-			goto out;
 
+#ifdef CONFIG_TCP_MD5SIG
+	ret = tcp_v4_md5_send_response_prepare(skb, 0,
+					       MAX_TCP_OPTION_SPACE - arg.iov[0].iov_len,
+					       &opts, sk);
 
-		genhash = tcp_v4_md5_hash_skb(newhash, key, NULL, skb);
-		if (genhash || memcmp(hash_location, newhash, 16) != 0)
-			goto out;
+	if (ret == -1)
+		return;
 
-	}
+	arg.iov[0].iov_len += ret;
 #endif
 
 	if (unlikely(extopt_list && !hlist_empty(extopt_list))) {
 		unsigned int remaining;
-		struct tcp_out_options opts;
 		int used;
 
 		remaining = sizeof(rep.opt);
 #ifdef CONFIG_TCP_MD5SIG
-		if (key)
+		if (opts.md5)
 			remaining -= TCPOLEN_MD5SIG_ALIGNED;
 #endif
 
-		memset(&opts, 0, sizeof(opts));
-
 		used = tcp_extopt_response_prepare(skb, TCPHDR_RST, remaining,
 						   &opts, sk);
 
@@ -707,19 +675,7 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb)
 	}
 
 #ifdef CONFIG_TCP_MD5SIG
-	if (key) {
-		rep.opt[offset++] = htonl((TCPOPT_NOP << 24) |
-					  (TCPOPT_NOP << 16) |
-					  (TCPOPT_MD5SIG << 8) |
-					  TCPOLEN_MD5SIG);
-		/* Update length and the length the header thinks exists */
-		arg.iov[0].iov_len += TCPOLEN_MD5SIG_ALIGNED;
-		rep.th.doff = arg.iov[0].iov_len / 4;
-
-		tcp_v4_md5_hash_hdr((__u8 *)&rep.opt[offset],
-				    key, ip_hdr(skb)->saddr,
-				    ip_hdr(skb)->daddr, &rep.th);
-	}
+	tcp_v4_md5_send_response_write(&rep.opt[offset], skb, &rep.th, &opts, sk);
 #endif
 	arg.csum = csum_tcpudp_nofold(ip_hdr(skb)->daddr,
 				      ip_hdr(skb)->saddr, /* XXX */
@@ -750,11 +706,6 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb)
 	__TCP_INC_STATS(net, TCP_MIB_OUTSEGS);
 	__TCP_INC_STATS(net, TCP_MIB_OUTRSTS);
 	local_bh_enable();
-
-#ifdef CONFIG_TCP_MD5SIG
-out:
-	rcu_read_unlock();
-#endif
 }
 
 /* The code following below sending ACKs in SYN-RECV and TIME-WAIT states
@@ -772,17 +723,19 @@ static void tcp_v4_send_ack(const struct sock *sk,
 		__be32 opt[(MAX_TCP_OPTION_SPACE >> 2)];
 	} rep;
 	struct hlist_head *extopt_list = NULL;
-#ifdef CONFIG_TCP_MD5SIG
-	struct tcp_md5sig_key *key;
-#endif
+	struct tcp_out_options opts;
 	struct net *net = sock_net(sk);
 	struct ip_reply_arg arg;
 	int offset = 0;
+#ifdef CONFIG_TCP_MD5SIG
+	int ret;
+#endif
 
 	extopt_list = tcp_extopt_get_list(sk);
 
 	memset(&rep.th, 0, sizeof(struct tcphdr));
 	memset(&arg, 0, sizeof(arg));
+	memset(&opts, 0, sizeof(opts));
 
 	arg.iov[0].iov_base = (unsigned char *)&rep;
 	arg.iov[0].iov_len  = sizeof(rep.th);
@@ -806,25 +759,24 @@ static void tcp_v4_send_ack(const struct sock *sk,
 	rep.th.window  = htons(win);
 
 #ifdef CONFIG_TCP_MD5SIG
-	if (sk->sk_state == TCP_TIME_WAIT) {
-		key = tcp_twsk_md5_key(tcp_twsk(sk));
-	} else if (sk->sk_state == TCP_NEW_SYN_RECV) {
-		key = tcp_md5_do_lookup(sk, (union tcp_md5_addr *)&ip_hdr(skb)->saddr,
-					AF_INET);
-	} else {
-		key = NULL;     /* Should not happen */
-	}
+	ret = tcp_v4_md5_send_response_prepare(skb, 0,
+					       MAX_TCP_OPTION_SPACE - arg.iov[0].iov_len,
+					       &opts, sk);
+
+	if (ret == -1)
+		return;
+
+	arg.iov[0].iov_len += ret;
 #endif
 
 	if (unlikely(extopt_list && !hlist_empty(extopt_list))) {
 		unsigned int remaining;
-		struct tcp_out_options opts;
 		int used;
 
 		remaining = sizeof(rep.th) + sizeof(rep.opt) - arg.iov[0].iov_len;
 
 #ifdef CONFIG_TCP_MD5SIG
-		if (key)
+		if (opts.md5)
 			remaining -= TCPOLEN_MD5SIG_ALIGNED;
 #endif
 
@@ -841,18 +793,11 @@ static void tcp_v4_send_ack(const struct sock *sk,
 	}
 
 #ifdef CONFIG_TCP_MD5SIG
-	if (key) {
-		rep.opt[offset++] = htonl((TCPOPT_NOP << 24) |
-					  (TCPOPT_NOP << 16) |
-					  (TCPOPT_MD5SIG << 8) |
-					  TCPOLEN_MD5SIG);
+	if (opts.md5) {
 		arg.iov[0].iov_len += TCPOLEN_MD5SIG_ALIGNED;
 		rep.th.doff = arg.iov[0].iov_len / 4;
-
-		tcp_v4_md5_hash_hdr((__u8 *) &rep.opt[offset],
-				    key, ip_hdr(skb)->saddr,
-				    ip_hdr(skb)->daddr, &rep.th);
 	}
+	tcp_v4_md5_send_response_write(&rep.opt[offset], skb, &rep.th, &opts, sk);
 #endif
 
 	arg.flags = reply_flags;
@@ -961,374 +906,6 @@ static void tcp_v4_reqsk_destructor(struct request_sock *req)
 	kfree(rcu_dereference_protected(inet_rsk(req)->ireq_opt, 1));
 }
 
-#ifdef CONFIG_TCP_MD5SIG
-/*
- * RFC2385 MD5 checksumming requires a mapping of
- * IP address->MD5 Key.
- * We need to maintain these in the sk structure.
- */
-
-/* Find the Key structure for an address.  */
-struct tcp_md5sig_key *tcp_md5_do_lookup(const struct sock *sk,
-					 const union tcp_md5_addr *addr,
-					 int family)
-{
-	const struct tcp_sock *tp = tcp_sk(sk);
-	struct tcp_md5sig_key *key;
-	const struct tcp_md5sig_info *md5sig;
-	__be32 mask;
-	struct tcp_md5sig_key *best_match = NULL;
-	bool match;
-
-	/* caller either holds rcu_read_lock() or socket lock */
-	md5sig = rcu_dereference_check(tp->md5sig_info,
-				       lockdep_sock_is_held(sk));
-	if (!md5sig)
-		return NULL;
-
-	hlist_for_each_entry_rcu(key, &md5sig->head, node) {
-		if (key->family != family)
-			continue;
-
-		if (family == AF_INET) {
-			mask = inet_make_mask(key->prefixlen);
-			match = (key->addr.a4.s_addr & mask) ==
-				(addr->a4.s_addr & mask);
-#if IS_ENABLED(CONFIG_IPV6)
-		} else if (family == AF_INET6) {
-			match = ipv6_prefix_equal(&key->addr.a6, &addr->a6,
-						  key->prefixlen);
-#endif
-		} else {
-			match = false;
-		}
-
-		if (match && (!best_match ||
-			      key->prefixlen > best_match->prefixlen))
-			best_match = key;
-	}
-	return best_match;
-}
-EXPORT_SYMBOL(tcp_md5_do_lookup);
-
-static struct tcp_md5sig_key *tcp_md5_do_lookup_exact(const struct sock *sk,
-						      const union tcp_md5_addr *addr,
-						      int family, u8 prefixlen)
-{
-	const struct tcp_sock *tp = tcp_sk(sk);
-	struct tcp_md5sig_key *key;
-	unsigned int size = sizeof(struct in_addr);
-	const struct tcp_md5sig_info *md5sig;
-
-	/* caller either holds rcu_read_lock() or socket lock */
-	md5sig = rcu_dereference_check(tp->md5sig_info,
-				       lockdep_sock_is_held(sk));
-	if (!md5sig)
-		return NULL;
-#if IS_ENABLED(CONFIG_IPV6)
-	if (family == AF_INET6)
-		size = sizeof(struct in6_addr);
-#endif
-	hlist_for_each_entry_rcu(key, &md5sig->head, node) {
-		if (key->family != family)
-			continue;
-		if (!memcmp(&key->addr, addr, size) &&
-		    key->prefixlen == prefixlen)
-			return key;
-	}
-	return NULL;
-}
-
-struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
-					 const struct sock *addr_sk)
-{
-	const union tcp_md5_addr *addr;
-
-	addr = (const union tcp_md5_addr *)&addr_sk->sk_daddr;
-	return tcp_md5_do_lookup(sk, addr, AF_INET);
-}
-EXPORT_SYMBOL(tcp_v4_md5_lookup);
-
-/* This can be called on a newly created socket, from other files */
-int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
-		   int family, u8 prefixlen, const u8 *newkey, u8 newkeylen,
-		   gfp_t gfp)
-{
-	/* Add Key to the list */
-	struct tcp_md5sig_key *key;
-	struct tcp_sock *tp = tcp_sk(sk);
-	struct tcp_md5sig_info *md5sig;
-
-	key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen);
-	if (key) {
-		/* Pre-existing entry - just update that one. */
-		memcpy(key->key, newkey, newkeylen);
-		key->keylen = newkeylen;
-		return 0;
-	}
-
-	md5sig = rcu_dereference_protected(tp->md5sig_info,
-					   lockdep_sock_is_held(sk));
-	if (!md5sig) {
-		md5sig = kmalloc(sizeof(*md5sig), gfp);
-		if (!md5sig)
-			return -ENOMEM;
-
-		sk_nocaps_add(sk, NETIF_F_GSO_MASK);
-		INIT_HLIST_HEAD(&md5sig->head);
-		rcu_assign_pointer(tp->md5sig_info, md5sig);
-	}
-
-	key = sock_kmalloc(sk, sizeof(*key), gfp);
-	if (!key)
-		return -ENOMEM;
-	if (!tcp_alloc_md5sig_pool()) {
-		sock_kfree_s(sk, key, sizeof(*key));
-		return -ENOMEM;
-	}
-
-	memcpy(key->key, newkey, newkeylen);
-	key->keylen = newkeylen;
-	key->family = family;
-	key->prefixlen = prefixlen;
-	memcpy(&key->addr, addr,
-	       (family == AF_INET6) ? sizeof(struct in6_addr) :
-				      sizeof(struct in_addr));
-	hlist_add_head_rcu(&key->node, &md5sig->head);
-	return 0;
-}
-EXPORT_SYMBOL(tcp_md5_do_add);
-
-int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family,
-		   u8 prefixlen)
-{
-	struct tcp_md5sig_key *key;
-
-	key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen);
-	if (!key)
-		return -ENOENT;
-	hlist_del_rcu(&key->node);
-	atomic_sub(sizeof(*key), &sk->sk_omem_alloc);
-	kfree_rcu(key, rcu);
-	return 0;
-}
-EXPORT_SYMBOL(tcp_md5_do_del);
-
-static void tcp_clear_md5_list(struct sock *sk)
-{
-	struct tcp_sock *tp = tcp_sk(sk);
-	struct tcp_md5sig_key *key;
-	struct hlist_node *n;
-	struct tcp_md5sig_info *md5sig;
-
-	md5sig = rcu_dereference_protected(tp->md5sig_info, 1);
-
-	hlist_for_each_entry_safe(key, n, &md5sig->head, node) {
-		hlist_del_rcu(&key->node);
-		atomic_sub(sizeof(*key), &sk->sk_omem_alloc);
-		kfree_rcu(key, rcu);
-	}
-}
-
-static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
-				 char __user *optval, int optlen)
-{
-	struct tcp_md5sig cmd;
-	struct sockaddr_in *sin = (struct sockaddr_in *)&cmd.tcpm_addr;
-	u8 prefixlen = 32;
-
-	if (optlen < sizeof(cmd))
-		return -EINVAL;
-
-	if (copy_from_user(&cmd, optval, sizeof(cmd)))
-		return -EFAULT;
-
-	if (sin->sin_family != AF_INET)
-		return -EINVAL;
-
-	if (optname == TCP_MD5SIG_EXT &&
-	    cmd.tcpm_flags & TCP_MD5SIG_FLAG_PREFIX) {
-		prefixlen = cmd.tcpm_prefixlen;
-		if (prefixlen > 32)
-			return -EINVAL;
-	}
-
-	if (!cmd.tcpm_keylen)
-		return tcp_md5_do_del(sk, (union tcp_md5_addr *)&sin->sin_addr.s_addr,
-				      AF_INET, prefixlen);
-
-	if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN)
-		return -EINVAL;
-
-	return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin->sin_addr.s_addr,
-			      AF_INET, prefixlen, cmd.tcpm_key, cmd.tcpm_keylen,
-			      GFP_KERNEL);
-}
-
-static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp,
-				   __be32 daddr, __be32 saddr,
-				   const struct tcphdr *th, int nbytes)
-{
-	struct tcp4_pseudohdr *bp;
-	struct scatterlist sg;
-	struct tcphdr *_th;
-
-	bp = hp->scratch;
-	bp->saddr = saddr;
-	bp->daddr = daddr;
-	bp->pad = 0;
-	bp->protocol = IPPROTO_TCP;
-	bp->len = cpu_to_be16(nbytes);
-
-	_th = (struct tcphdr *)(bp + 1);
-	memcpy(_th, th, sizeof(*th));
-	_th->check = 0;
-
-	sg_init_one(&sg, bp, sizeof(*bp) + sizeof(*th));
-	ahash_request_set_crypt(hp->md5_req, &sg, NULL,
-				sizeof(*bp) + sizeof(*th));
-	return crypto_ahash_update(hp->md5_req);
-}
-
-static int tcp_v4_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key,
-			       __be32 daddr, __be32 saddr, const struct tcphdr *th)
-{
-	struct tcp_md5sig_pool *hp;
-	struct ahash_request *req;
-
-	hp = tcp_get_md5sig_pool();
-	if (!hp)
-		goto clear_hash_noput;
-	req = hp->md5_req;
-
-	if (crypto_ahash_init(req))
-		goto clear_hash;
-	if (tcp_v4_md5_hash_headers(hp, daddr, saddr, th, th->doff << 2))
-		goto clear_hash;
-	if (tcp_md5_hash_key(hp, key))
-		goto clear_hash;
-	ahash_request_set_crypt(req, NULL, md5_hash, 0);
-	if (crypto_ahash_final(req))
-		goto clear_hash;
-
-	tcp_put_md5sig_pool();
-	return 0;
-
-clear_hash:
-	tcp_put_md5sig_pool();
-clear_hash_noput:
-	memset(md5_hash, 0, 16);
-	return 1;
-}
-
-int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
-			const struct sock *sk,
-			const struct sk_buff *skb)
-{
-	struct tcp_md5sig_pool *hp;
-	struct ahash_request *req;
-	const struct tcphdr *th = tcp_hdr(skb);
-	__be32 saddr, daddr;
-
-	if (sk) { /* valid for establish/request sockets */
-		saddr = sk->sk_rcv_saddr;
-		daddr = sk->sk_daddr;
-	} else {
-		const struct iphdr *iph = ip_hdr(skb);
-		saddr = iph->saddr;
-		daddr = iph->daddr;
-	}
-
-	hp = tcp_get_md5sig_pool();
-	if (!hp)
-		goto clear_hash_noput;
-	req = hp->md5_req;
-
-	if (crypto_ahash_init(req))
-		goto clear_hash;
-
-	if (tcp_v4_md5_hash_headers(hp, daddr, saddr, th, skb->len))
-		goto clear_hash;
-	if (tcp_md5_hash_skb_data(hp, skb, th->doff << 2))
-		goto clear_hash;
-	if (tcp_md5_hash_key(hp, key))
-		goto clear_hash;
-	ahash_request_set_crypt(req, NULL, md5_hash, 0);
-	if (crypto_ahash_final(req))
-		goto clear_hash;
-
-	tcp_put_md5sig_pool();
-	return 0;
-
-clear_hash:
-	tcp_put_md5sig_pool();
-clear_hash_noput:
-	memset(md5_hash, 0, 16);
-	return 1;
-}
-EXPORT_SYMBOL(tcp_v4_md5_hash_skb);
-
-#endif
-
-/* Called with rcu_read_lock() */
-static bool tcp_v4_inbound_md5_hash(const struct sock *sk,
-				    const struct sk_buff *skb)
-{
-#ifdef CONFIG_TCP_MD5SIG
-	/*
-	 * This gets called for each TCP segment that arrives
-	 * so we want to be efficient.
-	 * We have 3 drop cases:
-	 * o No MD5 hash and one expected.
-	 * o MD5 hash and we're not expecting one.
-	 * o MD5 hash and its wrong.
-	 */
-	const __u8 *hash_location = NULL;
-	struct tcp_md5sig_key *hash_expected;
-	const struct iphdr *iph = ip_hdr(skb);
-	const struct tcphdr *th = tcp_hdr(skb);
-	int genhash;
-	unsigned char newhash[16];
-
-	hash_expected = tcp_md5_do_lookup(sk, (union tcp_md5_addr *)&iph->saddr,
-					  AF_INET);
-	hash_location = tcp_parse_md5sig_option(th);
-
-	/* We've parsed the options - do we have a hash? */
-	if (!hash_expected && !hash_location)
-		return false;
-
-	if (hash_expected && !hash_location) {
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
-		return true;
-	}
-
-	if (!hash_expected && hash_location) {
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
-		return true;
-	}
-
-	/* Okay, so this is hash_expected and hash_location -
-	 * so we need to calculate the checksum.
-	 */
-	genhash = tcp_v4_md5_hash_skb(newhash,
-				      hash_expected,
-				      NULL, skb);
-
-	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
-		net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s\n",
-				     &iph->saddr, ntohs(th->source),
-				     &iph->daddr, ntohs(th->dest),
-				     genhash ? " tcp_v4_calc_md5_hash failed"
-				     : "");
-		return true;
-	}
-	return false;
-#endif
-	return false;
-}
-
 static void tcp_v4_init_req(struct request_sock *req,
 			    const struct sock *sk_listener,
 			    struct sk_buff *skb)
@@ -1404,9 +981,6 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
 	struct inet_sock *newinet;
 	struct tcp_sock *newtp;
 	struct sock *newsk;
-#ifdef CONFIG_TCP_MD5SIG
-	struct tcp_md5sig_key *key;
-#endif
 	struct ip_options_rcu *inet_opt;
 
 	if (sk_acceptq_is_full(sk))
@@ -1453,20 +1027,7 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
 	tcp_initialize_rcv_mss(newsk);
 
 #ifdef CONFIG_TCP_MD5SIG
-	/* Copy over the MD5 key from the original socket */
-	key = tcp_md5_do_lookup(sk, (union tcp_md5_addr *)&newinet->inet_daddr,
-				AF_INET);
-	if (key) {
-		/*
-		 * We're using one, so create a matching key
-		 * on the newsk structure. If we fail to get
-		 * memory, then we end up not copying the key
-		 * across. Shucks.
-		 */
-		tcp_md5_do_add(newsk, (union tcp_md5_addr *)&newinet->inet_daddr,
-			       AF_INET, 32, key->key, key->keylen, GFP_ATOMIC);
-		sk_nocaps_add(newsk, NETIF_F_GSO_MASK);
-	}
+	tcp_v4_md5_syn_recv_sock(sk, newsk);
 #endif
 
 	if (__inet_inherit_port(sk, newsk) < 0)
@@ -1930,14 +1491,6 @@ const struct inet_connection_sock_af_ops ipv4_specific = {
 };
 EXPORT_SYMBOL(ipv4_specific);
 
-#ifdef CONFIG_TCP_MD5SIG
-static const struct tcp_sock_af_ops tcp_sock_ipv4_specific = {
-	.md5_lookup		= tcp_v4_md5_lookup,
-	.calc_md5_hash		= tcp_v4_md5_hash_skb,
-	.md5_parse		= tcp_v4_parse_md5_keys,
-};
-#endif
-
 /* NOTE: A lot of things set to zero explicitly by call to
  *       sk_alloc() so need not be done here.
  */
@@ -1980,12 +1533,7 @@ void tcp_v4_destroy_sock(struct sock *sk)
 	if (unlikely(!hlist_empty(&tp->tcp_option_list)))
 		tcp_extopt_destroy(sk);
 #ifdef CONFIG_TCP_MD5SIG
-	/* Clean up the MD5 key list, if any */
-	if (tp->md5sig_info) {
-		tcp_clear_md5_list(sk);
-		kfree_rcu(rcu_dereference_protected(tp->md5sig_info, 1), rcu);
-		tp->md5sig_info = NULL;
-	}
+	tcp_v4_md5_destroy_sock(sk);
 #endif
 
 	/* Clean up a referenced TCP bind bucket. */
diff --git a/net/ipv4/tcp_md5.c b/net/ipv4/tcp_md5.c
new file mode 100644
index 000000000000..d50580536978
--- /dev/null
+++ b/net/ipv4/tcp_md5.c
@@ -0,0 +1,1103 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+#include <linux/inet_diag.h>
+#include <linux/inetdevice.h>
+#include <linux/tcp.h>
+#include <linux/tcp_md5.h>
+
+#include <crypto/hash.h>
+
+#include <net/inet6_hashtables.h>
+
+static DEFINE_PER_CPU(struct tcp_md5sig_pool, tcp_md5sig_pool);
+static DEFINE_MUTEX(tcp_md5sig_mutex);
+static bool tcp_md5sig_pool_populated;
+
+#define tcp_twsk_md5_key(twsk)	((twsk)->tw_md5_key)
+
+static void __tcp_alloc_md5sig_pool(void)
+{
+	struct crypto_ahash *hash;
+	int cpu;
+
+	hash = crypto_alloc_ahash("md5", 0, CRYPTO_ALG_ASYNC);
+	if (IS_ERR(hash))
+		return;
+
+	for_each_possible_cpu(cpu) {
+		void *scratch = per_cpu(tcp_md5sig_pool, cpu).scratch;
+		struct ahash_request *req;
+
+		if (!scratch) {
+			scratch = kmalloc_node(sizeof(union tcp_md5sum_block) +
+					       sizeof(struct tcphdr),
+					       GFP_KERNEL,
+					       cpu_to_node(cpu));
+			if (!scratch)
+				return;
+			per_cpu(tcp_md5sig_pool, cpu).scratch = scratch;
+		}
+		if (per_cpu(tcp_md5sig_pool, cpu).md5_req)
+			continue;
+
+		req = ahash_request_alloc(hash, GFP_KERNEL);
+		if (!req)
+			return;
+
+		ahash_request_set_callback(req, 0, NULL, NULL);
+
+		per_cpu(tcp_md5sig_pool, cpu).md5_req = req;
+	}
+	/* before setting tcp_md5sig_pool_populated, we must commit all writes
+	 * to memory. See smp_rmb() in tcp_get_md5sig_pool()
+	 */
+	smp_wmb();
+	tcp_md5sig_pool_populated = true;
+}
+
+static bool tcp_alloc_md5sig_pool(void)
+{
+	if (unlikely(!tcp_md5sig_pool_populated)) {
+		mutex_lock(&tcp_md5sig_mutex);
+
+		if (!tcp_md5sig_pool_populated)
+			__tcp_alloc_md5sig_pool();
+
+		mutex_unlock(&tcp_md5sig_mutex);
+	}
+	return tcp_md5sig_pool_populated;
+}
+
+static void tcp_put_md5sig_pool(void)
+{
+	local_bh_enable();
+}
+
+/**
+ *	tcp_get_md5sig_pool - get md5sig_pool for this user
+ *
+ *	We use percpu structure, so if we succeed, we exit with preemption
+ *	and BH disabled, to make sure another thread or softirq handling
+ *	wont try to get same context.
+ */
+static struct tcp_md5sig_pool *tcp_get_md5sig_pool(void)
+{
+	local_bh_disable();
+
+	if (tcp_md5sig_pool_populated) {
+		/* coupled with smp_wmb() in __tcp_alloc_md5sig_pool() */
+		smp_rmb();
+		return this_cpu_ptr(&tcp_md5sig_pool);
+	}
+	local_bh_enable();
+	return NULL;
+}
+
+static struct tcp_md5sig_key *tcp_md5_do_lookup_exact(const struct sock *sk,
+						      const union tcp_md5_addr *addr,
+						      int family, u8 prefixlen)
+{
+	const struct tcp_sock *tp = tcp_sk(sk);
+	struct tcp_md5sig_key *key;
+	unsigned int size = sizeof(struct in_addr);
+	const struct tcp_md5sig_info *md5sig;
+
+	/* caller either holds rcu_read_lock() or socket lock */
+	md5sig = rcu_dereference_check(tp->md5sig_info,
+				       lockdep_sock_is_held(sk));
+	if (!md5sig)
+		return NULL;
+#if IS_ENABLED(CONFIG_IPV6)
+	if (family == AF_INET6)
+		size = sizeof(struct in6_addr);
+#endif
+	hlist_for_each_entry_rcu(key, &md5sig->head, node) {
+		if (key->family != family)
+			continue;
+		if (!memcmp(&key->addr, addr, size) &&
+		    key->prefixlen == prefixlen)
+			return key;
+	}
+	return NULL;
+}
+
+/* This can be called on a newly created socket, from other files */
+static int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
+			  int family, u8 prefixlen, const u8 *newkey,
+			  u8 newkeylen, gfp_t gfp)
+{
+	/* Add Key to the list */
+	struct tcp_md5sig_key *key;
+	struct tcp_sock *tp = tcp_sk(sk);
+	struct tcp_md5sig_info *md5sig;
+
+	key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen);
+	if (key) {
+		/* Pre-existing entry - just update that one. */
+		memcpy(key->key, newkey, newkeylen);
+		key->keylen = newkeylen;
+		return 0;
+	}
+
+	md5sig = rcu_dereference_protected(tp->md5sig_info,
+					   lockdep_sock_is_held(sk));
+	if (!md5sig) {
+		md5sig = kmalloc(sizeof(*md5sig), gfp);
+		if (!md5sig)
+			return -ENOMEM;
+
+		sk_nocaps_add(sk, NETIF_F_GSO_MASK);
+		INIT_HLIST_HEAD(&md5sig->head);
+		rcu_assign_pointer(tp->md5sig_info, md5sig);
+	}
+
+	key = sock_kmalloc(sk, sizeof(*key), gfp);
+	if (!key)
+		return -ENOMEM;
+	if (!tcp_alloc_md5sig_pool()) {
+		sock_kfree_s(sk, key, sizeof(*key));
+		return -ENOMEM;
+	}
+
+	memcpy(key->key, newkey, newkeylen);
+	key->keylen = newkeylen;
+	key->family = family;
+	key->prefixlen = prefixlen;
+	memcpy(&key->addr, addr,
+	       (family == AF_INET6) ? sizeof(struct in6_addr) :
+				      sizeof(struct in_addr));
+	hlist_add_head_rcu(&key->node, &md5sig->head);
+	return 0;
+}
+
+static void tcp_clear_md5_list(struct sock *sk)
+{
+	struct tcp_sock *tp = tcp_sk(sk);
+	struct tcp_md5sig_key *key;
+	struct hlist_node *n;
+	struct tcp_md5sig_info *md5sig;
+
+	md5sig = rcu_dereference_protected(tp->md5sig_info, 1);
+
+	hlist_for_each_entry_safe(key, n, &md5sig->head, node) {
+		hlist_del_rcu(&key->node);
+		atomic_sub(sizeof(*key), &sk->sk_omem_alloc);
+		kfree_rcu(key, rcu);
+	}
+}
+
+static int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr,
+			  int family, u8 prefixlen)
+{
+	struct tcp_md5sig_key *key;
+
+	key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen);
+	if (!key)
+		return -ENOENT;
+	hlist_del_rcu(&key->node);
+	atomic_sub(sizeof(*key), &sk->sk_omem_alloc);
+	kfree_rcu(key, rcu);
+	return 0;
+}
+
+static int tcp_md5_hash_key(struct tcp_md5sig_pool *hp,
+			    const struct tcp_md5sig_key *key)
+{
+	struct scatterlist sg;
+
+	sg_init_one(&sg, key->key, key->keylen);
+	ahash_request_set_crypt(hp->md5_req, &sg, NULL, key->keylen);
+	return crypto_ahash_update(hp->md5_req);
+}
+
+static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
+				 char __user *optval, int optlen)
+{
+	struct tcp_md5sig cmd;
+	struct sockaddr_in *sin = (struct sockaddr_in *)&cmd.tcpm_addr;
+	u8 prefixlen = 32;
+
+	if (optlen < sizeof(cmd))
+		return -EINVAL;
+
+	if (copy_from_user(&cmd, optval, sizeof(cmd)))
+		return -EFAULT;
+
+	if (sin->sin_family != AF_INET)
+		return -EINVAL;
+
+	if (optname == TCP_MD5SIG_EXT &&
+	    cmd.tcpm_flags & TCP_MD5SIG_FLAG_PREFIX) {
+		prefixlen = cmd.tcpm_prefixlen;
+		if (prefixlen > 32)
+			return -EINVAL;
+	}
+
+	if (!cmd.tcpm_keylen)
+		return tcp_md5_do_del(sk, (union tcp_md5_addr *)&sin->sin_addr.s_addr,
+				      AF_INET, prefixlen);
+
+	if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN)
+		return -EINVAL;
+
+	return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin->sin_addr.s_addr,
+			      AF_INET, prefixlen, cmd.tcpm_key, cmd.tcpm_keylen,
+			      GFP_KERNEL);
+}
+
+#if IS_ENABLED(CONFIG_IPV6)
+static int tcp_v6_parse_md5_keys(struct sock *sk, int optname,
+				 char __user *optval, int optlen)
+{
+	struct tcp_md5sig cmd;
+	struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&cmd.tcpm_addr;
+	u8 prefixlen;
+
+	if (optlen < sizeof(cmd))
+		return -EINVAL;
+
+	if (copy_from_user(&cmd, optval, sizeof(cmd)))
+		return -EFAULT;
+
+	if (sin6->sin6_family != AF_INET6)
+		return -EINVAL;
+
+	if (optname == TCP_MD5SIG_EXT &&
+	    cmd.tcpm_flags & TCP_MD5SIG_FLAG_PREFIX) {
+		prefixlen = cmd.tcpm_prefixlen;
+		if (prefixlen > 128 || (ipv6_addr_v4mapped(&sin6->sin6_addr) &&
+					prefixlen > 32))
+			return -EINVAL;
+	} else {
+		prefixlen = ipv6_addr_v4mapped(&sin6->sin6_addr) ? 32 : 128;
+	}
+
+	if (!cmd.tcpm_keylen) {
+		if (ipv6_addr_v4mapped(&sin6->sin6_addr))
+			return tcp_md5_do_del(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3],
+					      AF_INET, prefixlen);
+		return tcp_md5_do_del(sk, (union tcp_md5_addr *)&sin6->sin6_addr,
+				      AF_INET6, prefixlen);
+	}
+
+	if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN)
+		return -EINVAL;
+
+	if (ipv6_addr_v4mapped(&sin6->sin6_addr))
+		return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3],
+				      AF_INET, prefixlen, cmd.tcpm_key,
+				      cmd.tcpm_keylen, GFP_KERNEL);
+
+	return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr,
+			      AF_INET6, prefixlen, cmd.tcpm_key,
+			      cmd.tcpm_keylen, GFP_KERNEL);
+}
+#endif
+
+static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp,
+				   __be32 daddr, __be32 saddr,
+				   const struct tcphdr *th, int nbytes)
+{
+	struct tcp4_pseudohdr *bp;
+	struct scatterlist sg;
+	struct tcphdr *_th;
+
+	bp = hp->scratch;
+	bp->saddr = saddr;
+	bp->daddr = daddr;
+	bp->pad = 0;
+	bp->protocol = IPPROTO_TCP;
+	bp->len = cpu_to_be16(nbytes);
+
+	_th = (struct tcphdr *)(bp + 1);
+	memcpy(_th, th, sizeof(*th));
+	_th->check = 0;
+
+	sg_init_one(&sg, bp, sizeof(*bp) + sizeof(*th));
+	ahash_request_set_crypt(hp->md5_req, &sg, NULL,
+				sizeof(*bp) + sizeof(*th));
+	return crypto_ahash_update(hp->md5_req);
+}
+
+#if IS_ENABLED(CONFIG_IPV6)
+static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp,
+				   const struct in6_addr *daddr,
+				   const struct in6_addr *saddr,
+				   const struct tcphdr *th, int nbytes)
+{
+	struct tcp6_pseudohdr *bp;
+	struct scatterlist sg;
+	struct tcphdr *_th;
+
+	bp = hp->scratch;
+	/* 1. TCP pseudo-header (RFC2460) */
+	bp->saddr = *saddr;
+	bp->daddr = *daddr;
+	bp->protocol = cpu_to_be32(IPPROTO_TCP);
+	bp->len = cpu_to_be32(nbytes);
+
+	_th = (struct tcphdr *)(bp + 1);
+	memcpy(_th, th, sizeof(*th));
+	_th->check = 0;
+
+	sg_init_one(&sg, bp, sizeof(*bp) + sizeof(*th));
+	ahash_request_set_crypt(hp->md5_req, &sg, NULL,
+				sizeof(*bp) + sizeof(*th));
+	return crypto_ahash_update(hp->md5_req);
+}
+#endif
+
+static int tcp_v4_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key,
+			       __be32 daddr, __be32 saddr,
+			       const struct tcphdr *th)
+{
+	struct tcp_md5sig_pool *hp;
+	struct ahash_request *req;
+
+	hp = tcp_get_md5sig_pool();
+	if (!hp)
+		goto clear_hash_noput;
+	req = hp->md5_req;
+
+	if (crypto_ahash_init(req))
+		goto clear_hash;
+	if (tcp_v4_md5_hash_headers(hp, daddr, saddr, th, th->doff << 2))
+		goto clear_hash;
+	if (tcp_md5_hash_key(hp, key))
+		goto clear_hash;
+	ahash_request_set_crypt(req, NULL, md5_hash, 0);
+	if (crypto_ahash_final(req))
+		goto clear_hash;
+
+	tcp_put_md5sig_pool();
+	return 0;
+
+clear_hash:
+	tcp_put_md5sig_pool();
+clear_hash_noput:
+	memset(md5_hash, 0, 16);
+	return 1;
+}
+
+#if IS_ENABLED(CONFIG_IPV6)
+static int tcp_v6_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key,
+			       const struct in6_addr *daddr,
+			       struct in6_addr *saddr, const struct tcphdr *th)
+{
+	struct tcp_md5sig_pool *hp;
+	struct ahash_request *req;
+
+	hp = tcp_get_md5sig_pool();
+	if (!hp)
+		goto clear_hash_noput;
+	req = hp->md5_req;
+
+	if (crypto_ahash_init(req))
+		goto clear_hash;
+	if (tcp_v6_md5_hash_headers(hp, daddr, saddr, th, th->doff << 2))
+		goto clear_hash;
+	if (tcp_md5_hash_key(hp, key))
+		goto clear_hash;
+	ahash_request_set_crypt(req, NULL, md5_hash, 0);
+	if (crypto_ahash_final(req))
+		goto clear_hash;
+
+	tcp_put_md5sig_pool();
+	return 0;
+
+clear_hash:
+	tcp_put_md5sig_pool();
+clear_hash_noput:
+	memset(md5_hash, 0, 16);
+	return 1;
+}
+#endif
+
+/* RFC2385 MD5 checksumming requires a mapping of
+ * IP address->MD5 Key.
+ * We need to maintain these in the sk structure.
+ */
+
+/* Find the Key structure for an address.  */
+static struct tcp_md5sig_key *tcp_md5_do_lookup(const struct sock *sk,
+						const union tcp_md5_addr *addr,
+						int family)
+{
+	const struct tcp_sock *tp = tcp_sk(sk);
+	struct tcp_md5sig_key *key;
+	const struct tcp_md5sig_info *md5sig;
+	__be32 mask;
+	struct tcp_md5sig_key *best_match = NULL;
+	bool match;
+
+	/* caller either holds rcu_read_lock() or socket lock */
+	md5sig = rcu_dereference_check(tp->md5sig_info,
+				       lockdep_sock_is_held(sk));
+	if (!md5sig)
+		return NULL;
+
+	hlist_for_each_entry_rcu(key, &md5sig->head, node) {
+		if (key->family != family)
+			continue;
+
+		if (family == AF_INET) {
+			mask = inet_make_mask(key->prefixlen);
+			match = (key->addr.a4.s_addr & mask) ==
+				(addr->a4.s_addr & mask);
+#if IS_ENABLED(CONFIG_IPV6)
+		} else if (family == AF_INET6) {
+			match = ipv6_prefix_equal(&key->addr.a6, &addr->a6,
+						  key->prefixlen);
+#endif
+		} else {
+			match = false;
+		}
+
+		if (match && (!best_match ||
+			      key->prefixlen > best_match->prefixlen))
+			best_match = key;
+	}
+	return best_match;
+}
+
+/* Parse MD5 Signature option */
+static const u8 *tcp_parse_md5sig_option(const struct tcphdr *th)
+{
+	int length = (th->doff << 2) - sizeof(*th);
+	const u8 *ptr = (const u8 *)(th + 1);
+
+	/* If the TCP option is too short, we can short cut */
+	if (length < TCPOLEN_MD5SIG)
+		return NULL;
+
+	while (length > 0) {
+		int opcode = *ptr++;
+		int opsize;
+
+		switch (opcode) {
+		case TCPOPT_EOL:
+			return NULL;
+		case TCPOPT_NOP:
+			length--;
+			continue;
+		default:
+			opsize = *ptr++;
+			if (opsize < 2 || opsize > length)
+				return NULL;
+			if (opcode == TCPOPT_MD5SIG)
+				return opsize == TCPOLEN_MD5SIG ? ptr : NULL;
+		}
+		ptr += opsize - 2;
+		length -= opsize;
+	}
+	return NULL;
+}
+
+#if IS_ENABLED(CONFIG_IPV6)
+static struct tcp_md5sig_key *tcp_v6_md5_do_lookup(const struct sock *sk,
+						   const struct in6_addr *addr)
+{
+	return tcp_md5_do_lookup(sk, (union tcp_md5_addr *)addr, AF_INET6);
+}
+#endif
+
+static int tcp_md5_hash_skb_data(struct tcp_md5sig_pool *hp,
+				 const struct sk_buff *skb,
+				 unsigned int header_len)
+{
+	struct scatterlist sg;
+	const struct tcphdr *tp = tcp_hdr(skb);
+	struct ahash_request *req = hp->md5_req;
+	unsigned int i;
+	const unsigned int head_data_len = skb_headlen(skb) > header_len ?
+					   skb_headlen(skb) - header_len : 0;
+	const struct skb_shared_info *shi = skb_shinfo(skb);
+	struct sk_buff *frag_iter;
+
+	sg_init_table(&sg, 1);
+
+	sg_set_buf(&sg, ((u8 *)tp) + header_len, head_data_len);
+	ahash_request_set_crypt(req, &sg, NULL, head_data_len);
+	if (crypto_ahash_update(req))
+		return 1;
+
+	for (i = 0; i < shi->nr_frags; ++i) {
+		const struct skb_frag_struct *f = &shi->frags[i];
+		unsigned int offset = f->page_offset;
+		struct page *page = skb_frag_page(f) + (offset >> PAGE_SHIFT);
+
+		sg_set_page(&sg, page, skb_frag_size(f),
+			    offset_in_page(offset));
+		ahash_request_set_crypt(req, &sg, NULL, skb_frag_size(f));
+		if (crypto_ahash_update(req))
+			return 1;
+	}
+
+	skb_walk_frags(skb, frag_iter)
+		if (tcp_md5_hash_skb_data(hp, frag_iter, 0))
+			return 1;
+
+	return 0;
+}
+
+int tcp_v4_md5_send_response_prepare(struct sk_buff *skb, u8 flags,
+				     unsigned int remaining,
+				     struct tcp_out_options *opts,
+				     const struct sock *sk)
+{
+	const struct tcphdr *th = tcp_hdr(skb);
+	const struct iphdr *iph = ip_hdr(skb);
+	const __u8 *hash_location = NULL;
+
+	rcu_read_lock();
+	hash_location = tcp_parse_md5sig_option(th);
+	if (sk && sk_fullsock(sk)) {
+		opts->md5 = tcp_md5_do_lookup(sk,
+					      (union tcp_md5_addr *)&iph->saddr,
+					      AF_INET);
+	} else if (sk && sk->sk_state == TCP_TIME_WAIT) {
+		struct tcp_timewait_sock *tcptw = tcp_twsk(sk);
+
+		opts->md5 = tcp_twsk_md5_key(tcptw);
+	} else if (sk && sk->sk_state == TCP_NEW_SYN_RECV) {
+		opts->md5 = tcp_md5_do_lookup(sk,
+					      (union tcp_md5_addr *)&iph->saddr,
+					      AF_INET);
+	} else if (hash_location) {
+		unsigned char newhash[16];
+		struct sock *sk1;
+		int genhash;
+
+		/* active side is lost. Try to find listening socket through
+		 * source port, and then find md5 key through listening socket.
+		 * we are not loose security here:
+		 * Incoming packet is checked with md5 hash with finding key,
+		 * no RST generated if md5 hash doesn't match.
+		 */
+		sk1 = __inet_lookup_listener(dev_net(skb_dst(skb)->dev),
+					     &tcp_hashinfo, NULL, 0,
+					     iph->saddr,
+					     th->source, iph->daddr,
+					     ntohs(th->source), inet_iif(skb),
+					     tcp_v4_sdif(skb));
+		/* don't send rst if it can't find key */
+		if (!sk1)
+			goto out_err;
+
+		opts->md5 = tcp_md5_do_lookup(sk1, (union tcp_md5_addr *)
+					      &iph->saddr, AF_INET);
+		if (!opts->md5)
+			goto out_err;
+
+		genhash = tcp_v4_md5_hash_skb(newhash, opts->md5, NULL, skb);
+		if (genhash || memcmp(hash_location, newhash, 16) != 0)
+			goto out_err;
+	}
+
+	if (opts->md5)
+		return TCPOLEN_MD5SIG_ALIGNED;
+
+	rcu_read_unlock();
+	return 0;
+
+out_err:
+	rcu_read_unlock();
+	return -1;
+}
+
+void tcp_v4_md5_send_response_write(__be32 *topt, struct sk_buff *skb,
+				    struct tcphdr *t1,
+				    struct tcp_out_options *opts,
+				    const struct sock *sk)
+{
+	if (opts->md5) {
+		*topt++ = htonl((TCPOPT_NOP << 24) |
+				(TCPOPT_NOP << 16) |
+				(TCPOPT_MD5SIG << 8) |
+				TCPOLEN_MD5SIG);
+
+		tcp_v4_md5_hash_hdr((__u8 *)topt, opts->md5,
+				    ip_hdr(skb)->saddr,
+				    ip_hdr(skb)->daddr, t1);
+		rcu_read_unlock();
+	}
+}
+
+#if IS_ENABLED(CONFIG_IPV6)
+int tcp_v6_md5_send_response_prepare(struct sk_buff *skb, u8 flags,
+				     unsigned int remaining,
+				     struct tcp_out_options *opts,
+				     const struct sock *sk)
+{
+	const struct tcphdr *th = tcp_hdr(skb);
+	struct ipv6hdr *ipv6h = ipv6_hdr(skb);
+	const __u8 *hash_location = NULL;
+
+	rcu_read_lock();
+	hash_location = tcp_parse_md5sig_option(th);
+	if (sk && sk_fullsock(sk)) {
+		opts->md5 = tcp_v6_md5_do_lookup(sk, &ipv6h->saddr);
+	} else if (sk && sk->sk_state == TCP_TIME_WAIT) {
+		struct tcp_timewait_sock *tcptw = tcp_twsk(sk);
+
+		opts->md5 = tcp_twsk_md5_key(tcptw);
+	} else if (sk && sk->sk_state == TCP_NEW_SYN_RECV) {
+		opts->md5 = tcp_v6_md5_do_lookup(sk, &ipv6h->saddr);
+	} else if (hash_location) {
+		unsigned char newhash[16];
+		struct sock *sk1;
+		int genhash;
+
+		/* active side is lost. Try to find listening socket through
+		 * source port, and then find md5 key through listening socket.
+		 * we are not loose security here:
+		 * Incoming packet is checked with md5 hash with finding key,
+		 * no RST generated if md5 hash doesn't match.
+		 */
+		sk1 = inet6_lookup_listener(dev_net(skb_dst(skb)->dev),
+					    &tcp_hashinfo, NULL, 0,
+					    &ipv6h->saddr,
+					    th->source, &ipv6h->daddr,
+					    ntohs(th->source), tcp_v6_iif(skb),
+					    tcp_v6_sdif(skb));
+		if (!sk1)
+			goto out_err;
+
+		opts->md5 = tcp_v6_md5_do_lookup(sk1, &ipv6h->saddr);
+		if (!opts->md5)
+			goto out_err;
+
+		genhash = tcp_v6_md5_hash_skb(newhash, opts->md5, NULL, skb);
+		if (genhash || memcmp(hash_location, newhash, 16) != 0)
+			goto out_err;
+	}
+
+	if (opts->md5)
+		return TCPOLEN_MD5SIG_ALIGNED;
+
+	rcu_read_unlock();
+	return 0;
+
+out_err:
+	rcu_read_unlock();
+	return -1;
+}
+EXPORT_SYMBOL_GPL(tcp_v6_md5_send_response_prepare);
+
+void tcp_v6_md5_send_response_write(__be32 *topt, struct sk_buff *skb,
+				    struct tcphdr *t1,
+				    struct tcp_out_options *opts,
+				    const struct sock *sk)
+{
+	if (opts->md5) {
+		*topt++ = htonl((TCPOPT_NOP << 24) | (TCPOPT_NOP << 16) |
+				(TCPOPT_MD5SIG << 8) | TCPOLEN_MD5SIG);
+		tcp_v6_md5_hash_hdr((__u8 *)topt, opts->md5,
+				    &ipv6_hdr(skb)->saddr,
+				    &ipv6_hdr(skb)->daddr, t1);
+
+		rcu_read_unlock();
+	}
+}
+EXPORT_SYMBOL_GPL(tcp_v6_md5_send_response_write);
+#endif
+
+struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
+					 const struct sock *addr_sk)
+{
+	const union tcp_md5_addr *addr;
+
+	addr = (const union tcp_md5_addr *)&addr_sk->sk_daddr;
+	return tcp_md5_do_lookup(sk, addr, AF_INET);
+}
+EXPORT_SYMBOL(tcp_v4_md5_lookup);
+
+int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
+			const struct sock *sk,
+			const struct sk_buff *skb)
+{
+	struct tcp_md5sig_pool *hp;
+	struct ahash_request *req;
+	const struct tcphdr *th = tcp_hdr(skb);
+	__be32 saddr, daddr;
+
+	if (sk) { /* valid for establish/request sockets */
+		saddr = sk->sk_rcv_saddr;
+		daddr = sk->sk_daddr;
+	} else {
+		const struct iphdr *iph = ip_hdr(skb);
+
+		saddr = iph->saddr;
+		daddr = iph->daddr;
+	}
+
+	hp = tcp_get_md5sig_pool();
+	if (!hp)
+		goto clear_hash_noput;
+	req = hp->md5_req;
+
+	if (crypto_ahash_init(req))
+		goto clear_hash;
+
+	if (tcp_v4_md5_hash_headers(hp, daddr, saddr, th, skb->len))
+		goto clear_hash;
+	if (tcp_md5_hash_skb_data(hp, skb, th->doff << 2))
+		goto clear_hash;
+	if (tcp_md5_hash_key(hp, key))
+		goto clear_hash;
+	ahash_request_set_crypt(req, NULL, md5_hash, 0);
+	if (crypto_ahash_final(req))
+		goto clear_hash;
+
+	tcp_put_md5sig_pool();
+	return 0;
+
+clear_hash:
+	tcp_put_md5sig_pool();
+clear_hash_noput:
+	memset(md5_hash, 0, 16);
+	return 1;
+}
+EXPORT_SYMBOL(tcp_v4_md5_hash_skb);
+
+#if IS_ENABLED(CONFIG_IPV6)
+int tcp_v6_md5_hash_skb(char *md5_hash,
+			const struct tcp_md5sig_key *key,
+			const struct sock *sk,
+			const struct sk_buff *skb)
+{
+	const struct in6_addr *saddr, *daddr;
+	struct tcp_md5sig_pool *hp;
+	struct ahash_request *req;
+	const struct tcphdr *th = tcp_hdr(skb);
+
+	if (sk) { /* valid for establish/request sockets */
+		saddr = &sk->sk_v6_rcv_saddr;
+		daddr = &sk->sk_v6_daddr;
+	} else {
+		const struct ipv6hdr *ip6h = ipv6_hdr(skb);
+
+		saddr = &ip6h->saddr;
+		daddr = &ip6h->daddr;
+	}
+
+	hp = tcp_get_md5sig_pool();
+	if (!hp)
+		goto clear_hash_noput;
+	req = hp->md5_req;
+
+	if (crypto_ahash_init(req))
+		goto clear_hash;
+
+	if (tcp_v6_md5_hash_headers(hp, daddr, saddr, th, skb->len))
+		goto clear_hash;
+	if (tcp_md5_hash_skb_data(hp, skb, th->doff << 2))
+		goto clear_hash;
+	if (tcp_md5_hash_key(hp, key))
+		goto clear_hash;
+	ahash_request_set_crypt(req, NULL, md5_hash, 0);
+	if (crypto_ahash_final(req))
+		goto clear_hash;
+
+	tcp_put_md5sig_pool();
+	return 0;
+
+clear_hash:
+	tcp_put_md5sig_pool();
+clear_hash_noput:
+	memset(md5_hash, 0, 16);
+	return 1;
+}
+EXPORT_SYMBOL_GPL(tcp_v6_md5_hash_skb);
+#endif
+
+/* Called with rcu_read_lock() */
+bool tcp_v4_inbound_md5_hash(const struct sock *sk,
+			     const struct sk_buff *skb)
+{
+	/* This gets called for each TCP segment that arrives
+	 * so we want to be efficient.
+	 * We have 3 drop cases:
+	 * o No MD5 hash and one expected.
+	 * o MD5 hash and we're not expecting one.
+	 * o MD5 hash and its wrong.
+	 */
+	const __u8 *hash_location = NULL;
+	struct tcp_md5sig_key *hash_expected;
+	const struct iphdr *iph = ip_hdr(skb);
+	const struct tcphdr *th = tcp_hdr(skb);
+	int genhash;
+	unsigned char newhash[16];
+
+	hash_expected = tcp_md5_do_lookup(sk, (union tcp_md5_addr *)&iph->saddr,
+					  AF_INET);
+	hash_location = tcp_parse_md5sig_option(th);
+
+	/* We've parsed the options - do we have a hash? */
+	if (!hash_expected && !hash_location)
+		return false;
+
+	if (hash_expected && !hash_location) {
+		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
+		return true;
+	}
+
+	if (!hash_expected && hash_location) {
+		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
+		return true;
+	}
+
+	/* Okay, so this is hash_expected and hash_location -
+	 * so we need to calculate the checksum.
+	 */
+	genhash = tcp_v4_md5_hash_skb(newhash,
+				      hash_expected,
+				      NULL, skb);
+
+	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
+		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
+		net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s\n",
+				     &iph->saddr, ntohs(th->source),
+				     &iph->daddr, ntohs(th->dest),
+				     genhash ? " tcp_v4_calc_md5_hash failed"
+				     : "");
+		return true;
+	}
+	return false;
+}
+
+#if IS_ENABLED(CONFIG_IPV6)
+bool tcp_v6_inbound_md5_hash(const struct sock *sk,
+			     const struct sk_buff *skb)
+{
+	const __u8 *hash_location = NULL;
+	struct tcp_md5sig_key *hash_expected;
+	const struct ipv6hdr *ip6h = ipv6_hdr(skb);
+	const struct tcphdr *th = tcp_hdr(skb);
+	int genhash;
+	u8 newhash[16];
+
+	hash_expected = tcp_v6_md5_do_lookup(sk, &ip6h->saddr);
+	hash_location = tcp_parse_md5sig_option(th);
+
+	/* We've parsed the options - do we have a hash? */
+	if (!hash_expected && !hash_location)
+		return false;
+
+	if (hash_expected && !hash_location) {
+		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
+		return true;
+	}
+
+	if (!hash_expected && hash_location) {
+		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
+		return true;
+	}
+
+	/* check the signature */
+	genhash = tcp_v6_md5_hash_skb(newhash,
+				      hash_expected,
+				      NULL, skb);
+
+	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
+		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
+		net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u\n",
+				     genhash ? "failed" : "mismatch",
+				     &ip6h->saddr, ntohs(th->source),
+				     &ip6h->daddr, ntohs(th->dest));
+		return true;
+	}
+
+	return false;
+}
+EXPORT_SYMBOL_GPL(tcp_v6_inbound_md5_hash);
+#endif
+
+void tcp_v4_md5_destroy_sock(struct sock *sk)
+{
+	struct tcp_sock *tp = tcp_sk(sk);
+
+	/* Clean up the MD5 key list, if any */
+	if (tp->md5sig_info) {
+		tcp_clear_md5_list(sk);
+		kfree_rcu(rcu_dereference_protected(tp->md5sig_info, 1), rcu);
+		tp->md5sig_info = NULL;
+	}
+}
+
+void tcp_v4_md5_syn_recv_sock(const struct sock *listener, struct sock *sk)
+{
+	struct inet_sock *inet = inet_sk(sk);
+	struct tcp_md5sig_key *key;
+
+	/* Copy over the MD5 key from the original socket */
+	key = tcp_md5_do_lookup(listener, (union tcp_md5_addr *)&inet->inet_daddr,
+				AF_INET);
+	if (key) {
+		/* We're using one, so create a matching key
+		 * on the sk structure. If we fail to get
+		 * memory, then we end up not copying the key
+		 * across. Shucks.
+		 */
+		tcp_md5_do_add(sk, (union tcp_md5_addr *)&inet->inet_daddr,
+			       AF_INET, 32, key->key, key->keylen, GFP_ATOMIC);
+		sk_nocaps_add(sk, NETIF_F_GSO_MASK);
+	}
+}
+
+#if IS_ENABLED(CONFIG_IPV6)
+void tcp_v6_md5_syn_recv_sock(const struct sock *listener, struct sock *sk)
+{
+	struct tcp_md5sig_key *key;
+
+	/* Copy over the MD5 key from the original socket */
+	key = tcp_v6_md5_do_lookup(listener, &sk->sk_v6_daddr);
+	if (key) {
+		/* We're using one, so create a matching key
+		 * on the newsk structure. If we fail to get
+		 * memory, then we end up not copying the key
+		 * across. Shucks.
+		 */
+		tcp_md5_do_add(sk, (union tcp_md5_addr *)&sk->sk_v6_daddr,
+			       AF_INET6, 128, key->key, key->keylen,
+			       sk_gfp_mask(sk, GFP_ATOMIC));
+	}
+}
+EXPORT_SYMBOL_GPL(tcp_v6_md5_syn_recv_sock);
+
+struct tcp_md5sig_key *tcp_v6_md5_lookup(const struct sock *sk,
+					 const struct sock *addr_sk)
+{
+	return tcp_v6_md5_do_lookup(sk, &addr_sk->sk_v6_daddr);
+}
+EXPORT_SYMBOL_GPL(tcp_v6_md5_lookup);
+#endif
+
+void tcp_md5_time_wait(struct sock *sk, struct inet_timewait_sock *tw)
+{
+	struct tcp_timewait_sock *tcptw = tcp_twsk((struct sock *)tw);
+	struct tcp_sock *tp = tcp_sk(sk);
+	struct tcp_md5sig_key *key;
+
+	/* The timewait bucket does not have the key DB from the
+	 * sock structure. We just make a quick copy of the
+	 * md5 key being used (if indeed we are using one)
+	 * so the timewait ack generating code has the key.
+	 */
+	tcptw->tw_md5_key = NULL;
+	key = tp->af_specific->md5_lookup(sk, sk);
+	if (key) {
+		tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC);
+		BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool());
+	}
+}
+
+static void tcp_diag_md5sig_fill(struct tcp_diag_md5sig *info,
+				 const struct tcp_md5sig_key *key)
+{
+	info->tcpm_family = key->family;
+	info->tcpm_prefixlen = key->prefixlen;
+	info->tcpm_keylen = key->keylen;
+	memcpy(info->tcpm_key, key->key, key->keylen);
+
+	if (key->family == AF_INET)
+		info->tcpm_addr[0] = key->addr.a4.s_addr;
+	#if IS_ENABLED(CONFIG_IPV6)
+	else if (key->family == AF_INET6)
+		memcpy(&info->tcpm_addr, &key->addr.a6,
+		       sizeof(info->tcpm_addr));
+	#endif
+}
+
+static int tcp_diag_put_md5sig(struct sk_buff *skb,
+			       const struct tcp_md5sig_info *md5sig)
+{
+	const struct tcp_md5sig_key *key;
+	struct tcp_diag_md5sig *info;
+	struct nlattr *attr;
+	int md5sig_count = 0;
+
+	hlist_for_each_entry_rcu(key, &md5sig->head, node)
+		md5sig_count++;
+	if (md5sig_count == 0)
+		return 0;
+
+	attr = nla_reserve(skb, INET_DIAG_MD5SIG,
+			   md5sig_count * sizeof(struct tcp_diag_md5sig));
+	if (!attr)
+		return -EMSGSIZE;
+
+	info = nla_data(attr);
+	memset(info, 0, md5sig_count * sizeof(struct tcp_diag_md5sig));
+	hlist_for_each_entry_rcu(key, &md5sig->head, node) {
+		tcp_diag_md5sig_fill(info++, key);
+		if (--md5sig_count == 0)
+			break;
+	}
+
+	return 0;
+}
+
+int tcp_md5_diag_get_aux(struct sock *sk, bool net_admin, struct sk_buff *skb)
+{
+	if (net_admin) {
+		struct tcp_md5sig_info *md5sig;
+		int err = 0;
+
+		rcu_read_lock();
+		md5sig = rcu_dereference(tcp_sk(sk)->md5sig_info);
+		if (md5sig)
+			err = tcp_diag_put_md5sig(skb, md5sig);
+		rcu_read_unlock();
+		if (err < 0)
+			return err;
+	}
+
+	return 0;
+}
+EXPORT_SYMBOL_GPL(tcp_md5_diag_get_aux);
+
+int tcp_md5_diag_get_aux_size(struct sock *sk, bool net_admin)
+{
+	int size = 0;
+
+	if (net_admin && sk_fullsock(sk)) {
+		const struct tcp_md5sig_info *md5sig;
+		const struct tcp_md5sig_key *key;
+		size_t md5sig_count = 0;
+
+		rcu_read_lock();
+		md5sig = rcu_dereference(tcp_sk(sk)->md5sig_info);
+		if (md5sig) {
+			hlist_for_each_entry_rcu(key, &md5sig->head, node)
+				md5sig_count++;
+		}
+		rcu_read_unlock();
+		size += nla_total_size(md5sig_count *
+				       sizeof(struct tcp_diag_md5sig));
+	}
+
+	return size;
+}
+EXPORT_SYMBOL_GPL(tcp_md5_diag_get_aux_size);
+
+const struct tcp_sock_af_ops tcp_sock_ipv4_specific = {
+	.md5_lookup	= tcp_v4_md5_lookup,
+	.calc_md5_hash	= tcp_v4_md5_hash_skb,
+	.md5_parse	= tcp_v4_parse_md5_keys,
+};
+
+#if IS_ENABLED(CONFIG_IPV6)
+const struct tcp_sock_af_ops tcp_sock_ipv6_specific = {
+	.md5_lookup	=	tcp_v6_md5_lookup,
+	.calc_md5_hash	=	tcp_v6_md5_hash_skb,
+	.md5_parse	=	tcp_v6_parse_md5_keys,
+};
+EXPORT_SYMBOL_GPL(tcp_sock_ipv6_specific);
+
+const struct tcp_sock_af_ops tcp_sock_ipv6_mapped_specific = {
+	.md5_lookup	=	tcp_v4_md5_lookup,
+	.calc_md5_hash	=	tcp_v4_md5_hash_skb,
+	.md5_parse	=	tcp_v6_parse_md5_keys,
+};
+EXPORT_SYMBOL_GPL(tcp_sock_ipv6_mapped_specific);
+#endif
diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c
index 5e08dce49a00..072dbcebfbaf 100644
--- a/net/ipv4/tcp_minisocks.c
+++ b/net/ipv4/tcp_minisocks.c
@@ -22,6 +22,7 @@
 #include <linux/module.h>
 #include <linux/slab.h>
 #include <linux/sysctl.h>
+#include <linux/tcp_md5.h>
 #include <linux/workqueue.h>
 #include <linux/static_key.h>
 #include <net/tcp.h>
@@ -295,21 +296,7 @@ void tcp_time_wait(struct sock *sk, int state, int timeo)
 			INIT_HLIST_HEAD(&tp->tcp_option_list);
 		}
 #ifdef CONFIG_TCP_MD5SIG
-		/*
-		 * The timewait bucket does not have the key DB from the
-		 * sock structure. We just make a quick copy of the
-		 * md5 key being used (if indeed we are using one)
-		 * so the timewait ack generating code has the key.
-		 */
-		do {
-			struct tcp_md5sig_key *key;
-			tcptw->tw_md5_key = NULL;
-			key = tp->af_specific->md5_lookup(sk, sk);
-			if (key) {
-				tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC);
-				BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool());
-			}
-		} while (0);
+		tcp_md5_time_wait(sk, tw);
 #endif
 
 		/* Get the TIME_WAIT timeout firing. */
@@ -348,8 +335,7 @@ void tcp_twsk_destructor(struct sock *sk)
 	struct tcp_timewait_sock *twsk = tcp_twsk(sk);
 
 #ifdef CONFIG_TCP_MD5SIG
-	if (twsk->tw_md5_key)
-		kfree_rcu(twsk->tw_md5_key, rcu);
+	tcp_md5_twsk_destructor(twsk);
 #endif
 
 	if (unlikely(!hlist_empty(&twsk->tcp_option_list)))
@@ -538,8 +524,7 @@ struct sock *tcp_create_openreq_child(const struct sock *sk,
 		newtp->tsoffset = treq->ts_off;
 #ifdef CONFIG_TCP_MD5SIG
 		newtp->md5sig_info = NULL;	/*XXX*/
-		if (newtp->af_specific->md5_lookup(sk, newsk))
-			newtp->tcp_header_len += TCPOLEN_MD5SIG_ALIGNED;
+		tcp_md5_add_header_len(sk, newsk);
 #endif
 		if (unlikely(!hlist_empty(&treq->tcp_option_list)))
 			newtp->tcp_header_len += tcp_extopt_add_header(req_to_sk(req), newsk);
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index 97e6aecc03eb..c7fb7a0e1610 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -42,6 +42,7 @@
 #include <linux/gfp.h>
 #include <linux/module.h>
 #include <linux/static_key.h>
+#include <linux/tcp_md5.h>
 
 #include <trace/events/tcp.h>
 
@@ -3243,8 +3244,7 @@ static void tcp_connect_init(struct sock *sk)
 		tp->tcp_header_len += TCPOLEN_TSTAMP_ALIGNED;
 
 #ifdef CONFIG_TCP_MD5SIG
-	if (tp->af_specific->md5_lookup(sk, sk))
-		tp->tcp_header_len += TCPOLEN_MD5SIG_ALIGNED;
+	tcp_md5_add_header_len(sk, sk);
 #endif
 
 	if (unlikely(!hlist_empty(&tp->tcp_option_list)))
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index 8c6d0362299e..26b19475d91c 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -43,6 +43,7 @@
 #include <linux/ipv6.h>
 #include <linux/icmpv6.h>
 #include <linux/random.h>
+#include <linux/tcp_md5.h>
 
 #include <net/tcp.h>
 #include <net/ndisc.h>
@@ -79,10 +80,6 @@ static int	tcp_v6_do_rcv(struct sock *sk, struct sk_buff *skb);
 
 static const struct inet_connection_sock_af_ops ipv6_mapped;
 static const struct inet_connection_sock_af_ops ipv6_specific;
-#ifdef CONFIG_TCP_MD5SIG
-static const struct tcp_sock_af_ops tcp_sock_ipv6_specific;
-static const struct tcp_sock_af_ops tcp_sock_ipv6_mapped_specific;
-#endif
 
 static void inet6_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb)
 {
@@ -500,218 +497,6 @@ static void tcp_v6_reqsk_destructor(struct request_sock *req)
 	kfree_skb(inet_rsk(req)->pktopts);
 }
 
-#ifdef CONFIG_TCP_MD5SIG
-static struct tcp_md5sig_key *tcp_v6_md5_do_lookup(const struct sock *sk,
-						   const struct in6_addr *addr)
-{
-	return tcp_md5_do_lookup(sk, (union tcp_md5_addr *)addr, AF_INET6);
-}
-
-static struct tcp_md5sig_key *tcp_v6_md5_lookup(const struct sock *sk,
-						const struct sock *addr_sk)
-{
-	return tcp_v6_md5_do_lookup(sk, &addr_sk->sk_v6_daddr);
-}
-
-static int tcp_v6_parse_md5_keys(struct sock *sk, int optname,
-				 char __user *optval, int optlen)
-{
-	struct tcp_md5sig cmd;
-	struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&cmd.tcpm_addr;
-	u8 prefixlen;
-
-	if (optlen < sizeof(cmd))
-		return -EINVAL;
-
-	if (copy_from_user(&cmd, optval, sizeof(cmd)))
-		return -EFAULT;
-
-	if (sin6->sin6_family != AF_INET6)
-		return -EINVAL;
-
-	if (optname == TCP_MD5SIG_EXT &&
-	    cmd.tcpm_flags & TCP_MD5SIG_FLAG_PREFIX) {
-		prefixlen = cmd.tcpm_prefixlen;
-		if (prefixlen > 128 || (ipv6_addr_v4mapped(&sin6->sin6_addr) &&
-					prefixlen > 32))
-			return -EINVAL;
-	} else {
-		prefixlen = ipv6_addr_v4mapped(&sin6->sin6_addr) ? 32 : 128;
-	}
-
-	if (!cmd.tcpm_keylen) {
-		if (ipv6_addr_v4mapped(&sin6->sin6_addr))
-			return tcp_md5_do_del(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3],
-					      AF_INET, prefixlen);
-		return tcp_md5_do_del(sk, (union tcp_md5_addr *)&sin6->sin6_addr,
-				      AF_INET6, prefixlen);
-	}
-
-	if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN)
-		return -EINVAL;
-
-	if (ipv6_addr_v4mapped(&sin6->sin6_addr))
-		return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3],
-				      AF_INET, prefixlen, cmd.tcpm_key,
-				      cmd.tcpm_keylen, GFP_KERNEL);
-
-	return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr,
-			      AF_INET6, prefixlen, cmd.tcpm_key,
-			      cmd.tcpm_keylen, GFP_KERNEL);
-}
-
-static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp,
-				   const struct in6_addr *daddr,
-				   const struct in6_addr *saddr,
-				   const struct tcphdr *th, int nbytes)
-{
-	struct tcp6_pseudohdr *bp;
-	struct scatterlist sg;
-	struct tcphdr *_th;
-
-	bp = hp->scratch;
-	/* 1. TCP pseudo-header (RFC2460) */
-	bp->saddr = *saddr;
-	bp->daddr = *daddr;
-	bp->protocol = cpu_to_be32(IPPROTO_TCP);
-	bp->len = cpu_to_be32(nbytes);
-
-	_th = (struct tcphdr *)(bp + 1);
-	memcpy(_th, th, sizeof(*th));
-	_th->check = 0;
-
-	sg_init_one(&sg, bp, sizeof(*bp) + sizeof(*th));
-	ahash_request_set_crypt(hp->md5_req, &sg, NULL,
-				sizeof(*bp) + sizeof(*th));
-	return crypto_ahash_update(hp->md5_req);
-}
-
-static int tcp_v6_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key,
-			       const struct in6_addr *daddr, struct in6_addr *saddr,
-			       const struct tcphdr *th)
-{
-	struct tcp_md5sig_pool *hp;
-	struct ahash_request *req;
-
-	hp = tcp_get_md5sig_pool();
-	if (!hp)
-		goto clear_hash_noput;
-	req = hp->md5_req;
-
-	if (crypto_ahash_init(req))
-		goto clear_hash;
-	if (tcp_v6_md5_hash_headers(hp, daddr, saddr, th, th->doff << 2))
-		goto clear_hash;
-	if (tcp_md5_hash_key(hp, key))
-		goto clear_hash;
-	ahash_request_set_crypt(req, NULL, md5_hash, 0);
-	if (crypto_ahash_final(req))
-		goto clear_hash;
-
-	tcp_put_md5sig_pool();
-	return 0;
-
-clear_hash:
-	tcp_put_md5sig_pool();
-clear_hash_noput:
-	memset(md5_hash, 0, 16);
-	return 1;
-}
-
-static int tcp_v6_md5_hash_skb(char *md5_hash,
-			       const struct tcp_md5sig_key *key,
-			       const struct sock *sk,
-			       const struct sk_buff *skb)
-{
-	const struct in6_addr *saddr, *daddr;
-	struct tcp_md5sig_pool *hp;
-	struct ahash_request *req;
-	const struct tcphdr *th = tcp_hdr(skb);
-
-	if (sk) { /* valid for establish/request sockets */
-		saddr = &sk->sk_v6_rcv_saddr;
-		daddr = &sk->sk_v6_daddr;
-	} else {
-		const struct ipv6hdr *ip6h = ipv6_hdr(skb);
-		saddr = &ip6h->saddr;
-		daddr = &ip6h->daddr;
-	}
-
-	hp = tcp_get_md5sig_pool();
-	if (!hp)
-		goto clear_hash_noput;
-	req = hp->md5_req;
-
-	if (crypto_ahash_init(req))
-		goto clear_hash;
-
-	if (tcp_v6_md5_hash_headers(hp, daddr, saddr, th, skb->len))
-		goto clear_hash;
-	if (tcp_md5_hash_skb_data(hp, skb, th->doff << 2))
-		goto clear_hash;
-	if (tcp_md5_hash_key(hp, key))
-		goto clear_hash;
-	ahash_request_set_crypt(req, NULL, md5_hash, 0);
-	if (crypto_ahash_final(req))
-		goto clear_hash;
-
-	tcp_put_md5sig_pool();
-	return 0;
-
-clear_hash:
-	tcp_put_md5sig_pool();
-clear_hash_noput:
-	memset(md5_hash, 0, 16);
-	return 1;
-}
-
-#endif
-
-static bool tcp_v6_inbound_md5_hash(const struct sock *sk,
-				    const struct sk_buff *skb)
-{
-#ifdef CONFIG_TCP_MD5SIG
-	const __u8 *hash_location = NULL;
-	struct tcp_md5sig_key *hash_expected;
-	const struct ipv6hdr *ip6h = ipv6_hdr(skb);
-	const struct tcphdr *th = tcp_hdr(skb);
-	int genhash;
-	u8 newhash[16];
-
-	hash_expected = tcp_v6_md5_do_lookup(sk, &ip6h->saddr);
-	hash_location = tcp_parse_md5sig_option(th);
-
-	/* We've parsed the options - do we have a hash? */
-	if (!hash_expected && !hash_location)
-		return false;
-
-	if (hash_expected && !hash_location) {
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
-		return true;
-	}
-
-	if (!hash_expected && hash_location) {
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
-		return true;
-	}
-
-	/* check the signature */
-	genhash = tcp_v6_md5_hash_skb(newhash,
-				      hash_expected,
-				      NULL, skb);
-
-	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
-		net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u\n",
-				     genhash ? "failed" : "mismatch",
-				     &ip6h->saddr, ntohs(th->source),
-				     &ip6h->daddr, ntohs(th->dest));
-		return true;
-	}
-#endif
-	return false;
-}
-
 static void tcp_v6_init_req(struct request_sock *req,
 			    const struct sock *sk_listener,
 			    struct sk_buff *skb)
@@ -786,56 +571,24 @@ static void tcp_v6_send_response(const struct sock *sk, struct sk_buff *skb, u32
 	__be32 *topt;
 	struct hlist_head *extopt_list = NULL;
 	struct tcp_out_options extraopts;
-#ifdef CONFIG_TCP_MD5SIG
-	struct tcp_md5sig_key *key = NULL;
-	const __u8 *hash_location = NULL;
-	struct ipv6hdr *ipv6h = ipv6_hdr(skb);
-#endif
+
+	memset(&extraopts, 0, sizeof(extraopts));
 
 	if (tsecr)
 		tot_len += TCPOLEN_TSTAMP_ALIGNED;
 #ifdef CONFIG_TCP_MD5SIG
-	rcu_read_lock();
-	hash_location = tcp_parse_md5sig_option(th);
-	if (sk && sk_fullsock(sk)) {
-		key = tcp_v6_md5_do_lookup(sk, &ipv6h->saddr);
-	} else if (sk && sk->sk_state == TCP_TIME_WAIT) {
-		struct tcp_timewait_sock *tcptw = tcp_twsk(sk);
-
-		key = tcp_twsk_md5_key(tcptw);
-	} else if (sk && sk->sk_state == TCP_NEW_SYN_RECV) {
-		key = tcp_v6_md5_do_lookup(sk, &ipv6h->saddr);
-	} else if (hash_location) {
-		unsigned char newhash[16];
-		struct sock *sk1 = NULL;
-		int genhash;
-
-		/* active side is lost. Try to find listening socket through
-		 * source port, and then find md5 key through listening socket.
-		 * we are not loose security here:
-		 * Incoming packet is checked with md5 hash with finding key,
-		 * no RST generated if md5 hash doesn't match.
-		 */
-		sk1 = inet6_lookup_listener(dev_net(skb_dst(skb)->dev),
-					    &tcp_hashinfo, NULL, 0,
-					    &ipv6h->saddr,
-					    th->source, &ipv6h->daddr,
-					    ntohs(th->source), tcp_v6_iif(skb),
-					    tcp_v6_sdif(skb));
-		if (!sk1)
-			goto out;
+{
+	int ret;
 
-		key = tcp_v6_md5_do_lookup(sk1, &ipv6h->saddr);
-		if (!key)
-			goto out;
+	ret = tcp_v6_md5_send_response_prepare(skb, 0,
+					       MAX_TCP_OPTION_SPACE - tot_len,
+					       &extraopts, sk);
 
-		genhash = tcp_v6_md5_hash_skb(newhash, key, NULL, skb);
-		if (genhash || memcmp(hash_location, newhash, 16) != 0)
-			goto out;
-	}
+	if (ret == -1)
+		goto out;
 
-	if (key)
-		tot_len += TCPOLEN_MD5SIG_ALIGNED;
+	tot_len += ret;
+}
 #endif
 
 	if (sk)
@@ -849,8 +602,6 @@ static void tcp_v6_send_response(const struct sock *sk, struct sk_buff *skb, u32
 		if (!rst || !th->ack)
 			extraflags |= TCPHDR_ACK;
 
-		memset(&extraopts, 0, sizeof(extraopts));
-
 		used = tcp_extopt_response_prepare(skb, extraflags, remaining,
 						   &extraopts, sk);
 
@@ -888,13 +639,8 @@ static void tcp_v6_send_response(const struct sock *sk, struct sk_buff *skb, u32
 	}
 
 #ifdef CONFIG_TCP_MD5SIG
-	if (key) {
-		*topt++ = htonl((TCPOPT_NOP << 24) | (TCPOPT_NOP << 16) |
-				(TCPOPT_MD5SIG << 8) | TCPOLEN_MD5SIG);
-		tcp_v6_md5_hash_hdr((__u8 *)topt, key,
-				    &ipv6_hdr(skb)->saddr,
-				    &ipv6_hdr(skb)->daddr, t1);
-	}
+	if (extraopts.md5)
+		tcp_v6_md5_send_response_write(topt, skb, t1, &extraopts, sk);
 #endif
 
 	if (unlikely(extopt_list && !hlist_empty(extopt_list)))
@@ -942,10 +688,6 @@ static void tcp_v6_send_response(const struct sock *sk, struct sk_buff *skb, u32
 
 out:
 	kfree_skb(buff);
-
-#ifdef CONFIG_TCP_MD5SIG
-	rcu_read_unlock();
-#endif
 }
 
 static void tcp_v6_send_reset(const struct sock *sk, struct sk_buff *skb)
@@ -1071,9 +813,6 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
 	struct inet_sock *newinet;
 	struct tcp_sock *newtp;
 	struct sock *newsk;
-#ifdef CONFIG_TCP_MD5SIG
-	struct tcp_md5sig_key *key;
-#endif
 	struct flowi6 fl6;
 
 	if (skb->protocol == htons(ETH_P_IP)) {
@@ -1218,18 +957,7 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
 	newinet->inet_rcv_saddr = LOOPBACK4_IPV6;
 
 #ifdef CONFIG_TCP_MD5SIG
-	/* Copy over the MD5 key from the original socket */
-	key = tcp_v6_md5_do_lookup(sk, &newsk->sk_v6_daddr);
-	if (key) {
-		/* We're using one, so create a matching key
-		 * on the newsk structure. If we fail to get
-		 * memory, then we end up not copying the key
-		 * across. Shucks.
-		 */
-		tcp_md5_do_add(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr,
-			       AF_INET6, 128, key->key, key->keylen,
-			       sk_gfp_mask(sk, GFP_ATOMIC));
-	}
+	tcp_v6_md5_syn_recv_sock(sk, newsk);
 #endif
 
 	if (__inet_inherit_port(sk, newsk) < 0) {
@@ -1691,14 +1419,6 @@ static const struct inet_connection_sock_af_ops ipv6_specific = {
 	.mtu_reduced	   = tcp_v6_mtu_reduced,
 };
 
-#ifdef CONFIG_TCP_MD5SIG
-static const struct tcp_sock_af_ops tcp_sock_ipv6_specific = {
-	.md5_lookup	=	tcp_v6_md5_lookup,
-	.calc_md5_hash	=	tcp_v6_md5_hash_skb,
-	.md5_parse	=	tcp_v6_parse_md5_keys,
-};
-#endif
-
 /*
  *	TCP over IPv4 via INET6 API
  */
@@ -1721,14 +1441,6 @@ static const struct inet_connection_sock_af_ops ipv6_mapped = {
 	.mtu_reduced	   = tcp_v4_mtu_reduced,
 };
 
-#ifdef CONFIG_TCP_MD5SIG
-static const struct tcp_sock_af_ops tcp_sock_ipv6_mapped_specific = {
-	.md5_lookup	=	tcp_v4_md5_lookup,
-	.calc_md5_hash	=	tcp_v4_md5_hash_skb,
-	.md5_parse	=	tcp_v6_parse_md5_keys,
-};
-#endif
-
 /* NOTE: A lot of things set to zero explicitly by call to
  *       sk_alloc() so need not be done here.
  */
-- 
2.16.1

Powered by blists - more mailing lists

Powered by Openwall GNU/*/Linux Powered by OpenVZ