crypto: arm/aes-ce - implement ciphertext stealing for CBC
authorArd Biesheuvel <ard.biesheuvel@linaro.org>
Tue, 3 Sep 2019 16:43:37 +0000 (09:43 -0700)
committerHerbert Xu <herbert@gondor.apana.org.au>
Mon, 9 Sep 2019 07:35:39 +0000 (17:35 +1000)
Instead of relying on the CTS template to wrap the accelerated CBC
skcipher, implement the ciphertext stealing part directly.

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
arch/arm/crypto/aes-ce-core.S
arch/arm/crypto/aes-ce-glue.c

index 763e51604ab668142caa04bfaa8b6307fda71970..b978cdf133af60f6f49675dda9b565dc04e47e32 100644 (file)
@@ -284,6 +284,91 @@ ENTRY(ce_aes_cbc_decrypt)
        pop             {r4-r6, pc}
 ENDPROC(ce_aes_cbc_decrypt)
 
+
+       /*
+        * ce_aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
+        *                        int rounds, int bytes, u8 const iv[])
+        * ce_aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
+        *                        int rounds, int bytes, u8 const iv[])
+        */
+
+ENTRY(ce_aes_cbc_cts_encrypt)
+       push            {r4-r6, lr}
+       ldrd            r4, r5, [sp, #16]
+
+       movw            ip, :lower16:.Lcts_permute_table
+       movt            ip, :upper16:.Lcts_permute_table
+       sub             r4, r4, #16
+       add             lr, ip, #32
+       add             ip, ip, r4
+       sub             lr, lr, r4
+       vld1.8          {q5}, [ip]
+       vld1.8          {q6}, [lr]
+
+       add             ip, r1, r4
+       vld1.8          {q0}, [r1]                      @ overlapping loads
+       vld1.8          {q3}, [ip]
+
+       vld1.8          {q1}, [r5]                      @ get iv
+       prepare_key     r2, r3
+
+       veor            q0, q0, q1                      @ xor with iv
+       bl              aes_encrypt
+
+       vtbl.8          d4, {d0-d1}, d10
+       vtbl.8          d5, {d0-d1}, d11
+       vtbl.8          d2, {d6-d7}, d12
+       vtbl.8          d3, {d6-d7}, d13
+
+       veor            q0, q0, q1
+       bl              aes_encrypt
+
+       add             r4, r0, r4
+       vst1.8          {q2}, [r4]                      @ overlapping stores
+       vst1.8          {q0}, [r0]
+
+       pop             {r4-r6, pc}
+ENDPROC(ce_aes_cbc_cts_encrypt)
+
+ENTRY(ce_aes_cbc_cts_decrypt)
+       push            {r4-r6, lr}
+       ldrd            r4, r5, [sp, #16]
+
+       movw            ip, :lower16:.Lcts_permute_table
+       movt            ip, :upper16:.Lcts_permute_table
+       sub             r4, r4, #16
+       add             lr, ip, #32
+       add             ip, ip, r4
+       sub             lr, lr, r4
+       vld1.8          {q5}, [ip]
+       vld1.8          {q6}, [lr]
+
+       add             ip, r1, r4
+       vld1.8          {q0}, [r1]                      @ overlapping loads
+       vld1.8          {q1}, [ip]
+
+       vld1.8          {q3}, [r5]                      @ get iv
+       prepare_key     r2, r3
+
+       bl              aes_decrypt
+
+       vtbl.8          d4, {d0-d1}, d10
+       vtbl.8          d5, {d0-d1}, d11
+       vtbx.8          d0, {d2-d3}, d12
+       vtbx.8          d1, {d2-d3}, d13
+
+       veor            q1, q1, q2
+       bl              aes_decrypt
+       veor            q0, q0, q3                      @ xor with iv
+
+       add             r4, r0, r4
+       vst1.8          {q1}, [r4]                      @ overlapping stores
+       vst1.8          {q0}, [r0]
+
+       pop             {r4-r6, pc}
+ENDPROC(ce_aes_cbc_cts_decrypt)
+
+
        /*
         * aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[], int rounds,
         *                 int blocks, u8 ctr[])
index c215792a2494a26faaef9eab59a676505a3c6d5f..cdb1a07e7ad0a03a00b7c1d99158874e9095c5ff 100644 (file)
@@ -35,6 +35,10 @@ asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
                                   int rounds, int blocks, u8 iv[]);
 asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
                                   int rounds, int blocks, u8 iv[]);
+asmlinkage void ce_aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
+                                  int rounds, int bytes, u8 const iv[]);
+asmlinkage void ce_aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
+                                  int rounds, int bytes, u8 const iv[]);
 
 asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
                                   int rounds, int blocks, u8 ctr[]);
@@ -210,48 +214,182 @@ static int ecb_decrypt(struct skcipher_request *req)
        return err;
 }
 
-static int cbc_encrypt(struct skcipher_request *req)
+static int cbc_encrypt_walk(struct skcipher_request *req,
+                           struct skcipher_walk *walk)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
-       struct skcipher_walk walk;
        unsigned int blocks;
-       int err;
+       int err = 0;
 
-       err = skcipher_walk_virt(&walk, req, false);
-
-       while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+       while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
                kernel_neon_begin();
-               ce_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+               ce_aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
                                   ctx->key_enc, num_rounds(ctx), blocks,
-                                  walk.iv);
+                                  walk->iv);
                kernel_neon_end();
-               err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
+               err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
        }
        return err;
 }
 
-static int cbc_decrypt(struct skcipher_request *req)
+static int cbc_encrypt(struct skcipher_request *req)
 {
-       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
-       struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
        struct skcipher_walk walk;
-       unsigned int blocks;
        int err;
 
        err = skcipher_walk_virt(&walk, req, false);
+       if (err)
+               return err;
+       return cbc_encrypt_walk(req, &walk);
+}
 
-       while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+static int cbc_decrypt_walk(struct skcipher_request *req,
+                           struct skcipher_walk *walk)
+{
+       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+       struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+       unsigned int blocks;
+       int err = 0;
+
+       while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
                kernel_neon_begin();
-               ce_aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+               ce_aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
                                   ctx->key_dec, num_rounds(ctx), blocks,
-                                  walk.iv);
+                                  walk->iv);
                kernel_neon_end();
-               err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
+               err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
        }
        return err;
 }
 
+static int cbc_decrypt(struct skcipher_request *req)
+{
+       struct skcipher_walk walk;
+       int err;
+
+       err = skcipher_walk_virt(&walk, req, false);
+       if (err)
+               return err;
+       return cbc_decrypt_walk(req, &walk);
+}
+
+static int cts_cbc_encrypt(struct skcipher_request *req)
+{
+       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+       struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+       int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+       struct scatterlist *src = req->src, *dst = req->dst;
+       struct scatterlist sg_src[2], sg_dst[2];
+       struct skcipher_request subreq;
+       struct skcipher_walk walk;
+       int err;
+
+       skcipher_request_set_tfm(&subreq, tfm);
+       skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
+                                     NULL, NULL);
+
+       if (req->cryptlen <= AES_BLOCK_SIZE) {
+               if (req->cryptlen < AES_BLOCK_SIZE)
+                       return -EINVAL;
+               cbc_blocks = 1;
+       }
+
+       if (cbc_blocks > 0) {
+               skcipher_request_set_crypt(&subreq, req->src, req->dst,
+                                          cbc_blocks * AES_BLOCK_SIZE,
+                                          req->iv);
+
+               err = skcipher_walk_virt(&walk, &subreq, false) ?:
+                     cbc_encrypt_walk(&subreq, &walk);
+               if (err)
+                       return err;
+
+               if (req->cryptlen == AES_BLOCK_SIZE)
+                       return 0;
+
+               dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
+               if (req->dst != req->src)
+                       dst = scatterwalk_ffwd(sg_dst, req->dst,
+                                              subreq.cryptlen);
+       }
+
+       /* handle ciphertext stealing */
+       skcipher_request_set_crypt(&subreq, src, dst,
+                                  req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+                                  req->iv);
+
+       err = skcipher_walk_virt(&walk, &subreq, false);
+       if (err)
+               return err;
+
+       kernel_neon_begin();
+       ce_aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                              ctx->key_enc, num_rounds(ctx), walk.nbytes,
+                              walk.iv);
+       kernel_neon_end();
+
+       return skcipher_walk_done(&walk, 0);
+}
+
+static int cts_cbc_decrypt(struct skcipher_request *req)
+{
+       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+       struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+       int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+       struct scatterlist *src = req->src, *dst = req->dst;
+       struct scatterlist sg_src[2], sg_dst[2];
+       struct skcipher_request subreq;
+       struct skcipher_walk walk;
+       int err;
+
+       skcipher_request_set_tfm(&subreq, tfm);
+       skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
+                                     NULL, NULL);
+
+       if (req->cryptlen <= AES_BLOCK_SIZE) {
+               if (req->cryptlen < AES_BLOCK_SIZE)
+                       return -EINVAL;
+               cbc_blocks = 1;
+       }
+
+       if (cbc_blocks > 0) {
+               skcipher_request_set_crypt(&subreq, req->src, req->dst,
+                                          cbc_blocks * AES_BLOCK_SIZE,
+                                          req->iv);
+
+               err = skcipher_walk_virt(&walk, &subreq, false) ?:
+                     cbc_decrypt_walk(&subreq, &walk);
+               if (err)
+                       return err;
+
+               if (req->cryptlen == AES_BLOCK_SIZE)
+                       return 0;
+
+               dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
+               if (req->dst != req->src)
+                       dst = scatterwalk_ffwd(sg_dst, req->dst,
+                                              subreq.cryptlen);
+       }
+
+       /* handle ciphertext stealing */
+       skcipher_request_set_crypt(&subreq, src, dst,
+                                  req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+                                  req->iv);
+
+       err = skcipher_walk_virt(&walk, &subreq, false);
+       if (err)
+               return err;
+
+       kernel_neon_begin();
+       ce_aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                              ctx->key_dec, num_rounds(ctx), walk.nbytes,
+                              walk.iv);
+       kernel_neon_end();
+
+       return skcipher_walk_done(&walk, 0);
+}
+
 static int ctr_encrypt(struct skcipher_request *req)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
@@ -486,6 +624,22 @@ static struct skcipher_alg aes_algs[] = { {
        .setkey                 = ce_aes_setkey,
        .encrypt                = cbc_encrypt,
        .decrypt                = cbc_decrypt,
+}, {
+       .base.cra_name          = "__cts(cbc(aes))",
+       .base.cra_driver_name   = "__cts-cbc-aes-ce",
+       .base.cra_priority      = 300,
+       .base.cra_flags         = CRYPTO_ALG_INTERNAL,
+       .base.cra_blocksize     = AES_BLOCK_SIZE,
+       .base.cra_ctxsize       = sizeof(struct crypto_aes_ctx),
+       .base.cra_module        = THIS_MODULE,
+
+       .min_keysize            = AES_MIN_KEY_SIZE,
+       .max_keysize            = AES_MAX_KEY_SIZE,
+       .ivsize                 = AES_BLOCK_SIZE,
+       .walksize               = 2 * AES_BLOCK_SIZE,
+       .setkey                 = ce_aes_setkey,
+       .encrypt                = cts_cbc_encrypt,
+       .decrypt                = cts_cbc_decrypt,
 }, {
        .base.cra_name          = "__ctr(aes)",
        .base.cra_driver_name   = "__ctr-aes-ce",