Skip to content

Commit

Permalink
net/tls: Split conf to rx + tx
Browse files Browse the repository at this point in the history
In TLS inline crypto, we can have one direction in software
and another in hardware. Thus, we split the TLS configuration to separate
structures for receive and transmit.

Signed-off-by: Boris Pismenny <borisp@mellanox.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
Boris Pismenny authored and David S. Miller committed May 1, 2018
1 parent 2342a85 commit f66de3e
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 129 deletions.
51 changes: 33 additions & 18 deletions include/net/tls.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,10 @@ struct tls_device {
void (*unhash)(struct tls_device *device, struct sock *sk);
};

struct tls_sw_context {
struct tls_sw_context_tx {
struct crypto_aead *aead_send;
struct crypto_aead *aead_recv;
struct crypto_wait async_wait;

/* Receive context */
struct strparser strp;
void (*saved_data_ready)(struct sock *sk);
unsigned int (*sk_poll)(struct file *file, struct socket *sock,
struct poll_table_struct *wait);
struct sk_buff *recv_pkt;
u8 control;
bool decrypted;

/* Sending context */
char aad_space[TLS_AAD_SPACE_SIZE];

unsigned int sg_plaintext_size;
Expand All @@ -114,6 +103,19 @@ struct tls_sw_context {
struct scatterlist sg_aead_out[2];
};

struct tls_sw_context_rx {
struct crypto_aead *aead_recv;
struct crypto_wait async_wait;

struct strparser strp;
void (*saved_data_ready)(struct sock *sk);
unsigned int (*sk_poll)(struct file *file, struct socket *sock,
struct poll_table_struct *wait);
struct sk_buff *recv_pkt;
u8 control;
bool decrypted;
};

enum {
TLS_PENDING_CLOSED_RECORD
};
Expand All @@ -138,9 +140,15 @@ struct tls_context {
struct tls12_crypto_info_aes_gcm_128 crypto_recv_aes_gcm_128;
};

void *priv_ctx;
struct list_head list;
struct net_device *netdev;
refcount_t refcount;

void *priv_ctx_tx;
void *priv_ctx_rx;

u8 conf:3;
u8 tx_conf:3;
u8 rx_conf:3;

struct cipher_context tx;
struct cipher_context rx;
Expand Down Expand Up @@ -177,7 +185,8 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
int tls_sw_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags);
void tls_sw_close(struct sock *sk, long timeout);
void tls_sw_free_resources(struct sock *sk);
void tls_sw_free_resources_tx(struct sock *sk);
void tls_sw_free_resources_rx(struct sock *sk);
int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len);
unsigned int tls_sw_poll(struct file *file, struct socket *sock,
Expand Down Expand Up @@ -297,16 +306,22 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk)
return icsk->icsk_ulp_data;
}

static inline struct tls_sw_context *tls_sw_ctx(
static inline struct tls_sw_context_rx *tls_sw_ctx_rx(
const struct tls_context *tls_ctx)
{
return (struct tls_sw_context_rx *)tls_ctx->priv_ctx_rx;
}

static inline struct tls_sw_context_tx *tls_sw_ctx_tx(
const struct tls_context *tls_ctx)
{
return (struct tls_sw_context *)tls_ctx->priv_ctx;
return (struct tls_sw_context_tx *)tls_ctx->priv_ctx_tx;
}

static inline struct tls_offload_context *tls_offload_ctx(
const struct tls_context *tls_ctx)
{
return (struct tls_offload_context *)tls_ctx->priv_ctx;
return (struct tls_offload_context *)tls_ctx->priv_ctx_tx;
}

int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
Expand Down
103 changes: 51 additions & 52 deletions net/tls/tls_main.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,9 @@ enum {
TLSV6,
TLS_NUM_PROTS,
};

enum {
TLS_BASE,
TLS_SW_TX,
TLS_SW_RX,
TLS_SW_RXTX,
TLS_SW,
TLS_HW_RECORD,
TLS_NUM_CONFIG,
};
Expand All @@ -65,14 +62,14 @@ static struct proto *saved_tcpv6_prot;
static DEFINE_MUTEX(tcpv6_prot_mutex);
static LIST_HEAD(device_list);
static DEFINE_MUTEX(device_mutex);
static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG];
static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static struct proto_ops tls_sw_proto_ops;

static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx)
static void update_sk_prot(struct sock *sk, struct tls_context *ctx)
{
int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;

sk->sk_prot = &tls_prots[ip_ver][ctx->conf];
sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf];
}

int wait_on_pending_writer(struct sock *sk, long *timeo)
Expand Down Expand Up @@ -245,10 +242,10 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
lock_sock(sk);
sk_proto_close = ctx->sk_proto_close;

if (ctx->conf == TLS_HW_RECORD)
if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD)
goto skip_tx_cleanup;

