Commit 143d2647 authored by Ard Biesheuvel's avatar Ard Biesheuvel Committed by Herbert Xu
Browse files

crypto: arm/aes-ce - implement ciphertext stealing for CBC



Instead of relying on the CTS template to wrap the accelerated CBC
skcipher, implement the ciphertext stealing part directly.

Signed-off-by: default avatarArd Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 2ed8b790
Loading
Loading
Loading
Loading
+85 −0
Original line number Diff line number Diff line
@@ -284,6 +284,91 @@ ENTRY(ce_aes_cbc_decrypt)
	pop		{r4-r6, pc}
ENDPROC(ce_aes_cbc_decrypt)


	/*
	 * ce_aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
	 *			  int rounds, int bytes, u8 const iv[])
	 * ce_aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
	 *			  int rounds, int bytes, u8 const iv[])
	 */

ENTRY(ce_aes_cbc_cts_encrypt)
	push		{r4-r6, lr}
	ldrd		r4, r5, [sp, #16]

	movw		ip, :lower16:.Lcts_permute_table
	movt		ip, :upper16:.Lcts_permute_table
	sub		r4, r4, #16
	add		lr, ip, #32
	add		ip, ip, r4
	sub		lr, lr, r4
	vld1.8		{q5}, [ip]
	vld1.8		{q6}, [lr]

	add		ip, r1, r4
	vld1.8		{q0}, [r1]			@ overlapping loads
	vld1.8		{q3}, [ip]

	vld1.8		{q1}, [r5]			@ get iv
	prepare_key	r2, r3

	veor		q0, q0, q1			@ xor with iv
	bl		aes_encrypt

	vtbl.8		d4, {d0-d1}, d10
	vtbl.8		d5, {d0-d1}, d11
	vtbl.8		d2, {d6-d7}, d12
	vtbl.8		d3, {d6-d7}, d13

	veor		q0, q0, q1
	bl		aes_encrypt

	add		r4, r0, r4
	vst1.8		{q2}, [r4]			@ overlapping stores
	vst1.8		{q0}, [r0]

	pop		{r4-r6, pc}
ENDPROC(ce_aes_cbc_cts_encrypt)

ENTRY(ce_aes_cbc_cts_decrypt)
	push		{r4-r6, lr}
	ldrd		r4, r5, [sp, #16]

	movw		ip, :lower16:.Lcts_permute_table
	movt		ip, :upper16:.Lcts_permute_table
	sub		r4, r4, #16
	add		lr, ip, #32
	add		ip, ip, r4
	sub		lr, lr, r4
	vld1.8		{q5}, [ip]
	vld1.8		{q6}, [lr]

	add		ip, r1, r4
	vld1.8		{q0}, [r1]			@ overlapping loads
	vld1.8		{q1}, [ip]

	vld1.8		{q3}, [r5]			@ get iv
	prepare_key	r2, r3

	bl		aes_decrypt

	vtbl.8		d4, {d0-d1}, d10
	vtbl.8		d5, {d0-d1}, d11
	vtbx.8		d0, {d2-d3}, d12
	vtbx.8		d1, {d2-d3}, d13

	veor		q1, q1, q2
	bl		aes_decrypt
	veor		q0, q0, q3			@ xor with iv

	add		r4, r0, r4
	vst1.8		{q1}, [r4]			@ overlapping stores
	vst1.8		{q0}, [r0]

	pop		{r4-r6, pc}
ENDPROC(ce_aes_cbc_cts_decrypt)


	/*
	 * aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[], int rounds,
	 *		   int blocks, u8 ctr[])
+171 −17
Original line number Diff line number Diff line
@@ -35,6 +35,10 @@ asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
				   int rounds, int blocks, u8 iv[]);
asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
				   int rounds, int blocks, u8 iv[]);
asmlinkage void ce_aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
				   int rounds, int bytes, u8 const iv[]);
asmlinkage void ce_aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
				   int rounds, int bytes, u8 const iv[]);

asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
				   int rounds, int blocks, u8 ctr[]);
@@ -210,46 +214,180 @@ static int ecb_decrypt(struct skcipher_request *req)
	return err;
}

static int cbc_encrypt(struct skcipher_request *req)
static int cbc_encrypt_walk(struct skcipher_request *req,
			    struct skcipher_walk *walk)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
	struct skcipher_walk walk;
	unsigned int blocks;
	int err = 0;

	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
		kernel_neon_begin();
		ce_aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
				   ctx->key_enc, num_rounds(ctx), blocks,
				   walk->iv);
		kernel_neon_end();
		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
	}
	return err;
}

static int cbc_encrypt(struct skcipher_request *req)
{
	struct skcipher_walk walk;
	int err;

	err = skcipher_walk_virt(&walk, req, false);
	if (err)
		return err;
	return cbc_encrypt_walk(req, &walk);
}

	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
static int cbc_decrypt_walk(struct skcipher_request *req,
			    struct skcipher_walk *walk)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
	unsigned int blocks;
	int err = 0;

	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
		kernel_neon_begin();
		ce_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
				   ctx->key_enc, num_rounds(ctx), blocks,
				   walk.iv);
		ce_aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
				   ctx->key_dec, num_rounds(ctx), blocks,
				   walk->iv);
		kernel_neon_end();
		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
	}
	return err;
}

static int cbc_decrypt(struct skcipher_request *req)
{
	struct skcipher_walk walk;
	int err;

	err = skcipher_walk_virt(&walk, req, false);
	if (err)
		return err;
	return cbc_decrypt_walk(req, &walk);
}

static int cts_cbc_encrypt(struct skcipher_request *req)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
	struct scatterlist *src = req->src, *dst = req->dst;
	struct scatterlist sg_src[2], sg_dst[2];
	struct skcipher_request subreq;
	struct skcipher_walk walk;
	unsigned int blocks;
	int err;

	err = skcipher_walk_virt(&walk, req, false);
	skcipher_request_set_tfm(&subreq, tfm);
	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
				      NULL, NULL);

	if (req->cryptlen <= AES_BLOCK_SIZE) {
		if (req->cryptlen < AES_BLOCK_SIZE)
			return -EINVAL;
		cbc_blocks = 1;
	}

	if (cbc_blocks > 0) {
		skcipher_request_set_crypt(&subreq, req->src, req->dst,
					   cbc_blocks * AES_BLOCK_SIZE,
					   req->iv);

		err = skcipher_walk_virt(&walk, &subreq, false) ?:
		      cbc_encrypt_walk(&subreq, &walk);
		if (err)
			return err;

		if (req->cryptlen == AES_BLOCK_SIZE)
			return 0;

		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
		if (req->dst != req->src)
			dst = scatterwalk_ffwd(sg_dst, req->dst,
					       subreq.cryptlen);
	}

	/* handle ciphertext stealing */
	skcipher_request_set_crypt(&subreq, src, dst,
				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
				   req->iv);

	err = skcipher_walk_virt(&walk, &subreq, false);
	if (err)
		return err;

	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
	kernel_neon_begin();
		ce_aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
				   ctx->key_dec, num_rounds(ctx), blocks,
	ce_aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
			       ctx->key_enc, num_rounds(ctx), walk.nbytes,
			       walk.iv);
	kernel_neon_end();
		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);

	return skcipher_walk_done(&walk, 0);
}

