Skip to content

Commit

Permalink
Bug 1348720 - Move TlsAlertRecorder to some specific alert tests, r=t…
Browse files Browse the repository at this point in the history
…taubert

--HG--
extra : amend_source : a446bacace8010589bd412d58030d1414dade315
extra : histedit_source : 6457247f1a47b201a08e91814ffdd9dfb47bbdf9
  • Loading branch information
martinthomson committed Mar 23, 2017
1 parent 49e0d7f commit 092d015
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 77 deletions.
11 changes: 0 additions & 11 deletions gtests/ssl_gtest/ssl_gather_unittest.cc
Expand Up @@ -16,17 +16,11 @@ class GatherV2ClientHelloTest : public TlsConnectTestBase {
void ConnectExpectMalformedClientHello(const DataBuffer &data) {
EnsureTlsSetup();
server_->ExpectSendAlert(kTlsAlertIllegalParameter);
auto alert_recorder = std::make_shared<TlsAlertRecorder>();
server_->SetPacketFilter(alert_recorder);

client_->SendDirect(data);
server_->StartConnect();
server_->Handshake();
ASSERT_TRUE_WAIT(
(server_->error_code() == SSL_ERROR_RX_MALFORMED_CLIENT_HELLO), 2000);

EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
EXPECT_EQ(kTlsAlertIllegalParameter, alert_recorder->description());
}
};

Expand Down Expand Up @@ -56,16 +50,11 @@ TEST_F(TlsConnectTest, GatherExcessiveV3Record) {

EnsureTlsSetup();
server_->ExpectSendAlert(kTlsAlertRecordOverflow);
auto alert_recorder = std::make_shared<TlsAlertRecorder>();
server_->SetPacketFilter(alert_recorder);
client_->SendDirect(buffer);
server_->StartConnect();
server_->Handshake();
ASSERT_TRUE_WAIT((server_->error_code() == SSL_ERROR_RX_RECORD_TOO_LONG),
2000);

EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
EXPECT_EQ(kTlsAlertRecordOverflow, alert_recorder->description());
}

// Gather a 3-byte v2 header, with a fragment length of 2.
Expand Down
88 changes: 87 additions & 1 deletion gtests/ssl_gtest/ssl_loopback_unittest.cc
Expand Up @@ -39,7 +39,7 @@ TEST_P(TlsConnectGeneric, ConnectEcdsa) {
CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
}

TEST_P(TlsConnectGenericPre13, CipherSuiteMismatch) {
TEST_P(TlsConnectGeneric, CipherSuiteMismatch) {
EnsureTlsSetup();
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
Expand All @@ -53,6 +53,92 @@ TEST_P(TlsConnectGenericPre13, CipherSuiteMismatch) {
server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
}

class TlsAlertRecorder : public TlsRecordFilter {
public:
TlsAlertRecorder() : level_(255), description_(255) {}

PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output) override {
if (level_ != 255) { // Already captured.
return KEEP;
}
if (header.content_type() != kTlsAlertType) {
return KEEP;
}

std::cerr << "Alert: " << input << std::endl;

TlsParser parser(input);
EXPECT_TRUE(parser.Read(&level_));
EXPECT_TRUE(parser.Read(&description_));
return KEEP;
}

uint8_t level() const { return level_; }
uint8_t description() const { return description_; }

private:
uint8_t level_;
uint8_t description_;
};

class HelloTruncator : public TlsHandshakeFilter {
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) override {
if (header.handshake_type() != kTlsHandshakeClientHello &&
header.handshake_type() != kTlsHandshakeServerHello) {
return KEEP;
}
output->Assign(input.data(), input.len() - 1);
return CHANGE;
}
};

// Verify that when NSS reports that an alert is sent, it is actually sent.
TEST_P(TlsConnectGeneric, CaptureAlertServer) {
client_->SetPacketFilter(std::make_shared<HelloTruncator>());
auto alert_recorder = std::make_shared<TlsAlertRecorder>();
server_->SetPacketFilter(alert_recorder);

ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
EXPECT_EQ(kTlsAlertIllegalParameter, alert_recorder->description());
}

TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
server_->SetPacketFilter(std::make_shared<HelloTruncator>());
auto alert_recorder = std::make_shared<TlsAlertRecorder>();
client_->SetPacketFilter(alert_recorder);

ConnectExpectAlert(client_, kTlsAlertDecodeError);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
}

// In TLS 1.3, the server can't read the client alert.
TEST_P(TlsConnectTls13, CaptureAlertClient) {
server_->SetPacketFilter(std::make_shared<HelloTruncator>());
auto alert_recorder = std::make_shared<TlsAlertRecorder>();
client_->SetPacketFilter(alert_recorder);

server_->StartConnect();
client_->StartConnect();

client_->Handshake();
client_->ExpectSendAlert(kTlsAlertDecodeError);
server_->Handshake();
client_->Handshake();
if (mode_ == STREAM) {
// DTLS just drops the alert it can't decrypt.
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
}
server_->Handshake();
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
}

