Skip to content

Commit

Permalink
Bug 1385199 - Improve decrypting filter, r=ekr
Browse files Browse the repository at this point in the history
--HG--
branch : NSS_TLS13_DRAFT19_BRANCH
extra : rebase_source : 4b3d406fbcb002791f4118520661ae66758bdcbe
extra : amend_source : 7a0eaf2f0bdb407f73f8263855c07413c865f6a6
  • Loading branch information
martinthomson committed Jul 28, 2017
1 parent 50c65c0 commit cb58afa
Show file tree
Hide file tree
Showing 11 changed files with 172 additions and 24 deletions.
1 change: 0 additions & 1 deletion gtests/ssl_gtest/ssl_0rtt_unittest.cc
Expand Up @@ -119,7 +119,6 @@ class TlsZeroRttReplayTest : public TlsConnectTls13 {
auto early_data_ext =
std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn);
server_->SetPacketFilter(early_data_ext);
early_data_ext->EnableDecryption();

// Finally, replay the ClientHello and force the server to consume it. Stop
// after the server sends its first flight; the client will not be able to
Expand Down
1 change: 0 additions & 1 deletion gtests/ssl_gtest/ssl_custext_unittest.cc
Expand Up @@ -329,7 +329,6 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientEE) {
auto capture = std::make_shared<TlsExtensionCapture>(extension_code);
capture->SetHandshakeTypes({kTlsHandshakeEncryptedExtensions});
server_->SetTlsRecordFilter(capture);
capture->EnableDecryption();

Connect();

Expand Down
2 changes: 0 additions & 2 deletions gtests/ssl_gtest/ssl_damage_unittest.cc
Expand Up @@ -79,7 +79,6 @@ TEST_P(TlsConnectTls13, DamageServerSignature) {
auto filter =
std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify);
server_->SetTlsRecordFilter(filter);
filter->EnableDecryption();
client_->ExpectSendAlert(kTlsAlertDecryptError);
// The server can't read the client's alert, so it also sends an alert.
if (variant_ == ssl_variant_stream) {
Expand All @@ -100,7 +99,6 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) {
std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify);
client_->SetTlsRecordFilter(filter);
server_->ExpectSendAlert(kTlsAlertDecryptError);
filter->EnableDecryption();
// Do these handshakes by hand to avoid race condition on
// the client processing the server's alert.
client_->StartConnect();
Expand Down
1 change: 0 additions & 1 deletion gtests/ssl_gtest/ssl_extension_unittest.cc
Expand Up @@ -1009,7 +1009,6 @@ class TlsBogusExtensionTest : public TlsConnectTestBase,
std::make_shared<TlsExtensionAppender>(message, extension, empty);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
server_->SetTlsRecordFilter(filter);
filter->EnableDecryption();
} else {
server_->SetPacketFilter(filter);
}
Expand Down
2 changes: 1 addition & 1 deletion gtests/ssl_gtest/ssl_fragment_unittest.cc
Expand Up @@ -82,7 +82,7 @@ class RecordFragmenter : public PacketFilter {
while (parser.remaining()) {
TlsRecordHeader header;
DataBuffer record;
if (!header.Parse(&parser, &record)) {
if (!header.Parse(0, &parser, &record)) {
ADD_FAILURE() << "bad record header";
return false;
}
Expand Down
99 changes: 99 additions & 0 deletions gtests/ssl_gtest/ssl_loopback_unittest.cc
Expand Up @@ -189,6 +189,105 @@ TEST_P(TlsConnectGeneric, ConnectSendReceive) {
SendReceive();
}

class SaveTlsRecord : public TlsRecordFilter {
public:
SaveTlsRecord(size_t index) : index_(index), count_(0), contents_() {}

const DataBuffer& contents() const { return contents_; }

protected:
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& data,
DataBuffer* changed) override {
if (count_++ == index_) {
contents_ = data;
}
return KEEP;
}

private:
const size_t index_;
size_t count_;
DataBuffer contents_;
};

// Check that decrypting filters work and can read any record.
// This test (currently) only works in TLS 1.3 where we can decrypt.
TEST_F(TlsConnectStreamTls13, DecryptRecordClient) {
EnsureTlsSetup();
// 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer
auto saved = std::make_shared<SaveTlsRecord>(3);
client_->SetTlsRecordFilter(saved);
Connect();
SendReceive();

static const uint8_t data[] = {0xde, 0xad, 0xdc};
DataBuffer buf(data, sizeof(data));
client_->SendBuffer(buf);
EXPECT_EQ(buf, saved->contents());
}

TEST_F(TlsConnectStreamTls13, DecryptRecordServer) {
EnsureTlsSetup();
// Disable tickets so that we are sure to not get NewSessionTicket.
EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
// 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer
auto saved = std::make_shared<SaveTlsRecord>(3);
server_->SetTlsRecordFilter(saved);
Connect();
SendReceive();

static const uint8_t data[] = {0xde, 0xad, 0xd5};
DataBuffer buf(data, sizeof(data));
server_->SendBuffer(buf);
EXPECT_EQ(buf, saved->contents());
}

class DropTlsRecord : public TlsRecordFilter {
public:
DropTlsRecord(size_t index) : index_(index), count_(0) {}

protected:
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& data,
DataBuffer* changed) override {
if (count_++ == index_) {
return DROP;
}
return KEEP;
}

private:
const size_t index_;
size_t count_;
};

// Test that decrypting filters work correctly and are able to drop records.
TEST_F(TlsConnectStreamTls13, DropRecordServer) {
EnsureTlsSetup();
// Disable session tickets so that the server doesn't send an extra record.
EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
SSL_ENABLE_SESSION_TICKETS, PR_FALSE));

