armv8crypto: Extract GCM state into a structure

This makes it easier to refactor the GCM code to operate on
crypto_buffer_cursors rather than plain contiguous buffers, with the aim
of minimizing the amount of copying and zeroing done today.

No functional change intended.

Reviewed by:	jhb
MFC after:	1 week
Sponsored by:	Ampere Computing
Submitted by:	Klara, Inc.
Differential Revision:	https://reviews.freebsd.org/D28500
This commit is contained in:
Mark Johnston 2021-02-08 09:19:10 -05:00
parent 0dc7076037
commit 7509b677b4

View File

@ -234,6 +234,14 @@ armv8_aes_decrypt_xts(AES_key_t *data_schedule,
break; \
} while (0)
struct armv8_gcm_state {
__uint128_val_t EK0;
__uint128_val_t EKi;
__uint128_val_t Xi;
__uint128_val_t lenblock;
uint8_t aes_counter[AES_BLOCK_LEN];
};
void
armv8_aes_encrypt_gcm(AES_key_t *aes_key, size_t len,
const uint8_t *from, uint8_t *to,
@ -242,36 +250,34 @@ armv8_aes_encrypt_gcm(AES_key_t *aes_key, size_t len,
const uint8_t iv[static AES_GCM_IV_LEN],
const __uint128_val_t *Htable)
{
size_t i;
struct armv8_gcm_state s;
const uint64_t *from64;
uint64_t *to64;
uint8_t aes_counter[AES_BLOCK_LEN];
uint8_t block[AES_BLOCK_LEN];
size_t trailer;
__uint128_val_t EK0, EKi, Xi, lenblock;
size_t i, trailer;
bzero(&aes_counter, AES_BLOCK_LEN);
memcpy(aes_counter, iv, AES_GCM_IV_LEN);
bzero(&s.aes_counter, AES_BLOCK_LEN);
memcpy(s.aes_counter, iv, AES_GCM_IV_LEN);
/* Setup the counter */
aes_counter[AES_BLOCK_LEN - 1] = 1;
s.aes_counter[AES_BLOCK_LEN - 1] = 1;
/* EK0 for a final GMAC round */
aes_v8_encrypt(aes_counter, EK0.c, aes_key);
aes_v8_encrypt(s.aes_counter, s.EK0.c, aes_key);
/* GCM starts with 2 as counter, 1 is used for final xor of tag. */
aes_counter[AES_BLOCK_LEN - 1] = 2;
s.aes_counter[AES_BLOCK_LEN - 1] = 2;
memset(Xi.c, 0, sizeof(Xi.c));
memset(s.Xi.c, 0, sizeof(s.Xi.c));
trailer = authdatalen % AES_BLOCK_LEN;
if (authdatalen - trailer > 0) {
gcm_ghash_v8(Xi.u, Htable, authdata, authdatalen - trailer);
gcm_ghash_v8(s.Xi.u, Htable, authdata, authdatalen - trailer);
authdata += authdatalen - trailer;
}
if (trailer > 0 || authdatalen == 0) {
memset(block, 0, sizeof(block));
memcpy(block, authdata, trailer);
gcm_ghash_v8(Xi.u, Htable, block, AES_BLOCK_LEN);
gcm_ghash_v8(s.Xi.u, Htable, block, AES_BLOCK_LEN);
}
from64 = (const uint64_t*)from;
@ -279,11 +285,11 @@ armv8_aes_encrypt_gcm(AES_key_t *aes_key, size_t len,
trailer = len % AES_BLOCK_LEN;
for (i = 0; i < (len - trailer); i += AES_BLOCK_LEN) {
aes_v8_encrypt(aes_counter, EKi.c, aes_key);
AES_INC_COUNTER(aes_counter);
to64[0] = from64[0] ^ EKi.u[0];
to64[1] = from64[1] ^ EKi.u[1];
gcm_ghash_v8(Xi.u, Htable, (uint8_t*)to64, AES_BLOCK_LEN);
aes_v8_encrypt(s.aes_counter, s.EKi.c, aes_key);
AES_INC_COUNTER(s.aes_counter);
to64[0] = from64[0] ^ s.EKi.u[0];
to64[1] = from64[1] ^ s.EKi.u[1];
gcm_ghash_v8(s.Xi.u, Htable, (uint8_t*)to64, AES_BLOCK_LEN);
to64 += 2;
from64 += 2;
@ -293,31 +299,27 @@ armv8_aes_encrypt_gcm(AES_key_t *aes_key, size_t len,
from += (len - trailer);
if (trailer) {
aes_v8_encrypt(aes_counter, EKi.c, aes_key);
AES_INC_COUNTER(aes_counter);
aes_v8_encrypt(s.aes_counter, s.EKi.c, aes_key);
AES_INC_COUNTER(s.aes_counter);
memset(block, 0, sizeof(block));
for (i = 0; i < trailer; i++) {
block[i] = to[i] = from[i] ^ EKi.c[i];
block[i] = to[i] = from[i] ^ s.EKi.c[i];
}
gcm_ghash_v8(Xi.u, Htable, block, AES_BLOCK_LEN);
gcm_ghash_v8(s.Xi.u, Htable, block, AES_BLOCK_LEN);
}
/* Lengths block */
lenblock.u[0] = lenblock.u[1] = 0;
lenblock.d[1] = htobe32(authdatalen * 8);
lenblock.d[3] = htobe32(len * 8);
gcm_ghash_v8(Xi.u, Htable, lenblock.c, AES_BLOCK_LEN);
s.lenblock.u[0] = s.lenblock.u[1] = 0;
s.lenblock.d[1] = htobe32(authdatalen * 8);
s.lenblock.d[3] = htobe32(len * 8);
gcm_ghash_v8(s.Xi.u, Htable, s.lenblock.c, AES_BLOCK_LEN);
Xi.u[0] ^= EK0.u[0];
Xi.u[1] ^= EK0.u[1];
memcpy(tag, Xi.c, GMAC_DIGEST_LEN);
s.Xi.u[0] ^= s.EK0.u[0];
s.Xi.u[1] ^= s.EK0.u[1];
memcpy(tag, s.Xi.c, GMAC_DIGEST_LEN);
explicit_bzero(aes_counter, sizeof(aes_counter));
explicit_bzero(Xi.c, sizeof(Xi.c));
explicit_bzero(EK0.c, sizeof(EK0.c));
explicit_bzero(EKi.c, sizeof(EKi.c));
explicit_bzero(lenblock.c, sizeof(lenblock.c));
explicit_bzero(&s, sizeof(s));
}
int
@ -328,70 +330,68 @@ armv8_aes_decrypt_gcm(AES_key_t *aes_key, size_t len,
const uint8_t iv[static AES_GCM_IV_LEN],
const __uint128_val_t *Htable)
{
size_t i;
struct armv8_gcm_state s;
const uint64_t *from64;
uint64_t *to64;
uint8_t aes_counter[AES_BLOCK_LEN];
uint8_t block[AES_BLOCK_LEN];
size_t trailer;
__uint128_val_t EK0, EKi, Xi, lenblock;
size_t i, trailer;
int error;
error = 0;
bzero(&aes_counter, AES_BLOCK_LEN);
memcpy(aes_counter, iv, AES_GCM_IV_LEN);
bzero(&s.aes_counter, AES_BLOCK_LEN);
memcpy(s.aes_counter, iv, AES_GCM_IV_LEN);
/* Setup the counter */
aes_counter[AES_BLOCK_LEN - 1] = 1;
s.aes_counter[AES_BLOCK_LEN - 1] = 1;
/* EK0 for a final GMAC round */
aes_v8_encrypt(aes_counter, EK0.c, aes_key);
aes_v8_encrypt(s.aes_counter, s.EK0.c, aes_key);
memset(Xi.c, 0, sizeof(Xi.c));
memset(s.Xi.c, 0, sizeof(s.Xi.c));
trailer = authdatalen % AES_BLOCK_LEN;
if (authdatalen - trailer > 0) {
gcm_ghash_v8(Xi.u, Htable, authdata, authdatalen - trailer);
gcm_ghash_v8(s.Xi.u, Htable, authdata, authdatalen - trailer);
authdata += authdatalen - trailer;
}
if (trailer > 0 || authdatalen == 0) {
memset(block, 0, sizeof(block));
memcpy(block, authdata, trailer);
gcm_ghash_v8(Xi.u, Htable, block, AES_BLOCK_LEN);
gcm_ghash_v8(s.Xi.u, Htable, block, AES_BLOCK_LEN);
}
trailer = len % AES_BLOCK_LEN;
if (len - trailer > 0)
gcm_ghash_v8(Xi.u, Htable, from, len - trailer);
gcm_ghash_v8(s.Xi.u, Htable, from, len - trailer);
if (trailer > 0) {
memset(block, 0, sizeof(block));
memcpy(block, from + len - trailer, trailer);
gcm_ghash_v8(Xi.u, Htable, block, AES_BLOCK_LEN);
gcm_ghash_v8(s.Xi.u, Htable, block, AES_BLOCK_LEN);
}
/* Lengths block */
lenblock.u[0] = lenblock.u[1] = 0;
lenblock.d[1] = htobe32(authdatalen * 8);
lenblock.d[3] = htobe32(len * 8);
gcm_ghash_v8(Xi.u, Htable, lenblock.c, AES_BLOCK_LEN);
s.lenblock.u[0] = s.lenblock.u[1] = 0;
s.lenblock.d[1] = htobe32(authdatalen * 8);
s.lenblock.d[3] = htobe32(len * 8);
gcm_ghash_v8(s.Xi.u, Htable, s.lenblock.c, AES_BLOCK_LEN);
Xi.u[0] ^= EK0.u[0];
Xi.u[1] ^= EK0.u[1];
if (timingsafe_bcmp(tag, Xi.c, GMAC_DIGEST_LEN) != 0) {
s.Xi.u[0] ^= s.EK0.u[0];
s.Xi.u[1] ^= s.EK0.u[1];
if (timingsafe_bcmp(tag, s.Xi.c, GMAC_DIGEST_LEN) != 0) {
error = EBADMSG;
goto out;
}
/* GCM starts with 2 as counter, 1 is used for final xor of tag. */
aes_counter[AES_BLOCK_LEN - 1] = 2;
s.aes_counter[AES_BLOCK_LEN - 1] = 2;
from64 = (const uint64_t*)from;
to64 = (uint64_t*)to;
for (i = 0; i < (len - trailer); i += AES_BLOCK_LEN) {
aes_v8_encrypt(aes_counter, EKi.c, aes_key);
AES_INC_COUNTER(aes_counter);
to64[0] = from64[0] ^ EKi.u[0];
to64[1] = from64[1] ^ EKi.u[1];
aes_v8_encrypt(s.aes_counter, s.EKi.c, aes_key);
AES_INC_COUNTER(s.aes_counter);
to64[0] = from64[0] ^ s.EKi.u[0];
to64[1] = from64[1] ^ s.EKi.u[1];
to64 += 2;
from64 += 2;
}
@ -400,18 +400,13 @@ armv8_aes_decrypt_gcm(AES_key_t *aes_key, size_t len,
from += (len - trailer);
if (trailer) {
aes_v8_encrypt(aes_counter, EKi.c, aes_key);
AES_INC_COUNTER(aes_counter);
aes_v8_encrypt(s.aes_counter, s.EKi.c, aes_key);
AES_INC_COUNTER(s.aes_counter);
for (i = 0; i < trailer; i++)
to[i] = from[i] ^ EKi.c[i];
to[i] = from[i] ^ s.EKi.c[i];
}
out:
explicit_bzero(aes_counter, sizeof(aes_counter));
explicit_bzero(Xi.c, sizeof(Xi.c));
explicit_bzero(EK0.c, sizeof(EK0.c));
explicit_bzero(EKi.c, sizeof(EKi.c));
explicit_bzero(lenblock.c, sizeof(lenblock.c));
explicit_bzero(&s, sizeof(s));
return (error);
}