static int cts_cbc_decrypt(struct skcipher_request *req)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
	struct scatterlist *src = req->src, *dst = req->dst;
	struct scatterlist sg_src[2], sg_dst[2];
	struct skcipher_request subreq;
	struct skcipher_walk walk;
	int err;

	skcipher_request_set_tfm(&subreq, tfm);
	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
				      NULL, NULL);

	if (req->cryptlen <= AES_BLOCK_SIZE) {
		if (req->cryptlen < AES_BLOCK_SIZE)
			return -EINVAL;
		cbc_blocks = 1;
	}

	if (cbc_blocks > 0) {
		skcipher_request_set_crypt(&subreq, req->src, req->dst,
					   cbc_blocks * AES_BLOCK_SIZE,
					   req->iv);

		err = skcipher_walk_virt(&walk, &subreq, false) ?:
		      cbc_decrypt_walk(&subreq, &walk);
		if (err)
			return err;

		if (req->cryptlen == AES_BLOCK_SIZE)
			return 0;

		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
		if (req->dst != req->src)
			dst = scatterwalk_ffwd(sg_dst, req->dst,
					       subreq.cryptlen);
	}

	/* handle ciphertext stealing */
	skcipher_request_set_crypt(&subreq, src, dst,
				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
				   req->iv);

	err = skcipher_walk_virt(&walk, &subreq, false);
	if (err)
		return err;

	kernel_neon_begin();
	ce_aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
			       ctx->key_dec, num_rounds(ctx), walk.nbytes,
			       walk.iv);
	kernel_neon_end();

	return skcipher_walk_done(&walk, 0);
}

static int ctr_encrypt(struct skcipher_request *req)
@@ -486,6 +624,22 @@ static struct skcipher_alg aes_algs[] = { {
	.setkey			= ce_aes_setkey,
	.encrypt		= cbc_encrypt,
	.decrypt		= cbc_decrypt,
}, {
	.base.cra_name		= "__cts(cbc(aes))",
	.base.cra_driver_name	= "__cts-cbc-aes-ce",
	.base.cra_priority	= 300,
	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
	.base.cra_blocksize	= AES_BLOCK_SIZE,
	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
	.base.cra_module	= THIS_MODULE,

	.min_keysize		= AES_MIN_KEY_SIZE,
	.max_keysize		= AES_MAX_KEY_SIZE,
	.ivsize			= AES_BLOCK_SIZE,
	.walksize		= 2 * AES_BLOCK_SIZE,
	.setkey			= ce_aes_setkey,
	.encrypt		= cts_cbc_encrypt,
	.decrypt		= cts_cbc_decrypt,
}, {
	.base.cra_name		= "__ctr(aes)",
	.base.cra_driver_name	= "__ctr-aes-ce",