crypto: arm64/aes-cts-cbc - factor out CBC en/decryption of a walk
authorArd Biesheuvel <ard.biesheuvel@linaro.org>
Mon, 19 Aug 2019 14:17:35 +0000 (17:17 +0300)
committerHerbert Xu <herbert@gondor.apana.org.au>
Fri, 30 Aug 2019 08:05:27 +0000 (18:05 +1000)
The plain CBC driver and the CTS one share some code that iterates over
a scatterwalk and invokes the CBC asm code to do the processing. The
upcoming ESSIV/CBC mode will clone that pattern for the third time, so
let's factor it out first.

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
arch/arm64/crypto/aes-glue.c

index 55d6d4838708da9cd2f135c60ebb758de4dac586..23abf335f1eebaf149843537013ae302ebf83cd6 100644 (file)
@@ -186,46 +186,64 @@ static int ecb_decrypt(struct skcipher_request *req)
        return err;
 }
 
-static int cbc_encrypt(struct skcipher_request *req)
+static int cbc_encrypt_walk(struct skcipher_request *req,
+                           struct skcipher_walk *walk)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
-       int err, rounds = 6 + ctx->key_length / 4;
-       struct skcipher_walk walk;
+       int err = 0, rounds = 6 + ctx->key_length / 4;
        unsigned int blocks;
 
-       err = skcipher_walk_virt(&walk, req, false);
-
-       while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+       while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
                kernel_neon_begin();
-               aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               ctx->key_enc, rounds, blocks, walk.iv);
+               aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
+                               ctx->key_enc, rounds, blocks, walk->iv);
                kernel_neon_end();
-               err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
+               err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
        }
        return err;
 }
 
-static int cbc_decrypt(struct skcipher_request *req)
+static int cbc_encrypt(struct skcipher_request *req)
 {
-       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
-       struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
-       int err, rounds = 6 + ctx->key_length / 4;
        struct skcipher_walk walk;
-       unsigned int blocks;
+       int err;
 
        err = skcipher_walk_virt(&walk, req, false);
+       if (err)
+               return err;
+       return cbc_encrypt_walk(req, &walk);
+}
 
-       while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+static int cbc_decrypt_walk(struct skcipher_request *req,
+                           struct skcipher_walk *walk)
+{
+       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+       struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+       int err = 0, rounds = 6 + ctx->key_length / 4;
+       unsigned int blocks;
+
+       while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
                kernel_neon_begin();
-               aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               ctx->key_dec, rounds, blocks, walk.iv);
+               aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
+                               ctx->key_dec, rounds, blocks, walk->iv);
                kernel_neon_end();
-               err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
+               err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
        }
        return err;
 }
 
+static int cbc_decrypt(struct skcipher_request *req)
+{
+       struct skcipher_walk walk;
+       int err;
+
+       err = skcipher_walk_virt(&walk, req, false);
+       if (err)
+               return err;
+       return cbc_decrypt_walk(req, &walk);
+}
+
 static int cts_cbc_init_tfm(struct crypto_skcipher *tfm)
 {
        crypto_skcipher_set_reqsize(tfm, sizeof(struct cts_cbc_req_ctx));
@@ -251,22 +269,12 @@ static int cts_cbc_encrypt(struct skcipher_request *req)
        }
 
        if (cbc_blocks > 0) {
-               unsigned int blocks;
-
                skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
                                           cbc_blocks * AES_BLOCK_SIZE,
                                           req->iv);
 
-               err = skcipher_walk_virt(&walk, &rctx->subreq, false);
-
-               while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
-                       kernel_neon_begin();
-                       aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                       ctx->key_enc, rounds, blocks, walk.iv);
-                       kernel_neon_end();
-                       err = skcipher_walk_done(&walk,
-                                                walk.nbytes % AES_BLOCK_SIZE);
-               }
+               err = skcipher_walk_virt(&walk, &rctx->subreq, false) ?:
+                     cbc_encrypt_walk(&rctx->subreq, &walk);
                if (err)
                        return err;
 
@@ -316,22 +324,12 @@ static int cts_cbc_decrypt(struct skcipher_request *req)
        }
 
        if (cbc_blocks > 0) {
-               unsigned int blocks;
-
                skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
                                           cbc_blocks * AES_BLOCK_SIZE,
                                           req->iv);
 
-               err = skcipher_walk_virt(&walk, &rctx->subreq, false);
-
-               while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
-                       kernel_neon_begin();
-                       aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                       ctx->key_dec, rounds, blocks, walk.iv);
-                       kernel_neon_end();
-                       err = skcipher_walk_done(&walk,
-                                                walk.nbytes % AES_BLOCK_SIZE);
-               }
+               err = skcipher_walk_virt(&walk, &rctx->subreq, false) ?:
+                     cbc_decrypt_walk(&rctx->subreq, &walk);
                if (err)
                        return err;