Commit c61b1607 authored by Ard Biesheuvel's avatar Ard Biesheuvel Committed by Herbert Xu
Browse files

crypto: arm/aes-ce - implement ciphertext stealing for XTS



Update the AES-XTS implementation based on AES instructions so that it
can deal with inputs whose size is not a multiple of the cipher block
size. This is part of the original XTS specification, but was never
implemented before in the Linux kernel.

Signed-off-by: default avatarArd Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 67cfa5d3
Loading
Loading
Loading
Loading
+92 −11
Original line number Diff line number Diff line
@@ -369,9 +369,9 @@ ENDPROC(ce_aes_ctr_encrypt)

	/*
	 * aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[], int rounds,
	 *		   int blocks, u8 iv[], u32 const rk2[], int first)
	 *		   int bytes, u8 iv[], u32 const rk2[], int first)
	 * aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[], int rounds,
	 *		   int blocks, u8 iv[], u32 const rk2[], int first)
	 *		   int bytes, u8 iv[], u32 const rk2[], int first)
	 */

	.macro		next_tweak, out, in, const, tmp
@@ -414,7 +414,7 @@ ENTRY(ce_aes_xts_encrypt)
.Lxtsencloop4x:
	next_tweak	q4, q4, q15, q10
.Lxtsenc4x:
	subs		r4, r4, #4
	subs		r4, r4, #64
	bmi		.Lxtsenc1x
	vld1.8		{q0-q1}, [r1]!		@ get 4 pt blocks
	vld1.8		{q2-q3}, [r1]!
@@ -434,24 +434,58 @@ ENTRY(ce_aes_xts_encrypt)
	vst1.8		{q2-q3}, [r0]!
	vmov		q4, q7
	teq		r4, #0
	beq		.Lxtsencout
	beq		.Lxtsencret
	b		.Lxtsencloop4x
.Lxtsenc1x:
	adds		r4, r4, #4
	adds		r4, r4, #64
	beq		.Lxtsencout
	subs		r4, r4, #16
	bmi		.LxtsencctsNx
.Lxtsencloop:
	vld1.8		{q0}, [r1]!
.Lxtsencctsout:
	veor		q0, q0, q4
	bl		aes_encrypt
	veor		q0, q0, q4
	vst1.8		{q0}, [r0]!
	subs		r4, r4, #1
	teq		r4, #0
	beq		.Lxtsencout
	subs		r4, r4, #16
	next_tweak	q4, q4, q15, q6
	bmi		.Lxtsenccts
	vst1.8		{q0}, [r0]!
	b		.Lxtsencloop
.Lxtsencout:
	vst1.8		{q0}, [r0]
.Lxtsencret:
	vst1.8		{q4}, [r5]
	pop		{r4-r6, pc}

.LxtsencctsNx:
	vmov		q0, q3
	sub		r0, r0, #16
.Lxtsenccts:
	movw		ip, :lower16:.Lcts_permute_table
	movt		ip, :upper16:.Lcts_permute_table

	add		r1, r1, r4		@ rewind input pointer
	add		r4, r4, #16		@ # bytes in final block
	add		lr, ip, #32
	add		ip, ip, r4
	sub		lr, lr, r4
	add		r4, r0, r4		@ output address of final block

	vld1.8		{q1}, [r1]		@ load final partial block
	vld1.8		{q2}, [ip]
	vld1.8		{q3}, [lr]

	vtbl.8		d4, {d0-d1}, d4
	vtbl.8		d5, {d0-d1}, d5
	vtbx.8		d0, {d2-d3}, d6
	vtbx.8		d1, {d2-d3}, d7

	vst1.8		{q2}, [r4]		@ overlapping stores
	mov		r4, #0
	b		.Lxtsencctsout
ENDPROC(ce_aes_xts_encrypt)


@@ -462,13 +496,17 @@ ENTRY(ce_aes_xts_decrypt)
	prepare_key	r2, r3
	vmov		q4, q0

	/* subtract 16 bytes if we are doing CTS */
	tst		r4, #0xf
	subne		r4, r4, #0x10

	teq		r6, #0			@ start of a block?
	bne		.Lxtsdec4x

