#include <crypto/ctr.h>
#include <crypto/internal/simd.h>
#include <crypto/internal/skcipher.h>
+#include <crypto/scatterwalk.h>
#include <crypto/xts.h>
#include <linux/module.h>
int rounds, int blocks);
asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int blocks, u8 iv[]);
+asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[],
+ u32 const rk1[], int rounds, int bytes,
+ u32 const rk2[], u8 iv[], int first);
+asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[],
+ u32 const rk1[], int rounds, int bytes,
+ u32 const rk2[], u8 iv[], int first);
struct aesbs_ctx {
u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32];
struct aesbs_xts_ctx {
struct aesbs_ctx key;
u32 twkey[AES_MAX_KEYLENGTH_U32];
+ struct crypto_aes_ctx cts;
};
static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
return err;
key_len /= 2;
+ err = aes_expandkey(&ctx->cts, in_key, key_len);
+ if (err)
+ return err;
+
err = aes_expandkey(&rk, in_key + key_len, key_len);
if (err)
return err;
return ctr_encrypt(req);
}
-static int __xts_crypt(struct skcipher_request *req,
+static int __xts_crypt(struct skcipher_request *req, bool encrypt,
void (*fn)(u8 out[], u8 const in[], u8 const rk[],
int rounds, int blocks, u8 iv[]))
{
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
+ int tail = req->cryptlen % (8 * AES_BLOCK_SIZE);
+ struct scatterlist sg_src[2], sg_dst[2];
+ struct skcipher_request subreq;
+ struct scatterlist *src, *dst;
struct skcipher_walk walk;
- int err;
+ int nbytes, err;
+ int first = 1;
+ u8 *out, *in;
+
+ if (req->cryptlen < AES_BLOCK_SIZE)
+ return -EINVAL;
+
+ /* ensure that the cts tail is covered by a single step */
+ if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) {
+ int xts_blocks = DIV_ROUND_UP(req->cryptlen,
+ AES_BLOCK_SIZE) - 2;
+
+ skcipher_request_set_tfm(&subreq, tfm);
+ skcipher_request_set_callback(&subreq,
+ skcipher_request_flags(req),
+ NULL, NULL);
+ skcipher_request_set_crypt(&subreq, req->src, req->dst,
+ xts_blocks * AES_BLOCK_SIZE,
+ req->iv);
+ req = &subreq;
+ } else {
+ tail = 0;
+ }
err = skcipher_walk_virt(&walk, req, false);
if (err)
return err;
- kernel_neon_begin();
- neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey, ctx->key.rounds, 1);
- kernel_neon_end();
-
while (walk.nbytes >= AES_BLOCK_SIZE) {
unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
- if (walk.nbytes < walk.total)
+ if (walk.nbytes < walk.total || walk.nbytes % AES_BLOCK_SIZE)
blocks = round_down(blocks,
walk.stride / AES_BLOCK_SIZE);
+ out = walk.dst.virt.addr;
+ in = walk.src.virt.addr;
+ nbytes = walk.nbytes;
+
kernel_neon_begin();
- fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
- ctx->key.rounds, blocks, walk.iv);
+ if (likely(blocks > 6)) { /* plain NEON is faster otherwise */
+ if (first)
+ neon_aes_ecb_encrypt(walk.iv, walk.iv,
+ ctx->twkey,
+ ctx->key.rounds, 1);
+ first = 0;
+
+ fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
+ walk.iv);
+
+ out += blocks * AES_BLOCK_SIZE;
+ in += blocks * AES_BLOCK_SIZE;
+ nbytes -= blocks * AES_BLOCK_SIZE;
+ }
+
+ if (walk.nbytes == walk.total && nbytes > 0)
+ goto xts_tail;
+
kernel_neon_end();
- err = skcipher_walk_done(&walk,
- walk.nbytes - blocks * AES_BLOCK_SIZE);
+ skcipher_walk_done(&walk, nbytes);
}
- return err;
+
+ if (err || likely(!tail))
+ return err;
+
+ /* handle ciphertext stealing */
+ dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+ if (req->dst != req->src)
+ dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+
+ skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+ req->iv);
+
+ err = skcipher_walk_virt(&walk, req, false);
+ if (err)
+ return err;
+
+ out = walk.dst.virt.addr;
+ in = walk.src.virt.addr;
+ nbytes = walk.nbytes;
+
+ kernel_neon_begin();
+xts_tail:
+ if (encrypt)
+ neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
+ nbytes, ctx->twkey, walk.iv, first ?: 2);
+ else
+ neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
+ nbytes, ctx->twkey, walk.iv, first ?: 2);
+ kernel_neon_end();
+
+ return skcipher_walk_done(&walk, 0);
}
static int xts_encrypt(struct skcipher_request *req)
{
- return __xts_crypt(req, aesbs_xts_encrypt);
+ return __xts_crypt(req, true, aesbs_xts_encrypt);
}
static int xts_decrypt(struct skcipher_request *req)
{
- return __xts_crypt(req, aesbs_xts_decrypt);
+ return __xts_crypt(req, false, aesbs_xts_decrypt);
}
static struct skcipher_alg aes_algs[] = { {