Skip to content

Commit

Permalink
Bug 1201704 - Abort on unexpected handshake hash computation. r=mt
Browse files Browse the repository at this point in the history
  • Loading branch information
ekr committed Sep 9, 2015
1 parent 9711aed commit faba93a
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 3 deletions.
1 change: 1 addition & 0 deletions external_tests/ssl_gtest/manifest.mn
Expand Up @@ -12,6 +12,7 @@ CSRCS = \
$(NULL)

CPPSRCS = \
ssl_agent_unittest.cc \
ssl_loopback_unittest.cc \
ssl_extension_unittest.cc \
ssl_prf_unittest.cc \
Expand Down
58 changes: 58 additions & 0 deletions external_tests/ssl_gtest/ssl_agent_unittest.cc
@@ -0,0 +1,58 @@
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at http://mozilla.org/MPL/2.0/. */

#include "ssl.h"
#include "sslerr.h"
#include "sslproto.h"

#include <memory>

#include "databuffer.h"
#include "tls_agent.h"
#include "tls_connect.h"
#include "tls_parser.h"

namespace nss_test {

void MakeTrivialHandshakeMessage(uint8_t hs_type, size_t hs_len,
DataBuffer* out) {
size_t total_len = 5 + 4 + hs_len;

out->Allocate(total_len);

size_t index = 0;
out->Write(index, kTlsHandshakeType, 1); ++index; // Content Type
out->Write(index, 3, 1); ++index; // Version high
out->Write(index, 1, 1); ++index; // Version low
out->Write(index, 4 + hs_len, 2); index += 2; // Length

out->Write(index, hs_type, 1); ++index; // Handshake record type.
out->Write(index, hs_len, 3); index += 3; // Handshake length
for (; index < total_len; ++index) {
out->Write(index, 1, 1);
}
}

TEST_P(TlsAgentTest, EarlyFinished) {
DataBuffer buffer;
MakeTrivialHandshakeMessage(kTlsHandshakeFinished, 0, &buffer);
ProcessMessage(buffer, TlsAgent::STATE_ERROR,
SSL_ERROR_RX_UNEXPECTED_FINISHED);
}

TEST_P(TlsAgentTest, EarlyCertificateVerify) {
DataBuffer buffer;
MakeTrivialHandshakeMessage(kTlsHandshakeCertificateVerify, 0, &buffer);
ProcessMessage(buffer, TlsAgent::STATE_ERROR,
SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
}

INSTANTIATE_TEST_CASE_P(AgentTests, TlsAgentTest,
::testing::Combine(
TlsAgentTestBase::kTlsRolesAll,
TlsConnectTestBase::kTlsModesStream));

} // namespace nss_test
35 changes: 35 additions & 0 deletions external_tests/ssl_gtest/tls_agent.cc
Expand Up @@ -17,6 +17,7 @@

namespace nss_test {


const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};

TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode, SSLKEAType kea)
Expand Down Expand Up @@ -548,5 +549,39 @@ void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
EXPECT_EQ(SECSuccess, rv);
}

static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"};
::testing::internal::ParamGenerator<std::string>
TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr);

void TlsAgentTestBase::Init() {
agent_ = new TlsAgent(
role_ == TlsAgent::CLIENT ? "client" : "server",
role_, mode_, kea_);
agent_->Init();
fd_ = DummyPrSocket::CreateFD("dummy", mode_);
agent_->adapter()->SetPeer(
DummyPrSocket::GetAdapter(fd_));
agent_->StartConnect();
}

void TlsAgentTestBase::EnsureInit() {
if (!agent_) {
Init();
}
}

void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
TlsAgent::State expected_state,
int32_t error_code) {
EnsureInit();
agent_->adapter()->PacketReceived(buffer);
agent_->Handshake();

ASSERT_EQ(expected_state, agent_->state());

if (expected_state == TlsAgent::STATE_ERROR) {
ASSERT_EQ(error_code, agent_->error_code());
}
}

} // namespace nss_test
52 changes: 52 additions & 0 deletions external_tests/ssl_gtest/tls_agent.h
Expand Up @@ -102,6 +102,7 @@ class TlsAgent : public PollTarget {
const char* state_str(State state) const { return states[state]; }

PRFileDesc* ssl_fd() { return ssl_fd_; }
DummyPrSocket* adapter() { return adapter_; }

uint16_t min_version() const { return vrange_.min; }
uint16_t max_version() const { return vrange_.max; }
Expand Down Expand Up @@ -239,6 +240,57 @@ class TlsAgent : public PollTarget {
bool expected_read_error_;
};

class TlsAgentTestBase : public ::testing::Test {
public:
static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll;

TlsAgentTestBase(TlsAgent::Role role,
Mode mode) : agent_(nullptr),
fd_(nullptr),
role_(role),
mode_(mode),
kea_(ssl_kea_rsa) {}
~TlsAgentTestBase() {
delete agent_;
if (fd_) {
PR_Close(fd_);
}
}

static inline TlsAgent::Role ToRole(const std::string& str) {
return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER;
}

static inline Mode ToMode(const std::string& str) {
return str == "TLS" ? STREAM : DGRAM;
}

void Init();

protected:
void EnsureInit();
void ProcessMessage(const DataBuffer& buffer,
TlsAgent::State expected_state,
int32_t error_code = 0);


TlsAgent* agent_;
PRFileDesc* fd_;
TlsAgent::Role role_;
Mode mode_;
SSLKEAType kea_;
};

class TlsAgentTest :
public TlsAgentTestBase,
public ::testing::WithParamInterface
<std::tuple<std::string,std::string>> {
public:
TlsAgentTest() :
TlsAgentTestBase(ToRole(std::get<0>(GetParam())),
ToMode(std::get<1>(GetParam()))) {}
};

} // namespace nss_test