// 0 = ServerHello, 1 = other handshake, 2 = first write
server_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2));
Connect();
server_->SendData(23, 23); // This should be dropped, so it won't be counted.
server_->ResetSentBytes();
SendReceive();
}

TEST_F(TlsConnectStreamTls13, DropRecordClient) {
EnsureTlsSetup();
// 0 = ClientHello, 1 = Finished, 2 = first write
client_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2));
Connect();
client_->SendData(26, 26); // This should be dropped, so it won't be counted.
client_->ResetSentBytes();
SendReceive();
}

// The next two tests takes advantage of the fact that we
// automatically read the first 1024 bytes, so if
// we provide 1200 bytes, they overrun the read buffer
Expand Down
2 changes: 0 additions & 2 deletions gtests/ssl_gtest/ssl_record_unittest.cc
Expand Up @@ -137,7 +137,6 @@ TEST_F(TlsConnectStreamTls13, LargeRecord) {
const size_t record_limit = 16384;
auto replacer = std::make_shared<RecordReplacer>(record_limit);
client_->SetTlsRecordFilter(replacer);
replacer->EnableDecryption();
Connect();

replacer->Enable();
Expand All @@ -152,7 +151,6 @@ TEST_F(TlsConnectStreamTls13, TooLargeRecord) {
const size_t record_limit = 16384;
auto replacer = std::make_shared<RecordReplacer>(record_limit + 1);
client_->SetTlsRecordFilter(replacer);
replacer->EnableDecryption();
Connect();

replacer->Enable();
Expand Down
2 changes: 0 additions & 2 deletions gtests/ssl_gtest/ssl_skip_unittest.cc
Expand Up @@ -101,7 +101,6 @@ class Tls13SkipTest : public TlsConnectTestBase,
void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
EnsureTlsSetup();
server_->SetTlsRecordFilter(filter);
filter->EnableDecryption();
client_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
if (variant_ == ssl_variant_stream) {
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
Expand All @@ -120,7 +119,6 @@ class Tls13SkipTest : public TlsConnectTestBase,
void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
EnsureTlsSetup();
client_->SetTlsRecordFilter(filter);
filter->EnableDecryption();
server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
ConnectExpectFailOneSide(TlsAgent::SERVER);

Expand Down
2 changes: 2 additions & 0 deletions gtests/ssl_gtest/tls_agent.h
Expand Up @@ -81,9 +81,11 @@ class TlsAgent : public PollTarget {
adapter_->SetPeer(peer->adapter_);
}

// Set a filter that can access plaintext (TLS 1.3 only).
void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) {
filter->SetAgent(this);
adapter_->SetPacketFilter(filter);
filter->EnableDecryption();
}

void SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
Expand Down
68 changes: 56 additions & 12 deletions gtests/ssl_gtest/tls_filter.cc
Expand Up @@ -57,12 +57,17 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending,
PRBool isServer = self->agent()->role() == TlsAgent::SERVER;

if (g_ssl_gtest_verbose) {
std::cerr << "Cipher spec changed. Role="
<< (isServer ? "server" : "client")
<< " direction=" << (sending ? "send" : "receive") << std::endl;
std::cerr << (isServer ? "server" : "client") << ": "
<< (sending ? "send" : "receive")
<< " cipher spec changed: " << newSpec->phase << std::endl;
}
if (!sending) {
return;
}
if (!sending) return;

self->in_sequence_number_ = 0;
self->out_sequence_number_ = 0;
self->dropped_record_ = false;
self->cipher_spec_.reset(new TlsCipherSpec());
bool ret =
self->cipher_spec_->Init(SSLInt_CipherSpecToAlgorithm(isServer, newSpec),
Expand All @@ -83,11 +88,23 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
TlsRecordHeader header;
DataBuffer record;

if (!header.Parse(&parser, &record)) {
if (!header.Parse(in_sequence_number_, &parser, &record)) {
ADD_FAILURE() << "not a valid record";
return KEEP;
}

// Track the sequence number, which is necessary for stream mode (the
// sequence number is in the header for datagram).
//
// This isn't perfectly robust. If there is a change from an active cipher
// spec to another active cipher spec (KeyUpdate for instance) AND writes
// are consolidated across that change AND packets were dropped from the
// older epoch, we will not correctly re-encrypt records in the old epoch to
// update their sequence numbers.
if (cipher_spec_ && header.content_type() == kTlsApplicationDataType) {
++in_sequence_number_;
}

if (FilterRecord(header, record, &offset, output) != KEEP) {
changed = true;
} else {
Expand Down Expand Up @@ -120,30 +137,49 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
header.sequence_number()};

PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered);
// In stream mode, even if something doesn't change we need to re-encrypt if
// previous packets were dropped.
if (action == KEEP) {
return KEEP;
if (header.is_dtls() || !dropped_record_) {
return KEEP;
}
filtered = plaintext;
}

if (action == DROP) {
std::cerr << "record drop: " << record << std::endl;
dropped_record_ = true;
return DROP;
}

EXPECT_GT(0x10000U, filtered.len());
std::cerr << "record old: " << plaintext << std::endl;
std::cerr << "record new: " << filtered << std::endl;
if (action != KEEP) {
std::cerr << "record old: " << plaintext << std::endl;
std::cerr << "record new: " << filtered << std::endl;
}

uint64_t seq_num;
if (header.is_dtls() || !cipher_spec_ ||
header.content_type() != kTlsApplicationDataType) {
seq_num = header.sequence_number();
} else {
seq_num = out_sequence_number_++;
}
TlsRecordHeader out_header = {header.version(), header.content_type(),
seq_num};

DataBuffer ciphertext;
bool rv = Protect(header, inner_content_type, filtered, &ciphertext);
bool rv = Protect(out_header, inner_content_type, filtered, &ciphertext);
EXPECT_TRUE(rv);
if (!rv) {
return KEEP;
}
*offset = header.Write(output, *offset, ciphertext);
*offset = out_header.Write(output, *offset, ciphertext);
return CHANGE;
}

bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
bool TlsRecordHeader::Parse(uint64_t sequence_number, TlsParser* parser,
DataBuffer* body) {
if (!parser->Read(&content_type_)) {
return false;
}
Expand All @@ -154,7 +190,7 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
}
version_ = version;

sequence_number_ = 0;
// If this is DTLS, overwrite the sequence number.
if (IsDtls(version)) {
uint32_t tmp;
if (!parser->Read(&tmp, 4)) {
Expand All @@ -165,6 +201,8 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
return false;
}
sequence_number_ |= static_cast<uint64_t>(tmp);
} else {
sequence_number_ = sequence_number;
}
return parser->ReadVariable(body, 2);
}
Expand Down Expand Up @@ -193,6 +231,9 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
return true;
}

