Skip to content

Commit

Permalink
crypto: arm64/aes-blk - add support for CTS-CBC mode
Browse files Browse the repository at this point in the history
Currently, we rely on the generic CTS chaining mode wrapper to
instantiate the cts(cbc(aes)) skcipher. Due to the high performance
of the ARMv8 Crypto Extensions AES instructions (~1 cycles per byte),
any overhead in the chaining mode layers is amplified, and so it pays
off considerably to fold the CTS handling into the SIMD routines.

On Cortex-A53, this results in a ~50% speedup for smaller input sizes.

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
  • Loading branch information
Ard Biesheuvel authored and Herbert Xu committed Sep 21, 2018
1 parent 6e7de6a commit dd597fb
Show file tree
Hide file tree
Showing 2 changed files with 243 additions and 1 deletion.
165 changes: 165 additions & 0 deletions arch/arm64/crypto/aes-glue.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <crypto/internal/hash.h>
#include <crypto/internal/simd.h>
#include <crypto/internal/skcipher.h>
#include <crypto/scatterwalk.h>
#include <linux/module.h>
#include <linux/cpufeature.h>
#include <crypto/xts.h>
Expand All @@ -31,6 +32,8 @@
#define aes_ecb_decrypt ce_aes_ecb_decrypt
#define aes_cbc_encrypt ce_aes_cbc_encrypt
#define aes_cbc_decrypt ce_aes_cbc_decrypt
#define aes_cbc_cts_encrypt ce_aes_cbc_cts_encrypt
#define aes_cbc_cts_decrypt ce_aes_cbc_cts_decrypt
#define aes_ctr_encrypt ce_aes_ctr_encrypt
#define aes_xts_encrypt ce_aes_xts_encrypt
#define aes_xts_decrypt ce_aes_xts_decrypt
Expand All @@ -45,6 +48,8 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
#define aes_ecb_decrypt neon_aes_ecb_decrypt
#define aes_cbc_encrypt neon_aes_cbc_encrypt
#define aes_cbc_decrypt neon_aes_cbc_decrypt
#define aes_cbc_cts_encrypt neon_aes_cbc_cts_encrypt
#define aes_cbc_cts_decrypt neon_aes_cbc_cts_decrypt
#define aes_ctr_encrypt neon_aes_ctr_encrypt
#define aes_xts_encrypt neon_aes_xts_encrypt
#define aes_xts_decrypt neon_aes_xts_decrypt
Expand Down Expand Up @@ -73,6 +78,11 @@ asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int blocks, u8 iv[]);

asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int bytes, u8 const iv[]);
asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int bytes, u8 const iv[]);

asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int blocks, u8 ctr[]);

Expand All @@ -87,6 +97,12 @@ asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
int blocks, u8 dg[], int enc_before,
int enc_after);

struct cts_cbc_req_ctx {
struct scatterlist sg_src[2];
struct scatterlist sg_dst[2];
struct skcipher_request subreq;
};

struct crypto_aes_xts_ctx {
struct crypto_aes_ctx key1;
struct crypto_aes_ctx __aligned(8) key2;
Expand Down Expand Up @@ -209,6 +225,136 @@ static int cbc_decrypt(struct skcipher_request *req)
return err;
}

static int cts_cbc_init_tfm(struct crypto_skcipher *tfm)
{
crypto_skcipher_set_reqsize(tfm, sizeof(struct cts_cbc_req_ctx));
return 0;
}

static int cts_cbc_encrypt(struct skcipher_request *req)
{
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
int err, rounds = 6 + ctx->key_length / 4;
int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
struct scatterlist *src = req->src, *dst = req->dst;
struct skcipher_walk walk;

skcipher_request_set_tfm(&rctx->subreq, tfm);

if (req->cryptlen == AES_BLOCK_SIZE)
cbc_blocks = 1;

if (cbc_blocks > 0) {
unsigned int blocks;

skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
cbc_blocks * AES_BLOCK_SIZE,
req->iv);

err = skcipher_walk_virt(&walk, &rctx->subreq, false);

while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
kernel_neon_begin();
aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
ctx->key_enc, rounds, blocks, walk.iv);
kernel_neon_end();
err = skcipher_walk_done(&walk,
walk.nbytes % AES_BLOCK_SIZE);
}
if (err)
return err;

if (req->cryptlen == AES_BLOCK_SIZE)
return 0;

dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
rctx->subreq.cryptlen);
if (req->dst != req->src)
dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
rctx->subreq.cryptlen);
}

/* handle ciphertext stealing */
skcipher_request_set_crypt(&rctx->subreq, src, dst,
req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
req->iv);

err = skcipher_walk_virt(&walk, &rctx->subreq, false);
if (err)
return err;

kernel_neon_begin();
aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
ctx->key_enc, rounds, walk.nbytes, walk.iv);
kernel_neon_end();

return skcipher_walk_done(&walk, 0);
}