.Lxtsdecloop4x:
	next_tweak	q4, q4, q15, q10
.Lxtsdec4x:
	subs		r4, r4, #4
	subs		r4, r4, #64
	bmi		.Lxtsdec1x
	vld1.8		{q0-q1}, [r1]!		@ get 4 ct blocks
	vld1.8		{q2-q3}, [r1]!
@@ -491,22 +529,55 @@ ENTRY(ce_aes_xts_decrypt)
	beq		.Lxtsdecout
	b		.Lxtsdecloop4x
.Lxtsdec1x:
	adds		r4, r4, #4
	adds		r4, r4, #64
	beq		.Lxtsdecout
	subs		r4, r4, #16
.Lxtsdecloop:
	vld1.8		{q0}, [r1]!
	bmi		.Lxtsdeccts
.Lxtsdecctsout:
	veor		q0, q0, q4
	add		ip, r2, #32		@ 3rd round key
	bl		aes_decrypt
	veor		q0, q0, q4
	vst1.8		{q0}, [r0]!
	subs		r4, r4, #1
	teq		r4, #0
	beq		.Lxtsdecout
	subs		r4, r4, #16
	next_tweak	q4, q4, q15, q6
	b		.Lxtsdecloop
.Lxtsdecout:
	vst1.8		{q4}, [r5]
	pop		{r4-r6, pc}

.Lxtsdeccts:
	movw		ip, :lower16:.Lcts_permute_table
	movt		ip, :upper16:.Lcts_permute_table

	add		r1, r1, r4		@ rewind input pointer
	add		r4, r4, #16		@ # bytes in final block
	add		lr, ip, #32
	add		ip, ip, r4
	sub		lr, lr, r4
	add		r4, r0, r4		@ output address of final block

	next_tweak	q5, q4, q15, q6

	vld1.8		{q1}, [r1]		@ load final partial block
	vld1.8		{q2}, [ip]
	vld1.8		{q3}, [lr]

	veor		q0, q0, q5
	bl		aes_decrypt
	veor		q0, q0, q5

	vtbl.8		d4, {d0-d1}, d4
	vtbl.8		d5, {d0-d1}, d5
	vtbx.8		d0, {d2-d3}, d6
	vtbx.8		d1, {d2-d3}, d7

	vst1.8		{q2}, [r4]		@ overlapping stores
	mov		r4, #0
	b		.Lxtsdecctsout