if (g_ssl_gtest_verbose) {
std::cerr << "unprotect: " << header.sequence_number() << std::endl;
}
if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) return false;

size_t len = plaintext->len();
Expand All @@ -218,6 +259,9 @@ bool TlsRecordFilter::Protect(const TlsRecordHeader& header,
*ciphertext = plaintext;
return true;
}
if (g_ssl_gtest_verbose) {
std::cerr << "protect: " << header.sequence_number() << std::endl;
}
DataBuffer padded = plaintext;
padded.Write(padded.len(), inner_content_type, 1);
return cipher_spec_->Protect(header, padded, ciphertext);
Expand Down
16 changes: 14 additions & 2 deletions gtests/ssl_gtest/tls_filter.h
Expand Up @@ -53,7 +53,7 @@ class TlsRecordHeader : public TlsVersioned {
size_t header_length() const { return is_dtls() ? 11 : 3; }

// Parse the header; return true if successful; body in an outparam if OK.
bool Parse(TlsParser* parser, DataBuffer* body);
bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body);
// Write the header and body to a buffer at the given offset.
// Return the offset of the end of the write.
size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
Expand All @@ -66,7 +66,13 @@ class TlsRecordHeader : public TlsVersioned {
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_() {}
TlsRecordFilter()
: agent_(nullptr),
count_(0),
cipher_spec_(),
dropped_record_(false),
in_sequence_number_(0),
out_sequence_number_(0) {}

void SetAgent(const TlsAgent* agent) { agent_ = agent; }
const TlsAgent* agent() const { return agent_; }
Expand Down Expand Up @@ -115,6 +121,12 @@ class TlsRecordFilter : public PacketFilter {
const TlsAgent* agent_;
size_t count_;
std::unique_ptr<TlsCipherSpec> cipher_spec_;
// Whether we dropped a record since the cipher spec changed.
bool dropped_record_;
// The sequence number we use for reading records as they are written.
uint64_t in_sequence_number_;
// The sequence number we use for writing modified records.
uint64_t out_sequence_number_;
};

inline std::ostream& operator<<(std::ostream& stream, TlsVersioned v) {
Expand Down

0 comments on commit cb58afa

Please sign in to comment.