Skip to content

Commit

Permalink
Implement RSA-PSS padding for TPMv2
Browse files Browse the repository at this point in the history
Now I can connect using TLSv1.3 using a TPMv2 RSA key. And my list of
"stuff I should never have had to do for myself in the application,
just to ask the crypto library to use the key that the user pointed
it at" has *really* jumped the shark now.

Signed-off-by: David Woodhouse <dwmw2@infradead.org>
  • Loading branch information
dwmw2 committed May 12, 2021
1 parent c6a3c9d commit ff36796
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 11 deletions.
5 changes: 3 additions & 2 deletions gnutls.h
Expand Up @@ -42,9 +42,10 @@ int tpm2_rsa_sign_hash_fn(gnutls_privkey_t key, gnutls_sign_algorithm_t algo,
int tpm2_ec_sign_hash_fn(gnutls_privkey_t key, gnutls_sign_algorithm_t algo,
void *_certinfo, unsigned int flags,
const gnutls_datum_t *data, gnutls_datum_t *sig);
int oc_pkcs1_pad(struct openconnect_info *vpninfo,
unsigned char *buf, int size, const gnutls_datum_t *data);
int oc_pad_rsasig(struct openconnect_info *vpninfo, gnutls_sign_algorithm_t algo,
unsigned char *buf, int size, const gnutls_datum_t *data, int keybits);
uint16_t tpm2_key_curve(struct openconnect_info *vpninfo, struct cert_info *certinfo);
int tpm2_rsa_key_bits(struct openconnect_info *vpninfo, struct cert_info *certinfo);

/* GnuTLS 3.6.0+ provides this. We have our own for older GnuTLS. There is
* also _gnutls_encode_ber_rs_raw() in some older versions, but there were
Expand Down
184 changes: 179 additions & 5 deletions gnutls_tpm2.c
Expand Up @@ -102,9 +102,15 @@ static int tpm2_ec_sign_fn(gnutls_privkey_t key, void *_certinfo,
#if GNUTLS_VERSION_NUMBER >= 0x030600
static int rsa_key_info(gnutls_privkey_t key, unsigned int flags, void *_certinfo)
{
struct cert_info *certinfo = _certinfo;
struct openconnect_info *vpninfo = certinfo->vpninfo;

if (flags & GNUTLS_PRIVKEY_INFO_PK_ALGO)
return GNUTLS_PK_RSA;

if (flags & GNUTLS_PRIVKEY_INFO_PK_ALGO_BITS)
return tpm2_rsa_key_bits(vpninfo, certinfo);

if (flags & GNUTLS_PRIVKEY_INFO_HAVE_SIGN_ALGO) {
gnutls_sign_algorithm_t algo = GNUTLS_FLAGS_TO_SIGN_ALGO(flags);
switch (algo) {
Expand All @@ -115,7 +121,18 @@ static int rsa_key_info(gnutls_privkey_t key, unsigned int flags, void *_certinf
case GNUTLS_SIGN_RSA_SHA512:
return 1;

case GNUTLS_SIGN_RSA_PSS_SHA256:
case GNUTLS_SIGN_RSA_PSS_RSAE_SHA256:
case GNUTLS_SIGN_RSA_PSS_SHA384:
case GNUTLS_SIGN_RSA_PSS_RSAE_SHA384:
case GNUTLS_SIGN_RSA_PSS_SHA512:
case GNUTLS_SIGN_RSA_PSS_RSAE_SHA512:
return 1;

default:
vpn_progress(vpninfo, PRG_DEBUG,
_("Not supporting EC sign algo %s\n"),
gnutls_sign_get_name(algo));
return 0;
}
}
Expand All @@ -130,16 +147,17 @@ static int rsa_key_info(gnutls_privkey_t key, unsigned int flags, void *_certinf
#if GNUTLS_VERSION_NUMBER >= 0x030400
static int ec_key_info(gnutls_privkey_t key, unsigned int flags, void *_certinfo)
{
struct cert_info *certinfo = _certinfo;
struct openconnect_info *vpninfo = certinfo->vpninfo;

if (flags & GNUTLS_PRIVKEY_INFO_PK_ALGO)
return GNUTLS_PK_EC;

#ifdef GNUTLS_PRIVKEY_INFO_HAVE_SIGN_ALGO
if (flags & GNUTLS_PRIVKEY_INFO_HAVE_SIGN_ALGO) {
struct cert_info *certinfo = _certinfo;
struct openconnect_info *vpninfo = certinfo->vpninfo;

uint16_t tpm2_curve = tpm2_key_curve(vpninfo, certinfo);
gnutls_sign_algorithm_t algo = GNUTLS_FLAGS_TO_SIGN_ALGO(flags);

switch (algo) {
case GNUTLS_SIGN_ECDSA_SHA1:
case GNUTLS_SIGN_ECDSA_SHA256:
Expand Down Expand Up @@ -393,8 +411,8 @@ int oc_gnutls_encode_rs_value(gnutls_datum_t *sig, const gnutls_datum_t *sig_r,

/* EMSA-PKCS1-v1_5 padding in accordance with RFC3447 §9.2 */
#define PKCS1_PAD_OVERHEAD 11
int oc_pkcs1_pad(struct openconnect_info *vpninfo,
unsigned char *buf, int size, const gnutls_datum_t *data)
static int oc_pkcs1_pad(struct openconnect_info *vpninfo,
unsigned char *buf, int size, const gnutls_datum_t *data)
{
if (data->size + PKCS1_PAD_OVERHEAD > size) {
vpn_progress(vpninfo, PRG_ERR,
Expand All @@ -411,4 +429,160 @@ int oc_pkcs1_pad(struct openconnect_info *vpninfo,

return 0;
}

#if GNUTLS_VERSION_NUMBER >= 0x030600
/* EMSA-PSS encoding in accordance with RFC3447 §9.1 */
static int oc_pss_mgf1_pad(struct openconnect_info *vpninfo, gnutls_digest_algorithm_t dig,
unsigned char *emBuf, int emLen, const gnutls_datum_t *mHash, int keybits)
{
gnutls_hash_hd_t hashctx = NULL;
int err = GNUTLS_E_PK_SIGN_FAILED;

/* The emBits for EMSA-PSS encoding is actually one *fewer* bit than
* the RSA modulus. As RFC3447 §8.1.1 points out, "the octet length
* of EM will be one less than k if modBits - 1 is divisible by 8
* and equal to k otherwise". Where k is the input emLen, which we
* thus need to adjust before using it as emLen for the following
* operations. Not that it matters much since I don't think the TPM
* can cope with RSA keys whose modulus isn't a multiple of 8 bits
* anyway. */
int msbits = (keybits - 1) & 7;
if (!msbits) {
*(emBuf++) = 0;
emLen--;
}

/* GnuTLS gives us a predigested mHash from which we create M' and
* continue the process. Can we infer all the PSS parameters from
* the digest size, including the salt size? Or does GnuTLS need
* a gnutls_privkey_import_ext5() which lets us have the params too?
* Better still, could GnuTLS just do this all for us and we only
* do a raw signature — really raw, unlike GNUTLS_SIGN_RSA_RAW
* which AIUI is actually padded. */
if (mHash->size > emLen - 2) {
vpn_progress(vpninfo, PRG_ERR,
_("PSS encoding failed; hash size %d too large for RSA key %d\n"),
mHash->size, emLen);
return GNUTLS_E_PK_SIGN_FAILED;
}

int sLen = mHash->size;
if (sLen + mHash->size > emLen - 2)
sLen = emLen - 2 - mHash->size;

char salt[SHA512_SIZE];
if (sLen) {
err = gnutls_rnd(GNUTLS_RND_NONCE, salt, sLen);
if (err)
goto out;
}

/* Hash M' (8 zeroes || mHash || salt) into its place in EM */
if ((err = gnutls_hash_init(&hashctx, dig)) ||
(err = gnutls_hash(hashctx, "\0\0\0\0\0\0\0\0", 8)) ||
(err = gnutls_hash(hashctx, mHash->data, mHash->size)) ||
(sLen && (err = gnutls_hash(hashctx, salt, sLen))))
goto out;

int maskedDBLen = emLen - mHash->size - 1;
gnutls_hash_output(hashctx, emBuf + maskedDBLen);

emBuf[emLen - 1] = 0xbc;

/* Now the MGF1 function as definsed in RFC3447 Appendix B, although
* it's somewhat easier to read in NIST SP 800-56B §7.2.2.2.
*
* We repeatedly hash (M' || C) where C is an incrementing 32-bit
* counter, so hash M' first and then use gnutls_hash_copy() each
* time to add C to the copy. */
err = gnutls_hash(hashctx, emBuf + maskedDBLen, mHash->size);
if (err)
goto out;

int mgflen = 0, mgf_count = 0;
while (mgflen < maskedDBLen) {
gnutls_hash_hd_t ctx2 = gnutls_hash_copy(hashctx);
if (!ctx2) {
err = GNUTLS_E_PK_SIGN_FAILED;
goto out;
}
uint32_t be_count = htonl(mgf_count++);
err = gnutls_hash(ctx2, &be_count, sizeof(be_count));
if (err) {
gnutls_hash_deinit(ctx2, NULL);
goto out;
}
if (mgflen + mHash->size <= maskedDBLen) {
gnutls_hash_deinit(ctx2, emBuf + mgflen);
mgflen += mHash->size;
} else {
char md[SHA512_SIZE];
gnutls_hash_deinit(ctx2, md);
memcpy(emBuf + mgflen, md, maskedDBLen - mgflen);
mgflen = maskedDBLen;
}
}

/* Back to EMSA-PSS-ENCODE step 10. The MGF result was directly placed
* into emBuf, so now XOR with DB, which is (zeroes || 0x01 || salt) */
int dst = maskedDBLen - 1;
while (sLen--)
emBuf[dst--] ^= salt[sLen];
emBuf[dst] ^= 0x01;

/* Now mask out the high bits. In the case where msbits is zero, we
* skipped the entire first byte so do nothing. */
if (msbits)
emBuf[0] &= 0xFF >> (8 - msbits);

err = 0;
out:
if (hashctx)
gnutls_hash_deinit(hashctx, NULL);

return err;
}
#endif

int oc_pad_rsasig(struct openconnect_info *vpninfo, gnutls_sign_algorithm_t algo,
unsigned char *buf, int size, const gnutls_datum_t *data, int keybits)
{
switch(algo) {
case GNUTLS_SIGN_UNKNOWN:
case GNUTLS_SIGN_RSA_SHA1:
case GNUTLS_SIGN_RSA_SHA256:
case GNUTLS_SIGN_RSA_SHA384:
case GNUTLS_SIGN_RSA_SHA512:
return oc_pkcs1_pad(vpninfo, buf, size, data);

#if GNUTLS_VERSION_NUMBER >= 0x030600
/* Really PKCS#1.5 padding, yes. */
case GNUTLS_SIGN_RSA_RAW:
return oc_pkcs1_pad(vpninfo, buf, size, data);

case GNUTLS_SIGN_RSA_PSS_SHA256:
case GNUTLS_SIGN_RSA_PSS_RSAE_SHA256:
if (data->size != SHA256_SIZE)
return GNUTLS_E_PK_SIGN_FAILED;
return oc_pss_mgf1_pad(vpninfo, GNUTLS_DIG_SHA256, buf, size, data, keybits);

case GNUTLS_SIGN_RSA_PSS_SHA384:
case GNUTLS_SIGN_RSA_PSS_RSAE_SHA384:
if (data->size != SHA384_SIZE)
return GNUTLS_E_PK_SIGN_FAILED;
return oc_pss_mgf1_pad(vpninfo, GNUTLS_DIG_SHA384, buf, size, data, keybits);

case GNUTLS_SIGN_RSA_PSS_SHA512:
case GNUTLS_SIGN_RSA_PSS_RSAE_SHA512:
if (data->size != SHA512_SIZE)
return GNUTLS_E_PK_SIGN_FAILED;
return oc_pss_mgf1_pad(vpninfo, GNUTLS_DIG_SHA512, buf, size, data, keybits);
#endif /* 3.6.0+ */
default:
vpn_progress(vpninfo, PRG_ERR,
_("TPMv2 RSA sign called for unknown algorithm %s\n"),
gnutls_sign_get_name(algo));
return GNUTLS_E_PK_SIGN_FAILED;
}
}
#endif /* HAVE_TSS2 */
12 changes: 9 additions & 3 deletions gnutls_tpm2_esys.c
Expand Up @@ -409,12 +409,13 @@ int tpm2_rsa_sign_hash_fn(gnutls_privkey_t key, gnutls_sign_algorithm_t algo,
TSS2_RC r;

vpn_progress(vpninfo, PRG_DEBUG,
_("TPM2 RSA sign function called for %d bytes.\n"),
data->size);
_("TPM2 RSA sign function called for %d bytes, algo %s\n"),
data->size, gnutls_sign_get_name(algo));

digest.size = certinfo->tpm2->pub.publicArea.unique.rsa.size;

if (oc_pkcs1_pad(vpninfo, digest.buffer, digest.size, data))
if (oc_pad_rsasig(vpninfo, algo, digest.buffer, digest.size, data,
certinfo->tpm2->pub.publicArea.parameters.rsaDetail.keyBits))
return GNUTLS_E_PK_SIGN_FAILED;

if (init_tpm2_key(&ectx, &key_handle, vpninfo, certinfo))
Expand Down Expand Up @@ -608,6 +609,11 @@ uint16_t tpm2_key_curve(struct openconnect_info *vpninfo, struct cert_info *cert
return certinfo->tpm2->pub.publicArea.parameters.eccDetail.curveID;
}

int tpm2_rsa_key_bits(struct openconnect_info *vpninfo, struct cert_info *certinfo)
{
return certinfo->tpm2->pub.publicArea.parameters.rsaDetail.keyBits;
}

void release_tpm2_ctx(struct openconnect_info *vpninfo, struct cert_info *certinfo)
{
if (certinfo->tpm2) {
Expand Down
8 changes: 7 additions & 1 deletion gnutls_tpm2_ibm.c
Expand Up @@ -344,7 +344,8 @@ int tpm2_rsa_sign_hash_fn(gnutls_privkey_t key, gnutls_sign_algorithm_t algo,

in.cipherText.t.size = certinfo->tpm2->pub.publicArea.unique.rsa.t.size;

if (oc_pkcs1_pad(vpninfo, in.cipherText.t.buffer, in.cipherText.t.size, data))
if (oc_pad_rsasig(vpninfo, algo, in.cipherText.t.buffer, in.cipherText.t.size, data,
certinfo->tpm2->pub.publicArea.parameters.rsaDetail.keyBits))
return GNUTLS_E_PK_SIGN_FAILED;

in.inScheme.scheme = TPM_ALG_NULL;
Expand Down Expand Up @@ -558,6 +559,11 @@ uint16_t tpm2_key_curve(struct openconnect_info *vpninfo, struct cert_info *cert
return certinfo->tpm2->pub.publicArea.parameters.eccDetail.curveID;
}

int tpm2_rsa_key_bits(struct openconnect_info *vpninfo, struct cert_info *certinfo)
{
return certinfo->tpm2->pub.publicArea.parameters.rsaDetail.keyBits;
}

void release_tpm2_ctx(struct openconnect_info *vpninfo, struct cert_info *certinfo)
{
if (certinfo->tpm2) {
Expand Down

0 comments on commit ff36796

Please sign in to comment.