#endif
2 changes: 2 additions & 0 deletions external_tests/ssl_gtest/tls_parser.h
Expand Up @@ -27,7 +27,9 @@ const uint8_t kTlsHandshakeClientHello = 1;
const uint8_t kTlsHandshakeServerHello = 2;
const uint8_t kTlsHandshakeCertificate = 11;
const uint8_t kTlsHandshakeServerKeyExchange = 12;
const uint8_t kTlsHandshakeCertificateVerify = 15;
const uint8_t kTlsHandshakeClientKeyExchange = 16;
const uint8_t kTlsHandshakeFinished = 20;

const uint8_t kTlsAlertWarning = 1;
const uint8_t kTlsAlertFatal = 2;
Expand Down
30 changes: 27 additions & 3 deletions lib/ssl/ssl3con.c
Expand Up @@ -4687,6 +4687,11 @@ ssl3_ComputeHandshakeHashes(sslSocket * ss,
SSL3Opaque sha_inner[MAX_MAC_LENGTH];

PORT_Assert( ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss) );
if (ss->ssl3.hs.hashType == handshake_hash_unknown) {
PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
return SECFailure;
}

hashes->hashAlg = ssl_hash_none;

#ifndef NO_PKCS11_BYPASS
Expand Down Expand Up @@ -9511,6 +9516,13 @@ ssl3_HandleCertificateVerify(sslSocket *ss, SSL3Opaque *b, PRUint32 length,
goto alert_loser;
}

if (!hashes) {
PORT_Assert(0);
desc = internal_error;
errCode = SEC_ERROR_LIBRARY_FAILURE;
goto alert_loser;
}

if (isTLS12) {
rv = ssl3_ConsumeSignatureAndHashAlgorithm(ss, &b, &length,
&sigAndHash);
Expand Down Expand Up @@ -11171,6 +11183,13 @@ ssl3_HandleFinished(sslSocket *ss, SSL3Opaque *b, PRUint32 length,
return SECFailure;
}

if (!hashes) {
PORT_Assert(0);
SSL3_SendAlert(ss, alert_fatal, internal_error);
PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
return SECFailure;
}

isTLS = (PRBool)(ss->ssl3.crSpec->version > SSL_LIBRARY_VERSION_3_0);
if (isTLS) {
TLSFinished tlsFinished;
Expand Down Expand Up @@ -11396,6 +11415,7 @@ ssl3_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
SECStatus rv = SECSuccess;
SSL3HandshakeType type = ss->ssl3.hs.msg_type;
SSL3Hashes hashes; /* computed hashes are put here. */
SSL3Hashes *hashesPtr = NULL; /* Set when hashes are computed */
PRUint8 hdr[4];
PRUint8 dtlsData[8];

Expand All @@ -11406,7 +11426,8 @@ ssl3_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
* current message.
*/
ssl_GetSpecReadLock(ss); /************************************/
if((type == finished) || (type == certificate_verify)) {
if(((type == finished) && (ss->ssl3.hs.ws == wait_finished)) ||
((type == certificate_verify) && (ss->ssl3.hs.ws == wait_cert_verify))) {
SSL3Sender sender = (SSL3Sender)0;
ssl3CipherSpec *rSpec = ss->ssl3.prSpec;

Expand All @@ -11415,6 +11436,9 @@ ssl3_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
rSpec = ss->ssl3.crSpec;
}
rv = ssl3_ComputeHandshakeHashes(ss, rSpec, &hashes, sender);
if (rv == SECSuccess) {
hashesPtr = &hashes;
}
}
ssl_ReleaseSpecReadLock(ss); /************************************/
if (rv != SECSuccess) {
Expand Down Expand Up @@ -11565,7 +11589,7 @@ ssl3_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
PORT_SetError(SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
return SECFailure;
}
rv = ssl3_HandleCertificateVerify(ss, b, length, &hashes);
rv = ssl3_HandleCertificateVerify(ss, b, length, hashesPtr);
break;
case client_key_exchange:
if (!ss->sec.isServer) {
Expand All @@ -11584,7 +11608,7 @@ ssl3_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
rv = ssl3_HandleNewSessionTicket(ss, b, length);
break;
case finished:
rv = ssl3_HandleFinished(ss, b, length, &hashes);
rv = ssl3_HandleFinished(ss, b, length, hashesPtr);
break;
default:
(void)SSL3_SendAlert(ss, alert_fatal, unexpected_message);
Expand Down

0 comments on commit faba93a

Please sign in to comment.