ENDPROC(ce_aes_xts_decrypt)

	/*
@@ -532,3 +603,13 @@ ENTRY(ce_aes_invert)
	vst1.32		{q0}, [r0]
	bx		lr
ENDPROC(ce_aes_invert)

	.section	".rodata", "a"
	.align		6
.Lcts_permute_table:
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		 0x0,  0x1,  0x2,  0x3,  0x4,  0x5,  0x6,  0x7
	.byte		 0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,  0xf
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+116 −12
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@
#include <crypto/ctr.h>
#include <crypto/internal/simd.h>
#include <crypto/internal/skcipher.h>
#include <crypto/scatterwalk.h>
#include <linux/cpufeature.h>
#include <linux/module.h>
#include <crypto/xts.h>
@@ -39,10 +40,10 @@ asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
				   int rounds, int blocks, u8 ctr[]);

asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
				   int rounds, int blocks, u8 iv[],
				   int rounds, int bytes, u8 iv[],
				   u32 const rk2[], int first);
asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
				   int rounds, int blocks, u8 iv[],
				   int rounds, int bytes, u8 iv[],
				   u32 const rk2[], int first);

struct aes_block {
@@ -317,20 +318,71 @@ static int xts_encrypt(struct skcipher_request *req)
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
	int err, first, rounds = num_rounds(&ctx->key1);
	int tail = req->cryptlen % AES_BLOCK_SIZE;
	struct scatterlist sg_src[2], sg_dst[2];
	struct skcipher_request subreq;
	struct scatterlist *src, *dst;
	struct skcipher_walk walk;
	unsigned int blocks;

	if (req->cryptlen < AES_BLOCK_SIZE)
		return -EINVAL;

	err = skcipher_walk_virt(&walk, req, false);

	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
					      AES_BLOCK_SIZE) - 2;

		skcipher_walk_abort(&walk);

		skcipher_request_set_tfm(&subreq, tfm);
		skcipher_request_set_callback(&subreq,
					      skcipher_request_flags(req),
					      NULL, NULL);
		skcipher_request_set_crypt(&subreq, req->src, req->dst,
					   xts_blocks * AES_BLOCK_SIZE,
					   req->iv);
		req = &subreq;
		err = skcipher_walk_virt(&walk, req, false);
	} else {
		tail = 0;
	}

	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
		int nbytes = walk.nbytes;

		if (walk.nbytes < walk.total)
			nbytes &= ~(AES_BLOCK_SIZE - 1);

	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
		kernel_neon_begin();
		ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
				   ctx->key1.key_enc, rounds, blocks, walk.iv,
				   ctx->key1.key_enc, rounds, nbytes, walk.iv,
				   ctx->key2.key_enc, first);
		kernel_neon_end();
		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
	}

	if (err || likely(!tail))
		return err;

	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
	if (req->dst != req->src)
		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);

	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
				   req->iv);

	err = skcipher_walk_virt(&walk, req, false);
	if (err)
		return err;

	kernel_neon_begin();
	ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
			   ctx->key1.key_enc, rounds, walk.nbytes, walk.iv,
			   ctx->key2.key_enc, first);
	kernel_neon_end();

	return skcipher_walk_done(&walk, 0);
}

static int xts_decrypt(struct skcipher_request *req)
@@ -338,20 +390,71 @@ static int xts_decrypt(struct skcipher_request *req)
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
	int err, first, rounds = num_rounds(&ctx->key1);
	int tail = req->cryptlen % AES_BLOCK_SIZE;
	struct scatterlist sg_src[2], sg_dst[2];
	struct skcipher_request subreq;
	struct scatterlist *src, *dst;
	struct skcipher_walk walk;
	unsigned int blocks;

	if (req->cryptlen < AES_BLOCK_SIZE)
		return -EINVAL;

	err = skcipher_walk_virt(&walk, req, false);

	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
					      AES_BLOCK_SIZE) - 2;

		skcipher_walk_abort(&walk);

		skcipher_request_set_tfm(&subreq, tfm);
		skcipher_request_set_callback(&subreq,
					      skcipher_request_flags(req),
					      NULL, NULL);
		skcipher_request_set_crypt(&subreq, req->src, req->dst,
					   xts_blocks * AES_BLOCK_SIZE,
					   req->iv);
		req = &subreq;
		err = skcipher_walk_virt(&walk, req, false);
	} else {
		tail = 0;
	}

	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
		int nbytes = walk.nbytes;

		if (walk.nbytes < walk.total)
			nbytes &= ~(AES_BLOCK_SIZE - 1);

	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
		kernel_neon_begin();
		ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
				   ctx->key1.key_dec, rounds, blocks, walk.iv,
				   ctx->key1.key_dec, rounds, nbytes, walk.iv,
				   ctx->key2.key_enc, first);
		kernel_neon_end();
		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
	}

	if (err || likely(!tail))
		return err;

	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
	if (req->dst != req->src)
		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);

	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
				   req->iv);

	err = skcipher_walk_virt(&walk, req, false);
	if (err)
		return err;

	kernel_neon_begin();
	ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
			   ctx->key1.key_dec, rounds, walk.nbytes, walk.iv,
			   ctx->key2.key_enc, first);
	kernel_neon_end();

	return skcipher_walk_done(&walk, 0);
}

static struct skcipher_alg aes_algs[] = { {
@@ -426,6 +529,7 @@ static struct skcipher_alg aes_algs[] = { {
	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
	.ivsize			= AES_BLOCK_SIZE,
	.walksize		= 2 * AES_BLOCK_SIZE,
	.setkey			= xts_set_key,
	.encrypt		= xts_encrypt,
	.decrypt		= xts_decrypt,