Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Bug 1320326 - Make ssl3_ConsumeHandshakeNumber() return a SECStatus a…
…nd take a pointer argument r=mt

Differential Revision: https://nss-review.dev.mozaws.net/D97
  • Loading branch information
Tim Taubert committed Nov 28, 2016
1 parent bab54a9 commit 8045316
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 181 deletions.
165 changes: 84 additions & 81 deletions lib/ssl/ssl3con.c
Expand Up @@ -1088,10 +1088,11 @@ ssl_ClientReadVersion(sslSocket *ss, SSL3Opaque **b, unsigned int *len,
SSL3ProtocolVersion *version)
{
SSL3ProtocolVersion v;
PRInt32 temp;
PRUint32 temp;
SECStatus rv;

temp = ssl3_ConsumeHandshakeNumber(ss, 2, b, len);
if (temp < 0) {
rv = ssl3_ConsumeHandshakeNumber(ss, &temp, 2, b, len);
if (rv != SECSuccess) {
return SECFailure; /* alert has been sent */
}

Expand Down Expand Up @@ -4311,7 +4312,7 @@ ssl3_AppendHandshakeHeader(sslSocket *ss, SSL3HandshakeType t, PRUint32 length)
* override the generic error code by setting another.
*/
SECStatus
ssl3_ConsumeHandshake(sslSocket *ss, void *v, PRInt32 bytes, SSL3Opaque **b,
ssl3_ConsumeHandshake(sslSocket *ss, void *v, PRUint32 bytes, SSL3Opaque **b,
PRUint32 *length)
{
PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
Expand All @@ -4329,37 +4330,33 @@ ssl3_ConsumeHandshake(sslSocket *ss, void *v, PRInt32 bytes, SSL3Opaque **b,

/* Read up the next "bytes" number of bytes from the (decrypted) input
* stream "b" (which is *length bytes long), and interpret them as an
* integer in network byte order. Returns the received value.
* integer in network byte order. Sets *num to the received value.
* Reduces *length by bytes. Advances *b by bytes.
*
* Returns SECFailure (-1) on failure.
* This value is indistinguishable from the equivalent received value.
* Only positive numbers are to be received this way.
* Thus, the largest value that may be sent this way is 0x7fffffff.
* On error, an alert has been sent, and a generic error code has been set.
*/
PRInt32
ssl3_ConsumeHandshakeNumber(sslSocket *ss, PRInt32 bytes, SSL3Opaque **b,
PRUint32 *length)
SECStatus
ssl3_ConsumeHandshakeNumber(sslSocket *ss, PRUint32 *num, PRUint32 bytes,
SSL3Opaque **b, PRUint32 *length)
{
PRUint8 *buf = *b;
int i;
PRInt32 num = 0;

PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
PORT_Assert(bytes <= sizeof num);

if ((PRUint32)bytes > *length) {
*num = 0;
if (bytes > *length || bytes > sizeof(*num)) {
return ssl3_DecodeError(ss);
}
PRINT_BUF(60, (ss, "consume bytes:", *b, bytes));

for (i = 0; i < bytes; i++)
num = (num << 8) + buf[i];
for (i = 0; i < bytes; i++) {
*num = (*num << 8) + buf[i];
}
*b += bytes;
*length -= bytes;
return num;
return SECSuccess;
}

/* Read in two values from the incoming decrypted byte stream "b", which is
Expand All @@ -4377,21 +4374,22 @@ ssl3_ConsumeHandshakeNumber(sslSocket *ss, PRInt32 bytes, SSL3Opaque **b,
* point to the values in the buffer **b.
*/
SECStatus
ssl3_ConsumeHandshakeVariable(sslSocket *ss, SECItem *i, PRInt32 bytes,
ssl3_ConsumeHandshakeVariable(sslSocket *ss, SECItem *i, PRUint32 bytes,
SSL3Opaque **b, PRUint32 *length)
{
PRInt32 count;
PRUint32 count;
SECStatus rv;

PORT_Assert(bytes <= 3);
i->len = 0;
i->data = NULL;
i->type = siBuffer;
count = ssl3_ConsumeHandshakeNumber(ss, bytes, b, length);
if (count < 0) { /* Can't test for SECSuccess here. */
rv = ssl3_ConsumeHandshakeNumber(ss, &count, bytes, b, length);
if (rv != SECSuccess) {
return SECFailure;
}
if (count > 0) {
if ((PRUint32)count > *length) {
if (count > *length) {
return ssl3_DecodeError(ss);
}
i->data = *b;
Expand Down Expand Up @@ -4662,10 +4660,11 @@ SECStatus
ssl_ConsumeSignatureScheme(sslSocket *ss, SSL3Opaque **b,
PRUint32 *length, SSLSignatureScheme *out)
{
PRInt32 tmp;
PRUint32 tmp;
SECStatus rv;

tmp = ssl3_ConsumeHandshakeNumber(ss, 2, b, length);
if (tmp < 0) {
rv = ssl3_ConsumeHandshakeNumber(ss, &tmp, 2, b, length);
if (rv != SECSuccess) {
return SECFailure; /* Error code set already. */
}
if (!ssl_IsSupportedSignatureScheme((SSLSignatureScheme)tmp)) {
Expand Down Expand Up @@ -6587,7 +6586,7 @@ ssl3_SetCipherSuite(sslSocket *ss, ssl3CipherSuite chosenSuite,
static SECStatus
ssl3_HandleServerHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
{
PRInt32 temp; /* allow for consume number failure */
PRUint32 temp;
PRBool suite_found = PR_FALSE;
int i;
int errCode = SSL_ERROR_RX_MALFORMED_SERVER_HELLO;
Expand Down Expand Up @@ -6702,8 +6701,8 @@ ssl3_HandleServerHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
}

/* find selected cipher suite in our list. */
temp = ssl3_ConsumeHandshakeNumber(ss, 2, &b, &length);
if (temp < 0) {
rv = ssl3_ConsumeHandshakeNumber(ss, &temp, 2, &b, &length);
if (rv != SECSuccess) {
goto loser; /* alert has been sent */
}
i = ssl3_config_match_init(ss);
Expand Down Expand Up @@ -6748,8 +6747,8 @@ ssl3_HandleServerHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length)

if (ss->version < SSL_LIBRARY_VERSION_TLS_1_3) {
/* find selected compression method in our list. */
temp = ssl3_ConsumeHandshakeNumber(ss, 1, &b, &length);
if (temp < 0) {
rv = ssl3_ConsumeHandshakeNumber(ss, &temp, 1, &b, &length);
if (rv != SECSuccess) {
goto loser; /* alert has been sent */
}
suite_found = PR_FALSE;
Expand Down Expand Up @@ -7257,36 +7256,37 @@ SECStatus
ssl3_ParseCertificateRequestCAs(sslSocket *ss, SSL3Opaque **b, PRUint32 *length,
PLArenaPool *arena, CERTDistNames *ca_list)
{
PRInt32 remaining;
PRUint32 remaining;
int nnames = 0;
dnameNode *node;
SECStatus rv;
int i;

remaining = ssl3_ConsumeHandshakeNumber(ss, 2, b, length);
if (remaining < 0)
rv = ssl3_ConsumeHandshakeNumber(ss, &remaining, 2, b, length);
if (rv != SECSuccess)
return SECFailure; /* malformed, alert has been sent */

if ((PRUint32)remaining > *length)
if (remaining > *length)
goto alert_loser;

ca_list->head = node = PORT_ArenaZNew(arena, dnameNode);
if (node == NULL)
goto no_mem;

while (remaining > 0) {
PRInt32 len;
PRUint32 len;

if (remaining < 2)
goto alert_loser; /* malformed */

node->name.len = len = ssl3_ConsumeHandshakeNumber(ss, 2, b, length);
if (len <= 0)
rv = ssl3_ConsumeHandshakeNumber(ss, &len, 2, b, length);
if (rv != SECSuccess)
return SECFailure; /* malformed, alert has been sent */

remaining -= 2;
if (remaining < len)
if (len == 0 || remaining < len + 2)
goto alert_loser; /* malformed */

remaining -= 2;
node->name.len = len;
node->name.data = *b;
*b += len;
*length -= len;
Expand Down Expand Up @@ -7362,9 +7362,9 @@ ssl_ParseSignatureSchemes(const sslSocket *ss, PLArenaPool *arena,
}

for (; max; --max) {
PRInt32 tmp;
tmp = ssl3_ExtConsumeHandshakeNumber(ss, 2, &buf.data, &buf.len);
if (tmp < 0) {
PRUint32 tmp;
rv = ssl3_ExtConsumeHandshakeNumber(ss, &tmp, 2, &buf.data, &buf.len);
if (rv != SECSuccess) {
PORT_Assert(0);
PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
return SECFailure;
Expand Down Expand Up @@ -8228,7 +8228,7 @@ static SECStatus
ssl3_HandleClientHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
{
sslSessionID *sid = NULL;
PRInt32 tmp;
PRUint32 tmp;
unsigned int i;
SECStatus rv;
int errCode = SSL_ERROR_RX_MALFORMED_CLIENT_HELLO;
Expand Down Expand Up @@ -8288,8 +8288,8 @@ ssl3_HandleClientHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
dtls_RehandshakeCleanup(ss);
}

tmp = ssl3_ConsumeHandshakeNumber(ss, 2, &b, &length);
if (tmp < 0)
rv = ssl3_ConsumeHandshakeNumber(ss, &tmp, 2, &b, &length);
if (rv != SECSuccess)
goto loser; /* malformed, alert already sent */

/* Translate the version. */
Expand Down Expand Up @@ -8342,9 +8342,9 @@ ssl3_HandleClientHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length)

if (length) {
/* Get length of hello extensions */
PRInt32 extension_length;
extension_length = ssl3_ConsumeHandshakeNumber(ss, 2, &b, &length);
if (extension_length < 0) {
PRUint32 extension_length;
rv = ssl3_ConsumeHandshakeNumber(ss, &extension_length, 2, &b, &length);
if (rv != SECSuccess) {
goto loser; /* alert already sent */
}
if (extension_length != length) {
Expand Down Expand Up @@ -9896,9 +9896,9 @@ ssl3_HandleRSAClientKeyExchange(sslSocket *ss,
enc_pms.len = length;

if (ss->ssl3.prSpec->version > SSL_LIBRARY_VERSION_3_0) { /* isTLS */
PRInt32 kLen;
kLen = ssl3_ConsumeHandshakeNumber(ss, 2, &enc_pms.data, &enc_pms.len);
if (kLen < 0) {
PRUint32 kLen;
rv = ssl3_ConsumeHandshakeNumber(ss, &kLen, 2, &enc_pms.data, &enc_pms.len);
if (rv != SECSuccess) {
PORT_SetError(SSL_ERROR_CLIENT_KEY_EXCHANGE_FAILURE);
return SECFailure;
}
Expand Down Expand Up @@ -10218,6 +10218,7 @@ ssl3_HandleNewSessionTicket(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
{
SECStatus rv;
SECItem ticketData;
PRUint32 temp;

SSL_TRC(3, ("%d: SSL3[%d]: handle session_ticket handshake",
SSL_GETPID(), ss->fd));
Expand All @@ -10244,8 +10245,13 @@ ssl3_HandleNewSessionTicket(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
PORT_SetError(SSL_ERROR_RX_MALFORMED_NEW_SESSION_TICKET);
return SECFailure;
}
ss->ssl3.hs.newSessionTicket.ticket_lifetime_hint =
(PRUint32)ssl3_ConsumeHandshakeNumber(ss, 4, &b, &length);

rv = ssl3_ConsumeHandshakeNumber(ss, &temp, 4, &b, &length);
if (rv != SECSuccess) {
PORT_SetError(SSL_ERROR_RX_MALFORMED_NEW_SESSION_TICKET);
return SECFailure;
}
ss->ssl3.hs.newSessionTicket.ticket_lifetime_hint = temp;

rv = ssl3_ConsumeHandshakeVariable(ss, &ticketData, 2, &b, &length);
if (rv != SECSuccess || length != 0) {
Expand Down Expand Up @@ -10540,21 +10546,20 @@ ssl3_HandleCertificateStatus(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
SECStatus
ssl_ReadCertificateStatus(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
{
PRInt32 status, len;
PRUint32 status, len;
SECStatus rv;

PORT_Assert(!ss->sec.isServer);

/* Consume the CertificateStatusType enum */
status = ssl3_ConsumeHandshakeNumber(ss, 1, &b, &length);
if (status != 1 /* ocsp */) {
ssl3_DecodeError(ss); /* sets error code */
return SECFailure;
rv = ssl3_ConsumeHandshakeNumber(ss, &status, 1, &b, &length);
if (rv != SECSuccess || status != 1 /* ocsp */) {
return ssl3_DecodeError(ss);
}

len = ssl3_ConsumeHandshakeNumber(ss, 3, &b, &length);
if (len != length) {
ssl3_DecodeError(ss); /* sets error code */
return SECFailure;
rv = ssl3_ConsumeHandshakeNumber(ss, &len, 3, &b, &length);
if (rv != SECSuccess || len != length) {
return ssl3_DecodeError(ss);
}

#define MAX_CERTSTATUS_LEN 0x1ffff /* 128k - 1 */
Expand Down Expand Up @@ -10611,8 +10616,8 @@ ssl3_CompleteHandleCertificate(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
{
ssl3CertNode *c;
ssl3CertNode *lastCert = NULL;
PRInt32 remaining = 0;
PRInt32 size;
PRUint32 remaining = 0;
PRUint32 size;
SECStatus rv;
PRBool isServer = ss->sec.isServer;
PRBool isTLS;
Expand All @@ -10628,10 +10633,10 @@ ssl3_CompleteHandleCertificate(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
** normal no_certificates message to maximize interoperability.
*/
if (length) {
remaining = ssl3_ConsumeHandshakeNumber(ss, 3, &b, &length);
if (remaining < 0)
rv = ssl3_ConsumeHandshakeNumber(ss, &remaining, 3, &b, &length);
if (rv != SECSuccess)
goto loser; /* fatal alert already sent by ConsumeHandshake. */
if ((PRUint32)remaining > length)
if (remaining > length)
goto decode_loser;
}

Expand Down Expand Up @@ -10662,15 +10667,14 @@ ssl3_CompleteHandleCertificate(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
}

/* First get the peer cert. */
remaining -= 3;
if (remaining < 0)
if (remaining < 3)
goto decode_loser;

size = ssl3_ConsumeHandshakeNumber(ss, 3, &b, &length);
if (size <= 0)
remaining -= 3;
rv = ssl3_ConsumeHandshakeNumber(ss, &size, 3, &b, &length);
if (rv != SECSuccess)
goto loser; /* fatal alert already sent by ConsumeHandshake. */

if (remaining < size)
if (size == 0 || remaining < size)
goto decode_loser;

certItem.data = b;
Expand All @@ -10690,15 +10694,14 @@ ssl3_CompleteHandleCertificate(sslSocket *ss, SSL3Opaque *b, PRUint32 length)

/* Now get all of the CA certs. */
while (remaining > 0) {
remaining -= 3;
if (remaining < 0)
if (remaining < 3)
goto decode_loser;

size = ssl3_ConsumeHandshakeNumber(ss, 3, &b, &length);
if (size <= 0)
remaining -= 3;
rv = ssl3_ConsumeHandshakeNumber(ss, &size, 3, &b, &length);
if (rv != SECSuccess)
goto loser; /* fatal alert already sent by ConsumeHandshake. */

if (remaining < size)
if (size == 0 || remaining < size)
goto decode_loser;

certItem.data = b;
Expand Down

0 comments on commit 8045316

Please sign in to comment.