static int cts_cbc_decrypt(struct skcipher_request *req)
{
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
int err, rounds = 6 + ctx->key_length / 4;
int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
struct scatterlist *src = req->src, *dst = req->dst;
struct skcipher_walk walk;

skcipher_request_set_tfm(&rctx->subreq, tfm);

if (req->cryptlen == AES_BLOCK_SIZE)
cbc_blocks = 1;

if (cbc_blocks > 0) {
unsigned int blocks;

skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
cbc_blocks * AES_BLOCK_SIZE,
req->iv);

err = skcipher_walk_virt(&walk, &rctx->subreq, false);

while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
kernel_neon_begin();
aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
ctx->key_dec, rounds, blocks, walk.iv);
kernel_neon_end();
err = skcipher_walk_done(&walk,
walk.nbytes % AES_BLOCK_SIZE);
}
if (err)
return err;

if (req->cryptlen == AES_BLOCK_SIZE)
return 0;

dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
rctx->subreq.cryptlen);
if (req->dst != req->src)
dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
rctx->subreq.cryptlen);
}

/* handle ciphertext stealing */
skcipher_request_set_crypt(&rctx->subreq, src, dst,
req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
req->iv);

err = skcipher_walk_virt(&walk, &rctx->subreq, false);
if (err)
return err;

kernel_neon_begin();
aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
ctx->key_dec, rounds, walk.nbytes, walk.iv);
kernel_neon_end();

return skcipher_walk_done(&walk, 0);
}

static int ctr_encrypt(struct skcipher_request *req)
{
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
Expand Down Expand Up @@ -334,6 +480,25 @@ static struct skcipher_alg aes_algs[] = { {
.setkey = skcipher_aes_setkey,
.encrypt = cbc_encrypt,
.decrypt = cbc_decrypt,
}, {
.base = {
.cra_name = "__cts(cbc(aes))",
.cra_driver_name = "__cts-cbc-aes-" MODE,
.cra_priority = PRIO,
.cra_flags = CRYPTO_ALG_INTERNAL,
.cra_blocksize = 1,
.cra_ctxsize = sizeof(struct crypto_aes_ctx),
.cra_module = THIS_MODULE,
},
.min_keysize = AES_MIN_KEY_SIZE,
.max_keysize = AES_MAX_KEY_SIZE,
.ivsize = AES_BLOCK_SIZE,
.chunksize = AES_BLOCK_SIZE,
.walksize = 2 * AES_BLOCK_SIZE,
.setkey = skcipher_aes_setkey,
.encrypt = cts_cbc_encrypt,
.decrypt = cts_cbc_decrypt,
.init = cts_cbc_init_tfm,
}, {
.base = {
.cra_name = "__ctr(aes)",
Expand Down
79 changes: 78 additions & 1 deletion arch/arm64/crypto/aes-modes.S
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,84 @@ AES_ENTRY(aes_cbc_decrypt)
AES_ENDPROC(aes_cbc_decrypt)


/*
* aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
* int rounds, int bytes, u8 const iv[])
* aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
* int rounds, int bytes, u8 const iv[])
*/

AES_ENTRY(aes_cbc_cts_encrypt)
adr_l x8, .Lcts_permute_table
sub x4, x4, #16
add x9, x8, #32
add x8, x8, x4
sub x9, x9, x4
ld1 {v3.16b}, [x8]
ld1 {v4.16b}, [x9]

ld1 {v0.16b}, [x1], x4 /* overlapping loads */
ld1 {v1.16b}, [x1]

ld1 {v5.16b}, [x5] /* get iv */
enc_prepare w3, x2, x6

eor v0.16b, v0.16b, v5.16b /* xor with iv */
tbl v1.16b, {v1.16b}, v4.16b
encrypt_block v0, w3, x2, x6, w7

eor v1.16b, v1.16b, v0.16b
tbl v0.16b, {v0.16b}, v3.16b
encrypt_block v1, w3, x2, x6, w7

add x4, x0, x4
st1 {v0.16b}, [x4] /* overlapping stores */
st1 {v1.16b}, [x0]
ret
AES_ENDPROC(aes_cbc_cts_encrypt)

AES_ENTRY(aes_cbc_cts_decrypt)
adr_l x8, .Lcts_permute_table
sub x4, x4, #16
add x9, x8, #32
add x8, x8, x4
sub x9, x9, x4
ld1 {v3.16b}, [x8]
ld1 {v4.16b}, [x9]

ld1 {v0.16b}, [x1], x4 /* overlapping loads */
ld1 {v1.16b}, [x1]

ld1 {v5.16b}, [x5] /* get iv */
dec_prepare w3, x2, x6

tbl v2.16b, {v1.16b}, v4.16b
decrypt_block v0, w3, x2, x6, w7
eor v2.16b, v2.16b, v0.16b

tbx v0.16b, {v1.16b}, v4.16b
tbl v2.16b, {v2.16b}, v3.16b
decrypt_block v0, w3, x2, x6, w7
eor v0.16b, v0.16b, v5.16b /* xor with iv */

add x4, x0, x4
st1 {v2.16b}, [x4] /* overlapping stores */
st1 {v0.16b}, [x0]
ret
AES_ENDPROC(aes_cbc_cts_decrypt)

.section ".rodata", "a"
.align 6
.Lcts_permute_table:
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.byte 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7
.byte 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.previous


/*
* aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
* int blocks, u8 ctr[])
Expand Down Expand Up @@ -253,7 +331,6 @@ AES_ENTRY(aes_ctr_encrypt)
ins v4.d[0], x7
b .Lctrcarrydone
AES_ENDPROC(aes_ctr_encrypt)
.ltorg


/*
Expand Down

0 comments on commit dd597fb

Please sign in to comment.