Skip to content

Commit

Permalink
Merge branch 'mptcp-refactor-token-container'
Browse files Browse the repository at this point in the history
Paolo Abeni says:

====================
mptcp: refactor token container

Currently the msk sockets are stored in a single radix tree, protected by a
global spin_lock. This series moves to an hash table, allocated at boot time,
with per bucker spin_lock - alike inet_hashtables, but using a different key:
the token itself.

The above improves scalability, as write operations will have a far later chance
to compete for lock acquisition, allows lockless lookup, and will allow
easier msk traversing - e.g. for diag interface implementation's sake.

This also introduces trivial, related, kunit tests and move the existing in
kernel's one to kunit.

v1 -> v2:
 - fixed a few extra and sparse warns
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
David S. Miller committed Jun 26, 2020
2 parents be7aa9f + a8ee9c9 commit e562d08
Show file tree
Hide file tree
Showing 11 changed files with 487 additions and 190 deletions.
20 changes: 14 additions & 6 deletions net/mptcp/Kconfig
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,20 @@ config MPTCP_IPV6
select IPV6
default y

config MPTCP_HMAC_TEST
bool "Tests for MPTCP HMAC implementation"
endif

config MPTCP_KUNIT_TESTS
tristate "This builds the MPTCP KUnit tests" if !KUNIT_ALL_TESTS
select MPTCP
depends on KUNIT
default KUNIT_ALL_TESTS
help
This option enable boot time self-test for the HMAC implementation
used by the MPTCP code
Currently covers the MPTCP crypto and token helpers.
Only useful for kernel devs running KUnit test harness and are not
for inclusion into a production build.

Say N if you are unsure.
For more information on KUnit and unit tests in general please refer
to the KUnit documentation in Documentation/dev-tools/kunit/.

If unsure, say N.

endif
4 changes: 4 additions & 0 deletions net/mptcp/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@ obj-$(CONFIG_MPTCP) += mptcp.o

mptcp-y := protocol.o subflow.o options.o token.o crypto.o ctrl.o pm.o diag.o \
mib.o pm_netlink.o

mptcp_crypto_test-objs := crypto_test.o
mptcp_token_test-objs := token_test.o
obj-$(CONFIG_MPTCP_KUNIT_TESTS) += mptcp_crypto_test.o mptcp_token_test.o
63 changes: 2 additions & 61 deletions net/mptcp/crypto.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,65 +87,6 @@ void mptcp_crypto_hmac_sha(u64 key1, u64 key2, u8 *msg, int len, void *hmac)
sha256_final(&state, (u8 *)hmac);
}

#ifdef CONFIG_MPTCP_HMAC_TEST
struct test_cast {
char *key;
char *msg;
char *result;
};

/* we can't reuse RFC 4231 test vectors, as we have constraint on the
* input and key size.
*/
static struct test_cast tests[] = {
{
.key = "0b0b0b0b0b0b0b0b",
.msg = "48692054",
.result = "8385e24fb4235ac37556b6b886db106284a1da671699f46db1f235ec622dcafa",
},
{
.key = "aaaaaaaaaaaaaaaa",
.msg = "dddddddd",
.result = "2c5e219164ff1dca1c4a92318d847bb6b9d44492984e1eb71aff9022f71046e9",
},
{
.key = "0102030405060708",
.msg = "cdcdcdcd",
.result = "e73b9ba9969969cefb04aa0d6df18ec2fcc075b6f23b4d8c4da736a5dbbc6e7d",
},
};

static int __init test_mptcp_crypto(void)
{
char hmac[32], hmac_hex[65];
u32 nonce1, nonce2;
u64 key1, key2;
u8 msg[8];
int i, j;

for (i = 0; i < ARRAY_SIZE(tests); ++i) {
/* mptcp hmap will convert to be before computing the hmac */
key1 = be64_to_cpu(*((__be64 *)&tests[i].key[0]));
key2 = be64_to_cpu(*((__be64 *)&tests[i].key[8]));
nonce1 = be32_to_cpu(*((__be32 *)&tests[i].msg[0]));
nonce2 = be32_to_cpu(*((__be32 *)&tests[i].msg[4]));

put_unaligned_be32(nonce1, &msg[0]);
put_unaligned_be32(nonce2, &msg[4]);

mptcp_crypto_hmac_sha(key1, key2, msg, 8, hmac);
for (j = 0; j < 32; ++j)
sprintf(&hmac_hex[j << 1], "%02x", hmac[j] & 0xff);
hmac_hex[64] = 0;

if (memcmp(hmac_hex, tests[i].result, 64))
pr_err("test %d failed, got %s expected %s", i,
hmac_hex, tests[i].result);
else
pr_info("test %d [ ok ]", i);
}
return 0;
}

