crypto: arm64/aes-neonbs - implement ciphertext stealing for XTS
authorArd Biesheuvel <ard.biesheuvel@linaro.org>
Tue, 3 Sep 2019 16:43:34 +0000 (09:43 -0700)
committerHerbert Xu <herbert@gondor.apana.org.au>
Mon, 9 Sep 2019 07:35:39 +0000 (17:35 +1000)
Update the AES-XTS implementation based on NEON instructions so that it
can deal with inputs whose size is not a multiple of the cipher block
size. This is part of the original XTS specification, but was never
implemented before in the Linux kernel.

Since the bit slicing driver is only faster if it can operate on at
least 7 blocks of input at the same time, let's reuse the alternate
path we are adding for CTS to process any data tail whose size is
not a multiple of 128 bytes.

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
arch/arm64/crypto/aes-ce.S
arch/arm64/crypto/aes-glue.c
arch/arm64/crypto/aes-modes.S
arch/arm64/crypto/aes-neon.S
arch/arm64/crypto/aes-neonbs-glue.c

index 00bd2885feaa5a17fd85f29ee71f5bbae955a038..c132c49c89a8c4427fd2a252ddd4874a5ad44dfa 100644 (file)
@@ -21,6 +21,9 @@
        .macro          xts_reload_mask, tmp
        .endm
 
+       .macro          xts_cts_skip_tw, reg, lbl
+       .endm
+
        /* preload all round keys */
        .macro          load_round_keys, rounds, rk
        cmp             \rounds, #12
index 23ee7c85c0b784ff4e8b445ee8826f5037049365..aa57dc639f77fbef95a1dc3c042fdaa58487c44d 100644 (file)
@@ -1071,5 +1071,7 @@ module_cpu_feature_match(AES, aes_init);
 module_init(aes_init);
 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
+EXPORT_SYMBOL(neon_aes_xts_encrypt);
+EXPORT_SYMBOL(neon_aes_xts_decrypt);
 #endif
 module_exit(aes_exit);
index f2c2ba739f36a0d76551e862cdf8ce5f61058f4a..131618389f1fda7fd3744ce0584a33ce5cc20388 100644 (file)
@@ -442,6 +442,7 @@ AES_ENTRY(aes_xts_encrypt)
        cbz             w7, .Lxtsencnotfirst
 
        enc_prepare     w3, x5, x8
+       xts_cts_skip_tw w7, .LxtsencNx
        encrypt_block   v4, w3, x5, x8, w7              /* first tweak */
        enc_switch_key  w3, x2, x8
        b               .LxtsencNx
@@ -530,10 +531,12 @@ AES_ENTRY(aes_xts_decrypt)
 
        ld1             {v4.16b}, [x6]
        xts_load_mask   v8
+       xts_cts_skip_tw w7, .Lxtsdecskiptw
        cbz             w7, .Lxtsdecnotfirst
 
        enc_prepare     w3, x5, x8
        encrypt_block   v4, w3, x5, x8, w7              /* first tweak */
+.Lxtsdecskiptw:
        dec_prepare     w3, x2, x8
        b               .LxtsdecNx
 
index 0cac5df6c901f24e58941c237550b4c44a2ea412..22d9b110cf78b1838ea1ef5c7643a9817e737756 100644 (file)
        xts_load_mask   \tmp
        .endm
 
+       /* special case for the neon-bs driver calling into this one for CTS */
+       .macro          xts_cts_skip_tw, reg, lbl
+       tbnz            \reg, #1, \lbl
+       .endm
+
        /* multiply by polynomial 'x' in GF(2^8) */
        .macro          mul_by_x, out, in, temp, const
        sshr            \temp, \in, #7
index bafd2ebef8f1ededfd4e136dfba6b55ecfc7c226..ea873b8904c49c6bc5e2a11795494e947cd2b5cf 100644 (file)
@@ -11,6 +11,7 @@
 #include <crypto/ctr.h>
 #include <crypto/internal/simd.h>
 #include <crypto/internal/skcipher.h>
+#include <crypto/scatterwalk.h>
 #include <crypto/xts.h>
 #include <linux/module.h>
 
@@ -45,6 +46,12 @@ asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
                                     int rounds, int blocks);
 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
                                     int rounds, int blocks, u8 iv[]);
+asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[],
+                                    u32 const rk1[], int rounds, int bytes,
+                                    u32 const rk2[], u8 iv[], int first);
+asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[],
+                                    u32 const rk1[], int rounds, int bytes,
+                                    u32 const rk2[], u8 iv[], int first);
 
 struct aesbs_ctx {
        u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32];
