Commit bd199f7f authored by Martin Thomson's avatar Martin Thomson

Bug 1086145 - Improving handshake test coverage, r=wtc

--HG--
extra : rebase_source : ff9160ed9fd8c30942150392bb75bd9260134328
parent c1a8d94b
......@@ -21,11 +21,11 @@ You should be able to run the unit tests manually as:
ssl_gtest -d ${SSLGTESTDIR}
Where $SSLGTESTDIR the directory created by ./all.sh or a manually
created directory with a database containing a certificate called
server (with its private keys)
Where $SSLGTESTDIR is a directory with a database containing:
- an RSA certificate called server (with its private key)
- an ECDSA certificate called ecdsa (with its private key)
A directory like this is created by ./all.sh and can be found
in a directory named something like
There is a very trivial set of tests that demonstrate some
of the features.
tests_results/security/${hostname}.${NUMBER}/ssl_gtests
......@@ -9,6 +9,7 @@ MODULE = nss
CPPSRCS = \
ssl_loopback_unittest.cc \
ssl_extension_unittest.cc \
ssl_skip_unittest.cc \
ssl_gtest.cc \
test_io.cc \
tls_agent.cc \
......
......@@ -268,8 +268,8 @@ class TlsExtensionTestBase : public TlsConnectTestBase {
client_->SetPacketFilter(filter);
}
ConnectExpectFail();
ASSERT_EQ(kTlsAlertFatal, alert_recorder->level());
ASSERT_EQ(alert, alert_recorder->description());
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
EXPECT_EQ(alert, alert_recorder->description());
}
void ServerHelloErrorTest(PacketFilter* filter,
......@@ -280,8 +280,8 @@ class TlsExtensionTestBase : public TlsConnectTestBase {
server_->SetPacketFilter(filter);
}
ConnectExpectFail();
ASSERT_EQ(kTlsAlertFatal, alert_recorder->level());
ASSERT_EQ(alert, alert_recorder->description());
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
EXPECT_EQ(alert, alert_recorder->description());
}
static void InitSimpleSni(DataBuffer* extension) {
......@@ -494,7 +494,7 @@ TEST_P(TlsExtensionTest12Plus, DISABLED_SignatureAlgorithmsSigUnsupported) {
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) {
EnableSomeECDHECiphers();
EnableSomeEcdheCiphers();
const uint8_t val[] = { 0x00, 0x01, 0x00 };
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(new TlsExtensionReplacer(ssl_elliptic_curves_xtn,
......@@ -502,7 +502,7 @@ TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) {
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) {
EnableSomeECDHECiphers();
EnableSomeEcdheCiphers();
const uint8_t val[] = { 0x09, 0x99, 0x00, 0x00 };
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(new TlsExtensionReplacer(ssl_elliptic_curves_xtn,
......@@ -510,7 +510,7 @@ TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) {
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) {
EnableSomeECDHECiphers();
EnableSomeEcdheCiphers();
const uint8_t val[] = { 0x00, 0x02, 0x00, 0x00, 0x00 };
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(new TlsExtensionReplacer(ssl_elliptic_curves_xtn,
......@@ -518,7 +518,7 @@ TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) {
}
TEST_P(TlsExtensionTestGeneric, SupportedPointsEmpty) {
EnableSomeECDHECiphers();
EnableSomeEcdheCiphers();
const uint8_t val[] = { 0x00 };
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(new TlsExtensionReplacer(ssl_ec_point_formats_xtn,
......@@ -526,7 +526,7 @@ TEST_P(TlsExtensionTestGeneric, SupportedPointsEmpty) {
}
TEST_P(TlsExtensionTestGeneric, SupportedPointsBadLength) {
EnableSomeECDHECiphers();
EnableSomeEcdheCiphers();
const uint8_t val[] = { 0x99, 0x00, 0x00 };
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(new TlsExtensionReplacer(ssl_ec_point_formats_xtn,
......@@ -534,7 +534,7 @@ TEST_P(TlsExtensionTestGeneric, SupportedPointsBadLength) {
}
TEST_P(TlsExtensionTestGeneric, SupportedPointsTrailingData) {
EnableSomeECDHECiphers();
EnableSomeEcdheCiphers();
const uint8_t val[] = { 0x01, 0x00, 0x00 };
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(new TlsExtensionReplacer(ssl_ec_point_formats_xtn,
......
......@@ -15,7 +15,7 @@
namespace nss_test {
class TlsServerKeyExchangeECDHE {
class TlsServerKeyExchangeEcdhe {
public:
bool Parse(const DataBuffer& buffer) {
TlsParser parser(buffer);
......@@ -45,13 +45,14 @@ TEST_P(TlsConnectGeneric, SetupOnly) {}
TEST_P(TlsConnectGeneric, Connect) {
Connect();
client_->CheckVersion(std::get<1>(GetParam()));
client_->CheckAuthType(ssl_auth_rsa);
}
TEST_P(TlsConnectGeneric, ConnectResumed) {
ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
Connect();
Reset();
ResetRsa();
Connect();
CheckResumption(RESUME_SESSIONID);
}
......@@ -59,7 +60,7 @@ TEST_P(TlsConnectGeneric, ConnectResumed) {
TEST_P(TlsConnectGeneric, ConnectClientCacheDisabled) {
ConfigureSessionCache(RESUME_NONE, RESUME_SESSIONID);
Connect();
Reset();
ResetRsa();
Connect();
CheckResumption(RESUME_NONE);
}
......@@ -67,7 +68,7 @@ TEST_P(TlsConnectGeneric, ConnectClientCacheDisabled) {
TEST_P(TlsConnectGeneric, ConnectServerCacheDisabled) {
ConfigureSessionCache(RESUME_SESSIONID, RESUME_NONE);
Connect();
Reset();
ResetRsa();
Connect();
CheckResumption(RESUME_NONE);
}
......@@ -75,7 +76,7 @@ TEST_P(TlsConnectGeneric, ConnectServerCacheDisabled) {
TEST_P(TlsConnectGeneric, ConnectSessionCacheDisabled) {
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
Connect();
Reset();
ResetRsa();
Connect();
CheckResumption(RESUME_NONE);
}
......@@ -85,7 +86,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeSupportBoth) {
ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
Connect();
Reset();
ResetRsa();
ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
Connect();
CheckResumption(RESUME_TICKET);
......@@ -97,7 +98,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientTicketServerBoth) {
ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH);
Connect();
Reset();
ResetRsa();
ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH);
Connect();
CheckResumption(RESUME_NONE);
......@@ -108,7 +109,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicket) {
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
Connect();
Reset();
ResetRsa();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
Connect();
CheckResumption(RESUME_TICKET);
......@@ -120,7 +121,7 @@ TEST_P(TlsConnectGeneric, ConnectClientServerTicketOnly) {
ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
Connect();
Reset();
ResetRsa();
ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
Connect();
CheckResumption(RESUME_NONE);
......@@ -130,7 +131,7 @@ TEST_P(TlsConnectGeneric, ConnectClientBothServerNone) {
ConfigureSessionCache(RESUME_BOTH, RESUME_NONE);
Connect();
Reset();
ResetRsa();
ConfigureSessionCache(RESUME_BOTH, RESUME_NONE);
Connect();
CheckResumption(RESUME_NONE);
......@@ -140,35 +141,12 @@ TEST_P(TlsConnectGeneric, ConnectClientNoneServerBoth) {
ConfigureSessionCache(RESUME_NONE, RESUME_BOTH);
Connect();
Reset();
ResetRsa();
ConfigureSessionCache(RESUME_NONE, RESUME_BOTH);
Connect();
CheckResumption(RESUME_NONE);
}
TEST_P(TlsConnectGeneric, ConnectTLS_1_1_Only) {
EnsureTlsSetup();
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_1);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_1);
Connect();
client_->CheckVersion(SSL_LIBRARY_VERSION_TLS_1_1);
}
TEST_P(TlsConnectGeneric, ConnectTLS_1_2_Only) {
EnsureTlsSetup();
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_2);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_2);
Connect();
client_->CheckVersion(SSL_LIBRARY_VERSION_TLS_1_2);
}
TEST_P(TlsConnectGeneric, ResumeWithHigherVersion) {
EnsureTlsSetup();
ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
......@@ -178,7 +156,7 @@ TEST_P(TlsConnectGeneric, ResumeWithHigherVersion) {
SSL_LIBRARY_VERSION_TLS_1_1);
Connect();
Reset();
ResetRsa();
EnsureTlsSetup();
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_2);
......@@ -196,65 +174,72 @@ TEST_P(TlsConnectGeneric, ConnectAlpn) {
server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, "a");
}
TEST_P(TlsConnectGeneric, ConnectEcdsa) {
ResetEcdsa();
Connect();
client_->CheckVersion(std::get<1>(GetParam()));
client_->CheckAuthType(ssl_auth_ecdsa);
}
TEST_P(TlsConnectDatagram, ConnectSrtp) {
EnableSrtp();
Connect();
CheckSrtp();
}
TEST_P(TlsConnectStream, ConnectECDHE) {
EnableSomeECDHECiphers();
TEST_P(TlsConnectStream, ConnectEcdhe) {
EnableSomeEcdheCiphers();
Connect();
client_->CheckKEAType(ssl_kea_ecdh);
}
TEST_P(TlsConnectStream, ConnectECDHETwiceReuseKey) {
EnableSomeECDHECiphers();
TEST_P(TlsConnectStream, ConnectEcdheTwiceReuseKey) {
EnableSomeEcdheCiphers();
TlsInspectorRecordHandshakeMessage* i1 =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
server_->SetPacketFilter(i1);
Connect();
client_->CheckKEAType(ssl_kea_ecdh);
TlsServerKeyExchangeECDHE dhe1;
ASSERT_TRUE(dhe1.Parse(i1->buffer()));
TlsServerKeyExchangeEcdhe dhe1;
EXPECT_TRUE(dhe1.Parse(i1->buffer()));
// Restart
Reset();
ResetRsa();
TlsInspectorRecordHandshakeMessage* i2 =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
server_->SetPacketFilter(i2);
EnableSomeECDHECiphers();
EnableSomeEcdheCiphers();
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
Connect();
client_->CheckKEAType(ssl_kea_ecdh);
TlsServerKeyExchangeECDHE dhe2;
ASSERT_TRUE(dhe2.Parse(i2->buffer()));
TlsServerKeyExchangeEcdhe dhe2;
EXPECT_TRUE(dhe2.Parse(i2->buffer()));
// Make sure they are the same.
ASSERT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len());
ASSERT_TRUE(!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
EXPECT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len());
EXPECT_TRUE(!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
dhe1.public_key_.len()));
}
TEST_P(TlsConnectStream, ConnectECDHETwiceNewKey) {
EnableSomeECDHECiphers();
TEST_P(TlsConnectStream, ConnectEcdheTwiceNewKey) {
EnableSomeEcdheCiphers();
SECStatus rv =
SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
ASSERT_EQ(SECSuccess, rv);
EXPECT_EQ(SECSuccess, rv);
TlsInspectorRecordHandshakeMessage* i1 =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
server_->SetPacketFilter(i1);
Connect();
client_->CheckKEAType(ssl_kea_ecdh);
TlsServerKeyExchangeECDHE dhe1;
ASSERT_TRUE(dhe1.Parse(i1->buffer()));
TlsServerKeyExchangeEcdhe dhe1;
EXPECT_TRUE(dhe1.Parse(i1->buffer()));
// Restart
Reset();
EnableSomeECDHECiphers();
ResetRsa();
EnableSomeEcdheCiphers();
rv = SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
ASSERT_EQ(SECSuccess, rv);
EXPECT_EQ(SECSuccess, rv);
TlsInspectorRecordHandshakeMessage* i2 =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
server_->SetPacketFilter(i2);
......@@ -262,11 +247,11 @@ TEST_P(TlsConnectStream, ConnectECDHETwiceNewKey) {
Connect();
client_->CheckKEAType(ssl_kea_ecdh);
TlsServerKeyExchangeECDHE dhe2;
ASSERT_TRUE(dhe2.Parse(i2->buffer()));
TlsServerKeyExchangeEcdhe dhe2;
EXPECT_TRUE(dhe2.Parse(i2->buffer()));
// Make sure they are different.
ASSERT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) &&
EXPECT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) &&
(!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
dhe1.public_key_.len())));
}
......
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "sslerr.h"
#include "tls_parser.h"
#include "tls_filter.h"
#include "tls_connect.h"
/*
* The tests in this file test that the TLS state machine is robust against
* attacks that alter the order of handshake messages.
*
* See <https://www.smacktls.com/smack.pdf> for a description of the problems
* that this sort of attack can enable.
*/
namespace nss_test {
class TlsHandshakeSkipFilter : public TlsRecordFilter {
public:
// A TLS record filter that skips handshake messages of the identified type.
TlsHandshakeSkipFilter(uint8_t handshake_type)
: handshake_type_(handshake_type),
skipped_(false) {}
protected:
// Takes a record; if it is a handshake record, it removes the first handshake
// message that is of handshake_type_ type.
virtual bool FilterRecord(uint8_t content_type, uint16_t version,
const DataBuffer& input, DataBuffer* output) {
if (content_type != kTlsHandshakeType) {
return false;
}
size_t output_offset = 0U;
output->Allocate(input.len());
TlsParser parser(input);
while (parser.remaining()) {
size_t start = parser.consumed();
uint8_t handshake_type;
if (!parser.Read(&handshake_type)) {
return false;
}
uint32_t length;
if (!TlsHandshakeFilter::ReadLength(&parser, version, &length)) {
return false;
}
if (!parser.Skip(length)) {
return false;
}
if (skipped_ || handshake_type != handshake_type_) {
size_t entire_length = parser.consumed() - start;
output->Write(output_offset, input.data() + start,
entire_length);
// DTLS sequence numbers need to be rewritten
if (skipped_ && IsDtls(version)) {
output->data()[start + 5] -= 1;
}
output_offset += entire_length;
} else {
std::cerr << "Dropping handshake: "
<< static_cast<unsigned>(handshake_type_) << std::endl;
// We only need to report that the output contains changed data if we
// drop a handshake message. But once we've skipped one message, we
// have to modify all subsequent handshake messages so that they include
// the correct DTLS sequence numbers.
skipped_ = true;
}
}
output->Truncate(output_offset);
return skipped_;
}
private:
// The type of handshake message to drop.
uint8_t handshake_type_;
// Whether this filter has ever skipped a handshake message. Track this so
// that sequence numbers on DTLS handshake messages can be rewritten in
// subsequent calls.
bool skipped_;
};
class TlsSkipTest
: public TlsConnectTestBase,
public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
protected:
TlsSkipTest()
: TlsConnectTestBase(TlsConnectTestBase::ToMode(std::get<0>(GetParam())),
std::get<1>(GetParam())) {}
void ServerSkipTest(PacketFilter* filter,
uint8_t alert = kTlsAlertUnexpectedMessage) {
auto alert_recorder = new TlsAlertRecorder();
client_->SetPacketFilter(alert_recorder);
if (filter) {
server_->SetPacketFilter(filter);
}
ConnectExpectFail();
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
EXPECT_EQ(alert, alert_recorder->description());
}
};
TEST_P(TlsSkipTest, SkipCertificate) {
ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertificateEcdhe) {
EnableSomeEcdheCiphers();
ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipCertificateEcdsa) {
ResetEcdsa();
ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipServerKeyExchange) {
// Have to enable some ephemeral suites, or ServerKeyExchange doesn't appear.
EnableSomeEcdheCiphers();
ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) {
ResetEcdsa();
ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
auto chain = new ChainedPacketFilter();
chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) {
ResetEcdsa();
auto chain = new ChainedPacketFilter();
chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
INSTANTIATE_TEST_CASE_P(SkipTls10, TlsSkipTest,
::testing::Combine(
TlsConnectTestBase::kTlsModesStream,
TlsConnectTestBase::kTlsV10));
INSTANTIATE_TEST_CASE_P(SkipVariants, TlsSkipTest,
::testing::Combine(
TlsConnectTestBase::kTlsModesAll,
TlsConnectTestBase::kTlsV11V12));
} // namespace nss_test
......@@ -42,7 +42,7 @@ bool TlsAgent::EnsureTlsSetup() {
EXPECT_NE(nullptr, priv);
if (!priv) return false; // Leak cert.
SECStatus rv = SSL_ConfigSecureServer(ssl_fd_, cert, priv, kt_rsa);
SECStatus rv = SSL_ConfigSecureServer(ssl_fd_, cert, priv, kea_);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false; // Leak cert and key.
......@@ -71,40 +71,42 @@ bool TlsAgent::EnsureTlsSetup() {
}
void TlsAgent::StartConnect() {
ASSERT_TRUE(EnsureTlsSetup());
EXPECT_TRUE(EnsureTlsSetup());
SECStatus rv;
rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE);
ASSERT_EQ(SECSuccess, rv);
EXPECT_EQ(SECSuccess, rv);
SetState(CONNECTING);
}
void TlsAgent::EnableSomeECDHECiphers() {
ASSERT_TRUE(EnsureTlsSetup());
void TlsAgent::EnableSomeEcdheCiphers() {
EXPECT_TRUE(EnsureTlsSetup());
const uint32_t EnabledCiphers[] = {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA};
const uint32_t EcdheCiphers[] = {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA};
for (size_t i = 0; i < PR_ARRAY_SIZE(EnabledCiphers); ++i) {
SECStatus rv = SSL_CipherPrefSet(ssl_fd_, EnabledCiphers[i], PR_TRUE);
ASSERT_EQ(SECSuccess, rv);
for (size_t i = 0; i < PR_ARRAY_SIZE(EcdheCiphers); ++i) {
SECStatus rv = SSL_CipherPrefSet(ssl_fd_, EcdheCiphers[i], PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
}
void TlsAgent::SetSessionTicketsEnabled(bool en) {
ASSERT_TRUE(EnsureTlsSetup());
EXPECT_TRUE(EnsureTlsSetup());
SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS,
en ? PR_TRUE : PR_FALSE);
ASSERT_EQ(SECSuccess, rv);
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::SetSessionCacheEnabled(bool en) {
ASSERT_TRUE(EnsureTlsSetup());
EXPECT_TRUE(EnsureTlsSetup());
SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE,
en ? PR_FALSE : PR_TRUE);
ASSERT_EQ(SECSuccess, rv);
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
......@@ -113,25 +115,30 @@ void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
if (ssl_fd_) {
SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
ASSERT_EQ(SECSuccess, rv);
EXPECT_EQ(SECSuccess, rv);
}
}
void TlsAgent::CheckKEAType(SSLKEAType type) const {
ASSERT_EQ(CONNECTED, state_);
ASSERT_EQ(type, csinfo_.keaType);
EXPECT_EQ(CONNECTED, state_);
EXPECT_EQ(type, csinfo_.keaType);
}
void TlsAgent::CheckAuthType(SSLAuthType type) const {
EXPECT_EQ(CONNECTED, state_);
EXPECT_EQ(type, csinfo_.authAlgorithm);
}
void TlsAgent::CheckVersion(uint16_t version) const {
ASSERT_EQ(CONNECTED, state_);
ASSERT_EQ(version, info_.protocolVersion);
EXPECT_EQ(CONNECTED, state_);
EXPECT_EQ(version, info_.protocolVersion);
}
void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
ASSERT_TRUE(EnsureTlsSetup());
EXPECT_TRUE(EnsureTlsSetup());
ASSERT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE));
ASSERT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len));
EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE));
EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len));
}
void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
......@@ -142,37 +149,41 @@ void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
SECStatus rv = SSL_GetNextProto(ssl_fd_, &state,
reinterpret_cast<unsigned char*>(chosen),
&chosen_len, sizeof(chosen));
ASSERT_EQ(SECSuccess, rv);
ASSERT_EQ(expected_state, state);
ASSERT_EQ(expected, std::string(chosen, chosen_len));
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ(expected_state, state);
EXPECT_EQ(expected, std::string(chosen, chosen_len));
}
void TlsAgent::EnableSrtp() {
ASSERT_TRUE(EnsureTlsSetup());
EXPECT_TRUE(EnsureTlsSetup());
const uint16_t ciphers[] = {
SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32
};
ASSERT_EQ(SECSuccess, SSL_SetSRTPCiphers(ssl_fd_, ciphers,
EXPECT_EQ(SECSuccess, SSL_SetSRTPCiphers(ssl_fd_, ciphers,
PR_ARRAY_SIZE(ciphers)));
}
void TlsAgent::CheckSrtp() {
uint16_t actual;
ASSERT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual));
ASSERT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual));
EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
}
void TlsAgent::CheckErrorCode(int32_t expected) const {
EXPECT_EQ(ERROR, state_);
EXPECT_EQ(expected, error_code_);
}
void TlsAgent::Handshake() {
SECStatus rv = SSL_ForceHandshake(ssl_fd_);
if (rv == SECSuccess) {
LOG("Handshake success");
SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_));
ASSERT_EQ(SECSuccess, rv);
EXPECT_EQ(SECSuccess, rv);
rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_));
ASSERT_EQ(SECSuccess, rv);
EXPECT_EQ(SECSuccess, rv);
SetState(CONNECTED);
return;
......@@ -192,25 +203,26 @@ void TlsAgent::Handshake() {
case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
default:
LOG("Handshake failed with error " << err);
error_code_ = err;
SetState(ERROR);
return;
}
}
void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
ASSERT_TRUE(EnsureTlsSetup());
EXPECT_TRUE(EnsureTlsSetup());
SECStatus rv = SSL_OptionSet(ssl_fd_,
SSL_NO_CACHE,
mode & RESUME_SESSIONID ?
PR_FALSE : PR_TRUE);
ASSERT_EQ(SECSuccess, rv);
EXPECT_EQ(SECSuccess, rv);
rv = SSL_OptionSet(ssl_fd_,
SSL_ENABLE_SESSION_TICKETS,
mode & RESUME_TICKET ?
PR_TRUE : PR_FALSE);
ASSERT_EQ(SECSuccess, rv);
EXPECT_EQ(SECSuccess, rv);
}
......
......@@ -33,14 +33,16 @@ class TlsAgent : public PollTarget {
enum Role { CLIENT, SERVER };
enum State { INIT, CONNECTING, CONNECTED, ERROR };
TlsAgent(const std::string& name, Role role, Mode mode)
TlsAgent(const std::string& name, Role role, Mode mode, SSLKEAType kea)
: name_(name),
mode_(mode),
kea_(kea),
pr_fd_(nullptr),
adapter_(nullptr),
ssl_fd_(nullptr),
role_(role),
state_(INIT) {
state_(INIT),
error_code_(0) {
memset(&info_, 0, sizeof(info_));
memset(&csinfo_, 0, sizeof(csinfo_));
SECStatus rv = SSL_VersionRangeGetDefault(mode_ == STREAM ?
......@@ -78,10 +80,11 @@ class TlsAgent : public PollTarget {
void StartConnect();
void CheckKEAType(SSLKEAType type) const;
void CheckAuthType(SSLAuthType type) const;