Skip to content

Commit

Permalink
Bug 1369606 - Add cipher suite to HelloRetryRequest, r=ekr
Browse files Browse the repository at this point in the history
--HG--
branch : NSS_TLS13_DRAFT19_BRANCH
extra : source : d161ffd0f9ffec9dc64d8b502ac9f892840ed290
extra : amend_source : 9b81dc20ba207fd25c7fddcc09231708eb3790e0
  • Loading branch information
martinthomson committed Jun 2, 2017
1 parent 0cd3574 commit 0a1bb86
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 85 deletions.
20 changes: 19 additions & 1 deletion gtests/ssl_gtest/ssl_hrr_unittest.cc
Expand Up @@ -187,6 +187,23 @@ TEST_P(TlsConnectTls13, RetryWithSameKeyShare) {
EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code());
}

// Stream because the server doesn't consume the alert and terminate.
TEST_F(TlsConnectStreamTls13, RetryWithDifferentCipherSuite) {
EnsureTlsSetup();
// Force a HelloRetryRequest.
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
server_->ConfigNamedGroups(groups);
// Then switch out the default suite (TLS_AES_128_GCM_SHA256).
server_->SetPacketFilter(std::make_shared<SelectedCipherSuiteReplacer>(
TLS_CHACHA20_POLY1305_SHA256));

client_->ExpectSendAlert(kTlsAlertIllegalParameter);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
}

// This tests that the second attempt at sending a ClientHello (after receiving
// a HelloRetryRequest) is correctly retransmitted.
TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) {
Expand Down Expand Up @@ -276,9 +293,10 @@ class HelloRetryRequestAgentTest : public TlsAgentTestClient {
void MakeCannedHrr(const uint8_t* body, size_t len, DataBuffer* hrr_record,
uint32_t seq_num = 0) const {
DataBuffer hrr_data;
hrr_data.Allocate(len + 4);
hrr_data.Allocate(len + 6);
size_t i = 0;
i = hrr_data.Write(i, 0x7f00 | TLS_1_3_DRAFT_VERSION, 2);
i = hrr_data.Write(i, TLS_AES_128_GCM_SHA256, 2);
i = hrr_data.Write(i, static_cast<uint32_t>(len), 2);
if (len) {
hrr_data.Write(i, body, len);
Expand Down
30 changes: 0 additions & 30 deletions gtests/ssl_gtest/ssl_resumption_unittest.cc
Expand Up @@ -437,36 +437,6 @@ TEST_P(TlsConnectGeneric, TestResumeServerDifferentCipher) {
CheckKeys();
}

class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
public:
SelectedCipherSuiteReplacer(uint16_t suite) : cipher_suite_(suite) {}

protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) override {
if (header.handshake_type() != kTlsHandshakeServerHello) {
return KEEP;
}

*output = input;
uint32_t temp = 0;
EXPECT_TRUE(input.Read(0, 2, &temp));
// Cipher suite is after version(2) and random(32).
size_t pos = 34;
if (temp < SSL_LIBRARY_VERSION_TLS_1_3) {
// In old versions, we have to skip a session_id too.
EXPECT_TRUE(input.Read(pos, 1, &temp));
pos += 1 + temp;
}
output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2);
return CHANGE;
}

private:
uint16_t cipher_suite_;
};

// Test that the client doesn't tolerate the server picking a different cipher
// suite for resumption.
TEST_P(TlsConnectStream, TestResumptionOverrideCipher) {
Expand Down
24 changes: 22 additions & 2 deletions gtests/ssl_gtest/tls_filter.cc
Expand Up @@ -432,8 +432,7 @@ bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) {

static bool FindHelloRetryExtensions(TlsParser* parser,
const TlsVersioned& header) {
// TODO for -19 add cipher suite
if (!parser->Skip(2)) { // version
if (!parser->Skip(4)) { // version (2) + cipher suite (2)
return false;
}
return true;
Expand Down Expand Up @@ -647,4 +646,25 @@ PacketFilter::Action TlsInspectorClientHelloVersionSetter::FilterHandshake(
return KEEP;
}

PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
if (header.handshake_type() != kTlsHandshakeServerHello) {
return KEEP;
}

*output = input;
uint32_t temp = 0;
EXPECT_TRUE(input.Read(0, 2, &temp));
// Cipher suite is after version(2) and random(32).
size_t pos = 34;
if (temp < SSL_LIBRARY_VERSION_TLS_1_3) {
// In old versions, we have to skip a session_id too.
EXPECT_TRUE(input.Read(pos, 1, &temp));
pos += 1 + temp;
}
output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2);
return CHANGE;
}

} // namespace nss_test
13 changes: 13 additions & 0 deletions gtests/ssl_gtest/tls_filter.h
Expand Up @@ -411,6 +411,19 @@ class TlsLastByteDamager : public TlsHandshakeFilter {
uint8_t type_;
};

class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
public:
SelectedCipherSuiteReplacer(uint16_t suite) : cipher_suite_(suite) {}

protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) override;

