diff --git a/sys/opencrypto/ktls.h b/sys/opencrypto/ktls.h index b97f589fecb4..503864f87ccc 100644 --- a/sys/opencrypto/ktls.h +++ b/sys/opencrypto/ktls.h @@ -55,5 +55,8 @@ int ktls_ocf_encrypt(struct ktls_ocf_encrypt_state *state, int ktls_ocf_decrypt(struct ktls_session *tls, const struct tls_record_layer *hdr, struct mbuf *m, uint64_t seqno, int *trailer_len); +int ktls_ocf_recrypt(struct ktls_session *tls, + const struct tls_record_layer *hdr, struct mbuf *m, uint64_t seqno); +bool ktls_ocf_recrypt_supported(struct ktls_session *tls); #endif /* !__OPENCRYPTO_KTLS_H__ */ diff --git a/sys/opencrypto/ktls_ocf.c b/sys/opencrypto/ktls_ocf.c index 3b330bf7061c..6347ca459646 100644 --- a/sys/opencrypto/ktls_ocf.c +++ b/sys/opencrypto/ktls_ocf.c @@ -44,6 +44,7 @@ __FBSDID("$FreeBSD$"); #include #include #include +#include #include #include @@ -53,6 +54,11 @@ struct ktls_ocf_sw { struct ktls_session *tls, struct mbuf *m, struct iovec *outiov, int outiovcnt); + /* Re-encrypt a received TLS record that is partially decrypted. */ + int (*recrypt)(struct ktls_session *tls, + const struct tls_record_layer *hdr, struct mbuf *m, + uint64_t seqno); + /* Decrypt a received TLS record. */ int (*decrypt)(struct ktls_session *tls, const struct tls_record_layer *hdr, struct mbuf *m, @@ -63,6 +69,7 @@ struct ktls_ocf_session { const struct ktls_ocf_sw *sw; crypto_session_t sid; crypto_session_t mac_sid; + crypto_session_t recrypt_sid; struct mtx lock; int mac_len; bool implicit_iv; @@ -109,6 +116,11 @@ SYSCTL_COUNTER_U64(_kern_ipc_tls_stats_ocf, OID_AUTO, tls12_gcm_encrypts, CTLFLAG_RD, &ocf_tls12_gcm_encrypts, "Total number of OCF TLS 1.2 GCM encryption operations"); +static COUNTER_U64_DEFINE_EARLY(ocf_tls12_gcm_recrypts); +SYSCTL_COUNTER_U64(_kern_ipc_tls_stats_ocf, OID_AUTO, tls12_gcm_recrypts, + CTLFLAG_RD, &ocf_tls12_gcm_recrypts, + "Total number of OCF TLS 1.2 GCM re-encryption operations"); + static COUNTER_U64_DEFINE_EARLY(ocf_tls12_chacha20_decrypts); SYSCTL_COUNTER_U64(_kern_ipc_tls_stats_ocf, OID_AUTO, tls12_chacha20_decrypts, CTLFLAG_RD, &ocf_tls12_chacha20_decrypts, @@ -129,6 +141,11 @@ SYSCTL_COUNTER_U64(_kern_ipc_tls_stats_ocf, OID_AUTO, tls13_gcm_encrypts, CTLFLAG_RD, &ocf_tls13_gcm_encrypts, "Total number of OCF TLS 1.3 GCM encryption operations"); +static COUNTER_U64_DEFINE_EARLY(ocf_tls13_gcm_recrypts); +SYSCTL_COUNTER_U64(_kern_ipc_tls_stats_ocf, OID_AUTO, tls13_gcm_recrypts, + CTLFLAG_RD, &ocf_tls13_gcm_recrypts, + "Total number of OCF TLS 1.3 GCM re-encryption operations"); + static COUNTER_U64_DEFINE_EARLY(ocf_tls13_chacha20_decrypts); SYSCTL_COUNTER_U64(_kern_ipc_tls_stats_ocf, OID_AUTO, tls13_chacha20_decrypts, CTLFLAG_RD, &ocf_tls13_chacha20_decrypts, @@ -549,8 +566,84 @@ ktls_ocf_tls12_aead_decrypt(struct ktls_session *tls, return (error); } +/* + * Reconstruct encrypted mbuf data in input buffer. + */ +static void +ktls_ocf_recrypt_fixup(struct mbuf *m, u_int skip, u_int len, char *buf) +{ + const char *src = buf; + u_int todo; + + while (skip >= m->m_len) { + skip -= m->m_len; + m = m->m_next; + } + + while (len > 0) { + todo = m->m_len - skip; + if (todo > len) + todo = len; + + if (m->m_flags & M_DECRYPTED) + memcpy(mtod(m, char *) + skip, src, todo); + src += todo; + len -= todo; + skip = 0; + m = m->m_next; + } +} + +static int +ktls_ocf_tls12_aead_recrypt(struct ktls_session *tls, + const struct tls_record_layer *hdr, struct mbuf *m, + uint64_t seqno) +{ + struct cryptop crp; + struct ktls_ocf_session *os; + char *buf; + u_int payload_len; + int error; + + os = tls->ocf_session; + + crypto_initreq(&crp, os->recrypt_sid); + + KASSERT(tls->params.cipher_algorithm == CRYPTO_AES_NIST_GCM_16, + ("%s: only AES-GCM is supported", __func__)); + + /* Setup the IV. */ + memcpy(crp.crp_iv, tls->params.iv, TLS_AEAD_GCM_LEN); + memcpy(crp.crp_iv + TLS_AEAD_GCM_LEN, hdr + 1, sizeof(uint64_t)); + be32enc(crp.crp_iv + AES_GCM_IV_LEN, 2); + + payload_len = ntohs(hdr->tls_length) - + (AES_GMAC_HASH_LEN + sizeof(uint64_t)); + crp.crp_op = CRYPTO_OP_ENCRYPT; + crp.crp_flags = CRYPTO_F_CBIMM | CRYPTO_F_IV_SEPARATE; + crypto_use_mbuf(&crp, m); + crp.crp_payload_start = tls->params.tls_hlen; + crp.crp_payload_length = payload_len; + + buf = malloc(payload_len, M_KTLS_OCF, M_WAITOK); + crypto_use_output_buf(&crp, buf, payload_len); + + counter_u64_add(ocf_tls12_gcm_recrypts, 1); + error = ktls_ocf_dispatch(os, &crp); + + crypto_destroyreq(&crp); + + if (error == 0) + ktls_ocf_recrypt_fixup(m, tls->params.tls_hlen, payload_len, + buf); + + free(buf, M_KTLS_OCF); + return (error); +} + static const struct ktls_ocf_sw ktls_ocf_tls12_aead_sw = { .encrypt = ktls_ocf_tls12_aead_encrypt, + .recrypt = ktls_ocf_tls12_aead_recrypt, .decrypt = ktls_ocf_tls12_aead_decrypt, }; @@ -681,8 +774,55 @@ ktls_ocf_tls13_aead_decrypt(struct ktls_session *tls, return (error); } +static int +ktls_ocf_tls13_aead_recrypt(struct ktls_session *tls, + const struct tls_record_layer *hdr, struct mbuf *m, + uint64_t seqno) +{ + struct cryptop crp; + struct ktls_ocf_session *os; + char *buf; + u_int payload_len; + int error; + + os = tls->ocf_session; + + crypto_initreq(&crp, os->recrypt_sid); + + KASSERT(tls->params.cipher_algorithm == CRYPTO_AES_NIST_GCM_16, + ("%s: only AES-GCM is supported", __func__)); + + /* Setup the IV. */ + memcpy(crp.crp_iv, tls->params.iv, tls->params.iv_len); + *(uint64_t *)(crp.crp_iv + 4) ^= htobe64(seqno); + be32enc(crp.crp_iv + 12, 2); + + payload_len = ntohs(hdr->tls_length) - AES_GMAC_HASH_LEN; + crp.crp_op = CRYPTO_OP_ENCRYPT; + crp.crp_flags = CRYPTO_F_CBIMM | CRYPTO_F_IV_SEPARATE; + crypto_use_mbuf(&crp, m); + crp.crp_payload_start = tls->params.tls_hlen; + crp.crp_payload_length = payload_len; + + buf = malloc(payload_len, M_KTLS_OCF, M_WAITOK); + crypto_use_output_buf(&crp, buf, payload_len); + + counter_u64_add(ocf_tls13_gcm_recrypts, 1); + error = ktls_ocf_dispatch(os, &crp); + + crypto_destroyreq(&crp); + + if (error == 0) + ktls_ocf_recrypt_fixup(m, tls->params.tls_hlen, payload_len, + buf); + + free(buf, M_KTLS_OCF); + return (error); +} + static const struct ktls_ocf_sw ktls_ocf_tls13_aead_sw = { .encrypt = ktls_ocf_tls13_aead_encrypt, + .recrypt = ktls_ocf_tls13_aead_recrypt, .decrypt = ktls_ocf_tls13_aead_decrypt, }; @@ -694,6 +834,7 @@ ktls_ocf_free(struct ktls_session *tls) os = tls->ocf_session; crypto_freesession(os->sid); crypto_freesession(os->mac_sid); + crypto_freesession(os->recrypt_sid); mtx_destroy(&os->lock); zfree(os, M_KTLS_OCF); } @@ -701,7 +842,7 @@ ktls_ocf_free(struct ktls_session *tls) int ktls_ocf_try(struct socket *so, struct ktls_session *tls, int direction) { - struct crypto_session_params csp, mac_csp; + struct crypto_session_params csp, mac_csp, recrypt_csp; struct ktls_ocf_session *os; int error, mac_len; @@ -709,6 +850,8 @@ ktls_ocf_try(struct socket *so, struct ktls_session *tls, int direction) memset(&mac_csp, 0, sizeof(mac_csp)); mac_csp.csp_mode = CSP_MODE_NONE; mac_len = 0; + memset(&recrypt_csp, 0, sizeof(mac_csp)); + recrypt_csp.csp_mode = CSP_MODE_NONE; switch (tls->params.cipher_algorithm) { case CRYPTO_AES_NIST_GCM_16: @@ -732,6 +875,13 @@ ktls_ocf_try(struct socket *so, struct ktls_session *tls, int direction) csp.csp_cipher_key = tls->params.cipher_key; csp.csp_cipher_klen = tls->params.cipher_key_len; csp.csp_ivlen = AES_GCM_IV_LEN; + + recrypt_csp.csp_flags |= CSP_F_SEPARATE_OUTPUT; + recrypt_csp.csp_mode = CSP_MODE_CIPHER; + recrypt_csp.csp_cipher_alg = CRYPTO_AES_ICM; + recrypt_csp.csp_cipher_key = tls->params.cipher_key; + recrypt_csp.csp_cipher_klen = tls->params.cipher_key_len; + recrypt_csp.csp_ivlen = AES_BLOCK_LEN; break; case CRYPTO_AES_CBC: switch (tls->params.cipher_key_len) { @@ -826,6 +976,16 @@ ktls_ocf_try(struct socket *so, struct ktls_session *tls, int direction) os->mac_len = mac_len; } + if (recrypt_csp.csp_mode != CSP_MODE_NONE) { + error = crypto_newsession(&os->recrypt_sid, &recrypt_csp, + CRYPTO_FLAG_HARDWARE | CRYPTO_FLAG_SOFTWARE); + if (error) { + crypto_freesession(os->sid); + free(os, M_KTLS_OCF); + return (error); + } + } + mtx_init(&os->lock, "ktls_ocf", NULL, MTX_DEF); tls->ocf_session = os; if (tls->params.cipher_algorithm == CRYPTO_AES_NIST_GCM_16 || @@ -870,3 +1030,17 @@ ktls_ocf_decrypt(struct ktls_session *tls, const struct tls_record_layer *hdr, { return (tls->ocf_session->sw->decrypt(tls, hdr, m, seqno, trailer_len)); } + +int +ktls_ocf_recrypt(struct ktls_session *tls, const struct tls_record_layer *hdr, + struct mbuf *m, uint64_t seqno) +{ + return (tls->ocf_session->sw->recrypt(tls, hdr, m, seqno)); +} + +bool +ktls_ocf_recrypt_supported(struct ktls_session *tls) +{ + return (tls->ocf_session->sw->recrypt != NULL && + tls->ocf_session->recrypt_sid != NULL); +}