@@ -64,6 +71,7 @@ struct aesbs_ctr_ctx {
 struct aesbs_xts_ctx {
        struct aesbs_ctx        key;
        u32                     twkey[AES_MAX_KEYLENGTH_U32];
+       struct crypto_aes_ctx   cts;
 };
 
 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
@@ -270,6 +278,10 @@ static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
                return err;
 
        key_len /= 2;
+       err = aes_expandkey(&ctx->cts, in_key, key_len);
+       if (err)
+               return err;
+
        err = aes_expandkey(&rk, in_key + key_len, key_len);
        if (err)
                return err;
@@ -302,48 +314,119 @@ static int ctr_encrypt_sync(struct skcipher_request *req)
        return ctr_encrypt(req);
 }
 
-static int __xts_crypt(struct skcipher_request *req,
+static int __xts_crypt(struct skcipher_request *req, bool encrypt,
                       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
                                  int rounds, int blocks, u8 iv[]))
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
+       int tail = req->cryptlen % (8 * AES_BLOCK_SIZE);
+       struct scatterlist sg_src[2], sg_dst[2];
+       struct skcipher_request subreq;
+       struct scatterlist *src, *dst;
        struct skcipher_walk walk;
-       int err;
+       int nbytes, err;
+       int first = 1;
+       u8 *out, *in;
+
+       if (req->cryptlen < AES_BLOCK_SIZE)
+               return -EINVAL;
+
+       /* ensure that the cts tail is covered by a single step */
+       if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) {
+               int xts_blocks = DIV_ROUND_UP(req->cryptlen,
+                                             AES_BLOCK_SIZE) - 2;
+
+               skcipher_request_set_tfm(&subreq, tfm);
+               skcipher_request_set_callback(&subreq,
+                                             skcipher_request_flags(req),
+                                             NULL, NULL);
+               skcipher_request_set_crypt(&subreq, req->src, req->dst,
+                                          xts_blocks * AES_BLOCK_SIZE,
+                                          req->iv);
+               req = &subreq;
+       } else {
+               tail = 0;
+       }
 
        err = skcipher_walk_virt(&walk, req, false);
        if (err)
                return err;
 
-       kernel_neon_begin();
-       neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey, ctx->key.rounds, 1);
-       kernel_neon_end();
-
        while (walk.nbytes >= AES_BLOCK_SIZE) {
                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 
-               if (walk.nbytes < walk.total)
+               if (walk.nbytes < walk.total || walk.nbytes % AES_BLOCK_SIZE)
                        blocks = round_down(blocks,
                                            walk.stride / AES_BLOCK_SIZE);
 
+               out = walk.dst.virt.addr;
+               in = walk.src.virt.addr;
+               nbytes = walk.nbytes;
+
                kernel_neon_begin();
-               fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
-                  ctx->key.rounds, blocks, walk.iv);
+               if (likely(blocks > 6)) { /* plain NEON is faster otherwise */
+                       if (first)
+                               neon_aes_ecb_encrypt(walk.iv, walk.iv,
+                                                    ctx->twkey,
+                                                    ctx->key.rounds, 1);
+                       first = 0;
+
+                       fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
+                          walk.iv);
+
+                       out += blocks * AES_BLOCK_SIZE;
+                       in += blocks * AES_BLOCK_SIZE;
+                       nbytes -= blocks * AES_BLOCK_SIZE;
+               }
+
+               if (walk.nbytes == walk.total && nbytes > 0)
+                       goto xts_tail;
+
                kernel_neon_end();
-               err = skcipher_walk_done(&walk,
-                                        walk.nbytes - blocks * AES_BLOCK_SIZE);
+               skcipher_walk_done(&walk, nbytes);
        }
-       return err;
+
+       if (err || likely(!tail))
+               return err;
+
+       /* handle ciphertext stealing */
+       dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+       if (req->dst != req->src)
+               dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+
+       skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+                                  req->iv);
+
+       err = skcipher_walk_virt(&walk, req, false);
+       if (err)
+               return err;
+
+       out = walk.dst.virt.addr;
+       in = walk.src.virt.addr;
+       nbytes = walk.nbytes;
+
+       kernel_neon_begin();
+xts_tail:
+       if (encrypt)
+               neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
+                                    nbytes, ctx->twkey, walk.iv, first ?: 2);
+       else
+               neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
+                                    nbytes, ctx->twkey, walk.iv, first ?: 2);
+       kernel_neon_end();
+
+       return skcipher_walk_done(&walk, 0);
 }
 
 static int xts_encrypt(struct skcipher_request *req)
 {
-       return __xts_crypt(req, aesbs_xts_encrypt);
+       return __xts_crypt(req, true, aesbs_xts_encrypt);
 }
 
 static int xts_decrypt(struct skcipher_request *req)
 {
-       return __xts_crypt(req, aesbs_xts_decrypt);
+       return __xts_crypt(req, false, aesbs_xts_decrypt);
 }
 
 static struct skcipher_alg aes_algs[] = { {