Commit 7c9d65c4 authored by Ard Biesheuvel's avatar Ard Biesheuvel Committed by Herbert Xu
Browse files

crypto: arm64/aes-cts-cbc - move request context data to the stack



Since the CTS-CBC code completes synchronously, there is no point in
keeping part of the scratch data it uses in the request context, so
move it to the stack instead.

Signed-off-by: default avatarArd Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 0cfd507c
Loading
Loading
Loading
Loading
+26 −35
Original line number Diff line number Diff line
@@ -107,12 +107,6 @@ asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
			       int blocks, u8 dg[], int enc_before,
			       int enc_after);

struct cts_cbc_req_ctx {
	struct scatterlist sg_src[2];
	struct scatterlist sg_dst[2];
	struct skcipher_request subreq;
};

struct crypto_aes_xts_ctx {
	struct crypto_aes_ctx key1;
	struct crypto_aes_ctx __aligned(8) key2;
@@ -292,23 +286,20 @@ static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
	return cbc_decrypt_walk(req, &walk);
}

static int cts_cbc_init_tfm(struct crypto_skcipher *tfm)
{
	crypto_skcipher_set_reqsize(tfm, sizeof(struct cts_cbc_req_ctx));
	return 0;
}

static int cts_cbc_encrypt(struct skcipher_request *req)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
	struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
	int err, rounds = 6 + ctx->key_length / 4;
	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
	struct scatterlist *src = req->src, *dst = req->dst;
	struct scatterlist sg_src[2], sg_dst[2];
	struct skcipher_request subreq;
	struct skcipher_walk walk;

	skcipher_request_set_tfm(&rctx->subreq, tfm);
	skcipher_request_set_tfm(&subreq, tfm);
	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
				      NULL, NULL);

	if (req->cryptlen <= AES_BLOCK_SIZE) {
		if (req->cryptlen < AES_BLOCK_SIZE)
@@ -317,31 +308,30 @@ static int cts_cbc_encrypt(struct skcipher_request *req)
	}

	if (cbc_blocks > 0) {
		skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
		skcipher_request_set_crypt(&subreq, req->src, req->dst,
					   cbc_blocks * AES_BLOCK_SIZE,
					   req->iv);

		err = skcipher_walk_virt(&walk, &rctx->subreq, false) ?:
		      cbc_encrypt_walk(&rctx->subreq, &walk);
		err = skcipher_walk_virt(&walk, &subreq, false) ?:
		      cbc_encrypt_walk(&subreq, &walk);
		if (err)
			return err;

		if (req->cryptlen == AES_BLOCK_SIZE)
			return 0;

		dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
					     rctx->subreq.cryptlen);
		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
		if (req->dst != req->src)
			dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
					       rctx->subreq.cryptlen);
			dst = scatterwalk_ffwd(sg_dst, req->dst,
					       subreq.cryptlen);
	}

	/* handle ciphertext stealing */
	skcipher_request_set_crypt(&rctx->subreq, src, dst,
	skcipher_request_set_crypt(&subreq, src, dst,
				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
				   req->iv);

	err = skcipher_walk_virt(&walk, &rctx->subreq, false);
	err = skcipher_walk_virt(&walk, &subreq, false);
	if (err)
		return err;

@@ -357,13 +347,16 @@ static int cts_cbc_decrypt(struct skcipher_request *req)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
	struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
	int err, rounds = 6 + ctx->key_length / 4;
	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
	struct scatterlist *src = req->src, *dst = req->dst;
	struct scatterlist sg_src[2], sg_dst[2];
	struct skcipher_request subreq;
	struct skcipher_walk walk;

	skcipher_request_set_tfm(&rctx->subreq, tfm);
	skcipher_request_set_tfm(&subreq, tfm);
	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
				      NULL, NULL);

	if (req->cryptlen <= AES_BLOCK_SIZE) {
		if (req->cryptlen < AES_BLOCK_SIZE)
@@ -372,31 +365,30 @@ static int cts_cbc_decrypt(struct skcipher_request *req)
	}

	if (cbc_blocks > 0) {
		skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
		skcipher_request_set_crypt(&subreq, req->src, req->dst,
					   cbc_blocks * AES_BLOCK_SIZE,
					   req->iv);

		err = skcipher_walk_virt(&walk, &rctx->subreq, false) ?:
		      cbc_decrypt_walk(&rctx->subreq, &walk);
		err = skcipher_walk_virt(&walk, &subreq, false) ?:
		      cbc_decrypt_walk(&subreq, &walk);
		if (err)
			return err;

		if (req->cryptlen == AES_BLOCK_SIZE)
			return 0;

		dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
					     rctx->subreq.cryptlen);
		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
		if (req->dst != req->src)
			dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
					       rctx->subreq.cryptlen);
			dst = scatterwalk_ffwd(sg_dst, req->dst,
					       subreq.cryptlen);
	}

	/* handle ciphertext stealing */
	skcipher_request_set_crypt(&rctx->subreq, src, dst,
	skcipher_request_set_crypt(&subreq, src, dst,
				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
				   req->iv);

	err = skcipher_walk_virt(&walk, &rctx->subreq, false);
	err = skcipher_walk_virt(&walk, &subreq, false);
	if (err)
		return err;

@@ -673,7 +665,6 @@ static struct skcipher_alg aes_algs[] = { {
	.setkey		= skcipher_aes_setkey,
	.encrypt	= cts_cbc_encrypt,
	.decrypt	= cts_cbc_decrypt,
	.init		= cts_cbc_init_tfm,
}, {
	.base = {
		.cra_name		= "__essiv(cbc(aes),sha256)",