late_initcall(test_mptcp_crypto);
#if IS_MODULE(CONFIG_MPTCP_KUNIT_TESTS)
EXPORT_SYMBOL_GPL(mptcp_crypto_hmac_sha);
#endif
72 changes: 72 additions & 0 deletions net/mptcp/crypto_test.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// SPDX-License-Identifier: GPL-2.0
#include <kunit/test.h>

#include "protocol.h"

struct test_case {
char *key;
char *msg;
char *result;
};

/* we can't reuse RFC 4231 test vectors, as we have constraint on the
* input and key size.
*/
static struct test_case tests[] = {
{
.key = "0b0b0b0b0b0b0b0b",
.msg = "48692054",
.result = "8385e24fb4235ac37556b6b886db106284a1da671699f46db1f235ec622dcafa",
},
{
.key = "aaaaaaaaaaaaaaaa",
.msg = "dddddddd",
.result = "2c5e219164ff1dca1c4a92318d847bb6b9d44492984e1eb71aff9022f71046e9",
},
{
.key = "0102030405060708",
.msg = "cdcdcdcd",
.result = "e73b9ba9969969cefb04aa0d6df18ec2fcc075b6f23b4d8c4da736a5dbbc6e7d",
},
};

static void mptcp_crypto_test_basic(struct kunit *test)
{
char hmac[32], hmac_hex[65];
u32 nonce1, nonce2;
u64 key1, key2;
u8 msg[8];
int i, j;

for (i = 0; i < ARRAY_SIZE(tests); ++i) {
/* mptcp hmap will convert to be before computing the hmac */
key1 = be64_to_cpu(*((__be64 *)&tests[i].key[0]));
key2 = be64_to_cpu(*((__be64 *)&tests[i].key[8]));
nonce1 = be32_to_cpu(*((__be32 *)&tests[i].msg[0]));
nonce2 = be32_to_cpu(*((__be32 *)&tests[i].msg[4]));

put_unaligned_be32(nonce1, &msg[0]);
put_unaligned_be32(nonce2, &msg[4]);

mptcp_crypto_hmac_sha(key1, key2, msg, 8, hmac);
for (j = 0; j < 32; ++j)
sprintf(&hmac_hex[j << 1], "%02x", hmac[j] & 0xff);
hmac_hex[64] = 0;

KUNIT_EXPECT_STREQ(test, &hmac_hex[0], tests[i].result);
}
}

static struct kunit_case mptcp_crypto_test_cases[] = {
KUNIT_CASE(mptcp_crypto_test_basic),
{}
};

static struct kunit_suite mptcp_crypto_suite = {
.name = "mptcp-crypto",
.test_cases = mptcp_crypto_test_cases,
};

kunit_test_suite(mptcp_crypto_suite);

MODULE_LICENSE("GPL");
2 changes: 1 addition & 1 deletion net/mptcp/pm.c
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ void mptcp_pm_close(struct mptcp_sock *msk)
sock_put((struct sock *)msk);
}