TEST_P(TlsConnectGenericPre13, ConnectFalseStart) {
client_->EnableFalseStart();
Connect();
Expand Down
8 changes: 2 additions & 6 deletions gtests/ssl_gtest/ssl_skip_unittest.cc
Expand Up @@ -87,12 +87,8 @@ class TlsSkipTest

void ServerSkipTest(std::shared_ptr<PacketFilter> filter,
uint8_t alert = kTlsAlertUnexpectedMessage) {
auto alert_recorder = std::make_shared<TlsAlertRecorder>();
client_->SetPacketFilter(alert_recorder);
server_->SetPacketFilter(filter);
ConnectExpectAlert(client_, alert);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
EXPECT_EQ(alert, alert_recorder->description());
}
};

Expand Down Expand Up @@ -130,6 +126,8 @@ class Tls13SkipTest : public TlsConnectTestBase,

server_->CheckErrorCode(error);
ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state());

client_->Handshake(); // Make sure to consume the alert the server sends.
}
};

Expand Down Expand Up @@ -218,7 +216,6 @@ TEST_P(Tls13SkipTest, SkipClientCertificate) {
ClientSkipTest(
std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
client_->Handshake(); // Make sure to consume the alert.
}

TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
Expand All @@ -228,7 +225,6 @@ TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
ClientSkipTest(
std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
SSL_ERROR_RX_UNEXPECTED_FINISHED);
client_->Handshake(); // Make sure to consume the alert.
}

INSTANTIATE_TEST_CASE_P(SkipTls10, TlsSkipTest,
Expand Down
16 changes: 0 additions & 16 deletions gtests/ssl_gtest/ssl_staticrsa_unittest.cc
Expand Up @@ -52,11 +52,7 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) {
kTlsHandshakeClientKeyExchange,
DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
client_->SetPacketFilter(i1);
auto alert_recorder = std::make_shared<TlsAlertRecorder>();
server_->SetPacketFilter(alert_recorder);
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
}

// Test that a PMS with a bogus version number is handled correctly.
Expand All @@ -65,11 +61,7 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) {
EnableOnlyStaticRsaCiphers();
client_->SetPacketFilter(
std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
auto alert_recorder = std::make_shared<TlsAlertRecorder>();
server_->SetPacketFilter(alert_recorder);
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
}

// Test that a PMS with a bogus version number is ignored when
Expand All @@ -91,11 +83,7 @@ TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) {
kTlsHandshakeClientKeyExchange,
DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
client_->SetPacketFilter(inspect);
auto alert_recorder = std::make_shared<TlsAlertRecorder>();
server_->SetPacketFilter(alert_recorder);
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
}

// This test is stream so we can catch the bad_record_mac alert.
Expand All @@ -105,11 +93,7 @@ TEST_P(TlsConnectStreamPre13,
EnableExtendedMasterSecret();
client_->SetPacketFilter(
std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
auto alert_recorder = std::make_shared<TlsAlertRecorder>();
server_->SetPacketFilter(alert_recorder);
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
}

TEST_P(TlsConnectStreamPre13,
Expand Down
25 changes: 0 additions & 25 deletions gtests/ssl_gtest/tls_filter.cc
Expand Up @@ -369,31 +369,6 @@ PacketFilter::Action TlsConversationRecorder::FilterRecord(
return KEEP;
}

PacketFilter::Action TlsAlertRecorder::FilterRecord(
const TlsRecordHeader& header, const DataBuffer& input,
DataBuffer* output) {
if (level_ == kTlsAlertFatal) { // already fatal
return KEEP;
}
if (header.content_type() != kTlsAlertType) {
return KEEP;
}

std::cerr << "Alert: " << input << std::endl;

TlsParser parser(input);
uint8_t lvl;
if (!parser.Read(&lvl)) {
return KEEP;
}
if (lvl == kTlsAlertWarning) { // not strong enough
return KEEP;
}
level_ = lvl;
(void)parser.Read(&description_);
return KEEP;
}

PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
DataBuffer* output) {
DataBuffer in(input);
Expand Down
18 changes: 0 additions & 18 deletions gtests/ssl_gtest/tls_filter.h
Expand Up @@ -232,24 +232,6 @@ class TlsConversationRecorder : public TlsRecordFilter {
DataBuffer& buffer_;
};

// Records an alert. If an alert has already been recorded, it won't save the
// new alert unless the old alert is a warning and the new one is fatal.
class TlsAlertRecorder : public TlsRecordFilter {
public:
TlsAlertRecorder() : level_(255), description_(255) {}

virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);

uint8_t level() const { return level_; }
uint8_t description() const { return description_; }

private:
uint8_t level_;
uint8_t description_;
};

// Runs multiple packet filters in series.
class ChainedPacketFilter : public PacketFilter {
public:
Expand Down

0 comments on commit 092d015

Please sign in to comment.