private:
uint16_t cipher_suite_;
};

} // namespace nss_test

#endif
116 changes: 67 additions & 49 deletions lib/ssl/ssl3con.c
Expand Up @@ -6573,11 +6573,9 @@ ssl3_SendCertificateVerify(sslSocket *ss, SECKEYPrivateKey *privKey)
/* Once a cipher suite has been selected, make sure that the necessary secondary
* information is properly set. */
SECStatus
ssl3_SetCipherSuite(sslSocket *ss, ssl3CipherSuite chosenSuite,
PRBool initHashes)
ssl3_SetupCipherSuite(sslSocket *ss, PRBool initHashes)
{
ss->ssl3.hs.cipher_suite = chosenSuite;
ss->ssl3.hs.suite_def = ssl_LookupCipherSuiteDef(chosenSuite);
ss->ssl3.hs.suite_def = ssl_LookupCipherSuiteDef(ss->ssl3.hs.cipher_suite);
if (!ss->ssl3.hs.suite_def) {
PORT_Assert(0);
PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
Expand All @@ -6594,6 +6592,59 @@ ssl3_SetCipherSuite(sslSocket *ss, ssl3CipherSuite chosenSuite,
return ssl3_InitHandshakeHashes(ss);
}

SECStatus
ssl_ClientConsumeCipherSuite(sslSocket *ss, SSL3ProtocolVersion version,
PRUint8 **b, unsigned int *length)
{
PRUint32 temp;
int i;
SECStatus rv;

/* Find the selected cipher suite in our list. */
rv = ssl3_ConsumeHandshakeNumber(ss, &temp, 2, b, length);
if (rv != SECSuccess) {
return SECFailure; /* alert has been sent */
}

i = ssl3_config_match_init(ss);
PORT_Assert(i > 0);
if (i <= 0) {
return SECFailure;
}
for (i = 0; i < ssl_V3_SUITES_IMPLEMENTED; i++) {
ssl3CipherSuiteCfg *suite = &ss->cipherSuites[i];
if (temp == suite->cipher_suite) {
SSLVersionRange vrange = { version, version };
if (!config_match(suite, ss->ssl3.policy, &vrange, ss)) {
/* config_match already checks whether the cipher suite is
* acceptable for the version, but the check is repeated here
* in order to give a more precise error code. */
if (!ssl3_CipherSuiteAllowedForVersionRange(temp, &vrange)) {
PORT_SetError(SSL_ERROR_CIPHER_DISALLOWED_FOR_VERSION);
} else {
PORT_SetError(SSL_ERROR_NO_CYPHER_OVERLAP);
}
return SECFailure;
}
break;
}
}
if (i >= ssl_V3_SUITES_IMPLEMENTED) {
PORT_SetError(SSL_ERROR_NO_CYPHER_OVERLAP);
return SECFailure;
}

/* Don't let the server change its mind. */
if (ss->ssl3.hs.helloRetry && temp != ss->ssl3.hs.cipher_suite) {
(void)SSL3_SendAlert(ss, alert_fatal, illegal_parameter);
PORT_SetError(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
return SECFailure;
}

ss->ssl3.hs.cipher_suite = (ssl3CipherSuite)temp;
return SECSuccess;
}

/* Called from ssl3_HandleHandshakeMessage() when it has deciphered a complete
* ssl3 ServerHello message.
* Caller must hold Handshake and RecvBuf locks.
Expand All @@ -6602,7 +6653,6 @@ static SECStatus
ssl3_HandleServerHello(sslSocket *ss, PRUint8 *b, PRUint32 length)
{
PRUint32 temp;
PRBool suite_found = PR_FALSE;
int i;
int errCode = SSL_ERROR_RX_MALFORMED_SERVER_HELLO;
SECStatus rv;
Expand Down Expand Up @@ -6725,68 +6775,35 @@ ssl3_HandleServerHello(sslSocket *ss, PRUint8 *b, PRUint32 length)
}
}

/* find selected cipher suite in our list. */
rv = ssl3_ConsumeHandshakeNumber(ss, &temp, 2, &b, &length);
rv = ssl_ClientConsumeCipherSuite(ss, ss->version, &b, &length);
if (rv != SECSuccess) {
goto loser; /* alert has been sent */
}
i = ssl3_config_match_init(ss);
PORT_Assert(i > 0);
if (i <= 0) {
errCode = PORT_GetError();
goto loser;
}
for (i = 0; i < ssl_V3_SUITES_IMPLEMENTED; i++) {
ssl3CipherSuiteCfg *suite = &ss->cipherSuites[i];
if (temp == suite->cipher_suite) {
SSLVersionRange vrange = { ss->version, ss->version };
if (!config_match(suite, ss->ssl3.policy, &vrange, ss)) {
/* config_match already checks whether the cipher suite is
* acceptable for the version, but the check is repeated here
* in order to give a more precise error code. */
if (!ssl3_CipherSuiteAllowedForVersionRange(temp, &vrange)) {
desc = handshake_failure;
errCode = SSL_ERROR_CIPHER_DISALLOWED_FOR_VERSION;
goto alert_loser;
}

break; /* failure */
}

suite_found = PR_TRUE;
break; /* success */
}
}
if (!suite_found) {
desc = handshake_failure;
errCode = SSL_ERROR_NO_CYPHER_OVERLAP;
goto alert_loser;
}

rv = ssl3_SetCipherSuite(ss, (ssl3CipherSuite)temp, PR_TRUE);
rv = ssl3_SetupCipherSuite(ss, PR_TRUE);
if (rv != SECSuccess) {
desc = internal_error;
errCode = PORT_GetError();
goto alert_loser;
goto loser;
}

if (ss->version < SSL_LIBRARY_VERSION_TLS_1_3) {
PRBool found = PR_FALSE;
/* find selected compression method in our list. */
rv = ssl3_ConsumeHandshakeNumber(ss, &temp, 1, &b, &length);
if (rv != SECSuccess) {
goto loser; /* alert has been sent */
}
suite_found = PR_FALSE;
for (i = 0; i < ssl_compression_method_count; i++) {
if (temp == ssl_compression_methods[i]) {
if (!ssl_CompressionEnabled(ss, ssl_compression_methods[i])) {
break; /* failure */
}
suite_found = PR_TRUE;
found = PR_TRUE;
break; /* success */
}
}
if (!suite_found) {
if (!found) {
desc = handshake_failure;
errCode = SSL_ERROR_NO_COMPRESSION_OVERLAP;
goto alert_loser;
Expand Down Expand Up @@ -8058,7 +8075,8 @@ ssl3_NegotiateCipherSuite(sslSocket *ss, const SECItem *suites,
for (i = 0; i + 1 < suites->len; i += 2) {
PRUint16 suite_i = (suites->data[i] << 8) | suites->data[i + 1];
if (suite_i == suite->cipher_suite) {
return ssl3_SetCipherSuite(ss, suite_i, initHashes);
ss->ssl3.hs.cipher_suite = suite_i;
return ssl3_SetupCipherSuite(ss, initHashes);
}
}
}
Expand Down Expand Up @@ -8723,7 +8741,8 @@ ssl3_HandleClientHelloPart2(sslSocket *ss,
for (i = 0; i + 1 < suites->len; i += 2) {
PRUint16 suite_i = (suites->data[i] << 8) | suites->data[i + 1];
if (suite_i == suite->cipher_suite) {
rv = ssl3_SetCipherSuite(ss, suite_i, PR_TRUE);
ss->ssl3.hs.cipher_suite = suite_i;
rv = ssl3_SetupCipherSuite(ss, PR_TRUE);
if (rv != SECSuccess) {
desc = internal_error;
errCode = PORT_GetError();
Expand Down Expand Up @@ -9170,8 +9189,6 @@ ssl3_HandleV2ClientHello(sslSocket *ss, unsigned char *buffer, int length,
**
** NOTE: This suite selection algorithm should be the same as the one in
** ssl3_HandleClientHello().
**
** See the comments about export cipher suites in ssl3_HandleClientHello().
*/
for (j = 0; j < ssl_V3_SUITES_IMPLEMENTED; j++) {
ssl3CipherSuiteCfg *suite = &ss->cipherSuites[j];
Expand All @@ -9182,7 +9199,8 @@ ssl3_HandleV2ClientHello(sslSocket *ss, unsigned char *buffer, int length,
for (i = 0; i + 2 < suite_length; i += 3) {
PRUint32 suite_i = (suites[i] << 16) | (suites[i + 1] << 8) | suites[i + 2];
if (suite_i == suite->cipher_suite) {
rv = ssl3_SetCipherSuite(ss, suite_i, PR_TRUE);
ss->ssl3.hs.cipher_suite = suite_i;
rv = ssl3_SetupCipherSuite(ss, PR_TRUE);
if (rv != SECSuccess) {
desc = internal_error;
errCode = PORT_GetError();
Expand Down
7 changes: 5 additions & 2 deletions lib/ssl/sslimpl.h
Expand Up @@ -1628,6 +1628,10 @@ extern SECStatus ssl_ClientReadVersion(sslSocket *ss, PRUint8 **b,
extern SECStatus ssl3_NegotiateVersion(sslSocket *ss,
SSL3ProtocolVersion peerVersion,
PRBool allowLargerPeerVersion);
extern SECStatus ssl_ClientConsumeCipherSuite(sslSocket *ss,
SSL3ProtocolVersion version,
PRUint8 **b,
unsigned int *length);

extern SECStatus ssl_GetPeerInfo(sslSocket *ss);

Expand Down Expand Up @@ -1830,8 +1834,7 @@ SECOidTag ssl3_HashTypeToOID(SSLHashType hashType);
SSLHashType ssl_SignatureSchemeToHashType(SSLSignatureScheme scheme);
KeyType ssl_SignatureSchemeToKeyType(SSLSignatureScheme scheme);

SECStatus ssl3_SetCipherSuite(sslSocket *ss, ssl3CipherSuite chosenSuite,
PRBool initHashes);
SECStatus ssl3_SetupCipherSuite(sslSocket *ss, PRBool initHashes);

/* Pull in TLS 1.3 functions */
#include "tls13con.h"
Expand Down
15 changes: 14 additions & 1 deletion lib/ssl/tls13con.c
Expand Up @@ -477,7 +477,8 @@ tls13_SetupClientHello(sslSocket *ss)
return SECFailure;
}

rv = ssl3_SetCipherSuite(ss, ss->sec.ci.sid->u.ssl3.cipherSuite, PR_FALSE);
ss->ssl3.hs.cipher_suite = ss->sec.ci.sid->u.ssl3.cipherSuite;
rv = ssl3_SetupCipherSuite(ss, PR_FALSE);
if (rv != SECSuccess) {
FATAL_ERROR(ss, PORT_GetError(), internal_error);
return SECFailure;
Expand Down Expand Up @@ -1518,6 +1519,7 @@ tls13_SendHelloRetryRequest(sslSocket *ss, const sslNamedGroupDef *selectedGroup
ssl_GetXmitBufLock(ss);
rv = ssl3_AppendHandshakeHeader(ss, hello_retry_request,
2 + /* version */
2 + /* cipher suite */
2 + /* extension length */
2 + /* group extension id */
2 + /* group extension length */
Expand All @@ -1534,6 +1536,12 @@ tls13_SendHelloRetryRequest(sslSocket *ss, const sslNamedGroupDef *selectedGroup
goto loser;
}

rv = ssl3_AppendHandshakeNumber(ss, ss->ssl3.hs.cipher_suite, 2);
if (rv != SECSuccess) {
FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
goto loser;
}

/* Length of extensions. */
rv = ssl3_AppendHandshakeNumber(ss, 2 + 2 + 2, 2);
if (rv != SECSuccess) {
Expand Down Expand Up @@ -1744,6 +1752,11 @@ tls13_HandleHelloRetryRequest(sslSocket *ss, PRUint8 *b, PRUint32 length)
return SECFailure;
}

rv = ssl_ClientConsumeCipherSuite(ss, version, &b, &length);
if (rv != SECSuccess) {
return SECFailure; /* error code already set */
}

/* Extensions. */
rv = ssl3_ConsumeHandshakeNumber(ss, &tmp, 2, &b, &length);
if (rv != SECSuccess) {
Expand Down

0 comments on commit 0a1bb86

Please sign in to comment.