void mptcp_pm_init(void)
void __init mptcp_pm_init(void)
{
pm_wq = alloc_workqueue("pm_wq", WQ_UNBOUND | WQ_MEM_RECLAIM, 8);
if (!pm_wq)
Expand Down
2 changes: 1 addition & 1 deletion net/mptcp/pm_netlink.c
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ static struct pernet_operations mptcp_pm_pernet_ops = {
.size = sizeof(struct pm_nl_pernet),
};

void mptcp_pm_nl_init(void)
void __init mptcp_pm_nl_init(void)
{
if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0)
panic("Failed to register MPTCP PM pernet subsystem.\n");
Expand Down
49 changes: 28 additions & 21 deletions net/mptcp/protocol.c
Original file line number Diff line number Diff line change
Expand Up @@ -1448,20 +1448,6 @@ struct sock *mptcp_sk_clone(const struct sock *sk,
msk->token = subflow_req->token;
msk->subflow = NULL;

if (unlikely(mptcp_token_new_accept(subflow_req->token, nsk))) {
nsk->sk_state = TCP_CLOSE;
bh_unlock_sock(nsk);

/* we can't call into mptcp_close() here - possible BH context
* free the sock directly.
* sk_clone_lock() sets nsk refcnt to two, hence call sk_free()
* too.
*/
sk_common_release(nsk);
sk_free(nsk);
return NULL;
}

msk->write_seq = subflow_req->idsn + 1;
atomic64_set(&msk->snd_una, msk->write_seq);
if (mp_opt->mp_capable) {
Expand Down Expand Up @@ -1547,7 +1533,7 @@ static void mptcp_destroy(struct sock *sk)
{
struct mptcp_sock *msk = mptcp_sk(sk);

mptcp_token_destroy(msk->token);
mptcp_token_destroy(msk);
if (msk->cached_ext)
__skb_ext_put(msk->cached_ext);

Expand Down Expand Up @@ -1636,6 +1622,20 @@ static void mptcp_release_cb(struct sock *sk)
}
}

static int mptcp_hash(struct sock *sk)
{
/* should never be called,
* we hash the TCP subflows not the master socket
*/
WARN_ON_ONCE(1);
return 0;
}

static void mptcp_unhash(struct sock *sk)
{
/* called from sk_common_release(), but nothing to do here */
}

static int mptcp_get_port(struct sock *sk, unsigned short snum)
{
struct mptcp_sock *msk = mptcp_sk(sk);
Expand Down Expand Up @@ -1679,7 +1679,6 @@ void mptcp_finish_connect(struct sock *ssk)
*/
WRITE_ONCE(msk->remote_key, subflow->remote_key);
WRITE_ONCE(msk->local_key, subflow->local_key);
WRITE_ONCE(msk->token, subflow->token);
WRITE_ONCE(msk->write_seq, subflow->idsn + 1);
WRITE_ONCE(msk->ack_seq, ack_seq);
WRITE_ONCE(msk->can_ack, 1);
Expand Down Expand Up @@ -1761,8 +1760,8 @@ static struct proto mptcp_prot = {
.sendmsg = mptcp_sendmsg,
.recvmsg = mptcp_recvmsg,
.release_cb = mptcp_release_cb,
.hash = inet_hash,
.unhash = inet_unhash,
.hash = mptcp_hash,
.unhash = mptcp_unhash,
.get_port = mptcp_get_port,
.sockets_allocated = &mptcp_sockets_allocated,
.memory_allocated = &tcp_memory_allocated,
Expand All @@ -1771,6 +1770,7 @@ static struct proto mptcp_prot = {
.sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_tcp_wmem),
.sysctl_mem = sysctl_tcp_mem,
.obj_size = sizeof(struct mptcp_sock),
.slab_flags = SLAB_TYPESAFE_BY_RCU,
.no_autobind = true,
};

Expand Down Expand Up @@ -1800,6 +1800,7 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
int addr_len, int flags)
{
struct mptcp_sock *msk = mptcp_sk(sock->sk);
struct mptcp_subflow_context *subflow;
struct socket *ssock;
int err;

Expand All @@ -1812,19 +1813,23 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
goto do_connect;
}

mptcp_token_destroy(msk);
ssock = __mptcp_socket_create(msk, TCP_SYN_SENT);
if (IS_ERR(ssock)) {
err = PTR_ERR(ssock);
goto unlock;
}

subflow = mptcp_subflow_ctx(ssock->sk);
#ifdef CONFIG_TCP_MD5SIG
/* no MPTCP if MD5SIG is enabled on this socket or we may run out of
* TCP option space.
*/
if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info))
mptcp_subflow_ctx(ssock->sk)->request_mptcp = 0;
subflow->request_mptcp = 0;
#endif
if (subflow->request_mptcp && mptcp_token_new_connect(ssock->sk))
subflow->request_mptcp = 0;

do_connect:
err = ssock->ops->connect(ssock, uaddr, addr_len, flags);
Expand Down Expand Up @@ -1888,6 +1893,7 @@ static int mptcp_listen(struct socket *sock, int backlog)
pr_debug("msk=%p", msk);

lock_sock(sock->sk);
mptcp_token_destroy(msk);
ssock = __mptcp_socket_create(msk, TCP_LISTEN);
if (IS_ERR(ssock)) {
err = PTR_ERR(ssock);
Expand Down Expand Up @@ -2077,7 +2083,7 @@ static struct inet_protosw mptcp_protosw = {
.flags = INET_PROTOSW_ICSK,
};

void mptcp_proto_init(void)
void __init mptcp_proto_init(void)
{
mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo;

Expand All @@ -2086,6 +2092,7 @@ void mptcp_proto_init(void)

mptcp_subflow_init();
mptcp_pm_init();
mptcp_token_init();

if (proto_register(&mptcp_prot, 1) != 0)
panic("Failed to register MPTCP proto.\n");
Expand Down Expand Up @@ -2139,7 +2146,7 @@ static struct inet_protosw mptcp_v6_protosw = {
.flags = INET_PROTOSW_ICSK,
};

int mptcp_proto_v6_init(void)
int __init mptcp_proto_v6_init(void)
{
int err;

Expand Down
Loading

0 comments on commit e562d08

Please sign in to comment.