if (ctx->conf == TLS_BASE) {
if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) {
kfree(ctx);
ctx = NULL;
goto skip_tx_cleanup;
Expand All @@ -270,15 +267,17 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
}
}

kfree(ctx->tx.rec_seq);
kfree(ctx->tx.iv);
kfree(ctx->rx.rec_seq);
kfree(ctx->rx.iv);
/* We need these for tls_sw_fallback handling of other packets */
if (ctx->tx_conf == TLS_SW) {
kfree(ctx->tx.rec_seq);
kfree(ctx->tx.iv);
tls_sw_free_resources_tx(sk);
}

if (ctx->conf == TLS_SW_TX ||
ctx->conf == TLS_SW_RX ||
ctx->conf == TLS_SW_RXTX) {
tls_sw_free_resources(sk);
if (ctx->rx_conf == TLS_SW) {
kfree(ctx->rx.rec_seq);
kfree(ctx->rx.iv);
tls_sw_free_resources_rx(sk);
}

skip_tx_cleanup:
Expand All @@ -287,7 +286,8 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
/* free ctx for TLS_HW_RECORD, used by tcp_set_state
* for sk->sk_prot->unhash [tls_hw_unhash]
*/
if (ctx && ctx->conf == TLS_HW_RECORD)
if (ctx && ctx->tx_conf == TLS_HW_RECORD &&
ctx->rx_conf == TLS_HW_RECORD)
kfree(ctx);
}

Expand Down Expand Up @@ -441,25 +441,21 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
goto err_crypto_info;
}

/* currently SW is default, we will have ethtool in future */
if (tx) {
rc = tls_set_sw_offload(sk, ctx, 1);
if (ctx->conf == TLS_SW_RX)
conf = TLS_SW_RXTX;
else
conf = TLS_SW_TX;
conf = TLS_SW;
} else {
rc = tls_set_sw_offload(sk, ctx, 0);
if (ctx->conf == TLS_SW_TX)
conf = TLS_SW_RXTX;
else
conf = TLS_SW_RX;
conf = TLS_SW;
}

if (rc)
goto err_crypto_info;

ctx->conf = conf;
if (tx)
ctx->tx_conf = conf;
else
ctx->rx_conf = conf;
update_sk_prot(sk, ctx);
if (tx) {
ctx->sk_write_space = sk->sk_write_space;
Expand Down Expand Up @@ -535,7 +531,8 @@ static int tls_hw_prot(struct sock *sk)
ctx->hash = sk->sk_prot->hash;
ctx->unhash = sk->sk_prot->unhash;
ctx->sk_proto_close = sk->sk_prot->close;
ctx->conf = TLS_HW_RECORD;
ctx->rx_conf = TLS_HW_RECORD;
ctx->tx_conf = TLS_HW_RECORD;
update_sk_prot(sk, ctx);
rc = 1;
break;
Expand Down Expand Up @@ -579,29 +576,30 @@ static int tls_hw_hash(struct sock *sk)
return err;
}

static void build_protos(struct proto *prot, struct proto *base)
static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
struct proto *base)
{
prot[TLS_BASE] = *base;
prot[TLS_BASE].setsockopt = tls_setsockopt;
prot[TLS_BASE].getsockopt = tls_getsockopt;
prot[TLS_BASE].close = tls_sk_proto_close;

prot[TLS_SW_TX] = prot[TLS_BASE];
prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg;
prot[TLS_SW_TX].sendpage = tls_sw_sendpage;

prot[TLS_SW_RX] = prot[TLS_BASE];
prot[TLS_SW_RX].recvmsg = tls_sw_recvmsg;
prot[TLS_SW_RX].close = tls_sk_proto_close;

prot[TLS_SW_RXTX] = prot[TLS_SW_TX];
prot[TLS_SW_RXTX].recvmsg = tls_sw_recvmsg;
prot[TLS_SW_RXTX].close = tls_sk_proto_close;

prot[TLS_HW_RECORD] = *base;
prot[TLS_HW_RECORD].hash = tls_hw_hash;
prot[TLS_HW_RECORD].unhash = tls_hw_unhash;
prot[TLS_HW_RECORD].close = tls_sk_proto_close;
prot[TLS_BASE][TLS_BASE] = *base;
prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close;

prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg;
prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage;

prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg;
prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close;

prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg;
prot[TLS_SW][TLS_SW].close = tls_sk_proto_close;

prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash;
prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash;
prot[TLS_HW_RECORD][TLS_HW_RECORD].close = tls_sk_proto_close;
}

static int tls_init(struct sock *sk)
Expand Down Expand Up @@ -643,7 +641,8 @@ static int tls_init(struct sock *sk)
mutex_unlock(&tcpv6_prot_mutex);
}

ctx->conf = TLS_BASE;
ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE;
update_sk_prot(sk, ctx);
out:
return rc;
Expand Down
Loading

0 comments on commit f66de3e

Please sign in to comment.