Skip to content

Commit

Permalink
Bug 1315405 - Refactor DTLS fragment handling, r=ttaubert
Browse files Browse the repository at this point in the history
--HG--
extra : rebase_source : cdf66d5b05329f234d9483373e182d55d012f954
extra : amend_source : f577026edca15d3dff3b63dc3a21fcce8825fb4e
  • Loading branch information
martinthomson committed Nov 21, 2016
1 parent 8b1121d commit bab54a9
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 37 deletions.
1 change: 1 addition & 0 deletions gtests/ssl_gtest/manifest.mn
Expand Up @@ -24,6 +24,7 @@ CPPSRCS = \
ssl_ems_unittest.cc \
ssl_exporter_unittest.cc \
ssl_extension_unittest.cc \
ssl_fragment_unittest.cc \
ssl_fuzz_unittest.cc \
ssl_gtest.cc \
ssl_hrr_unittest.cc \
Expand Down
158 changes: 158 additions & 0 deletions gtests/ssl_gtest/ssl_fragment_unittest.cc
@@ -0,0 +1,158 @@
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at http://mozilla.org/MPL/2.0/. */

#include "secerr.h"
#include "ssl.h"
#include "sslerr.h"
#include "sslproto.h"

#include "gtest_utils.h"
#include "scoped_ptrs.h"
#include "tls_connect.h"
#include "tls_filter.h"
#include "tls_parser.h"

namespace nss_test {

// This class cuts every unencrypted handshake record into two parts.
class RecordFragmenter : public PacketFilter {
public:
RecordFragmenter() : sequence_number_(0), splitting_(true) {}

private:
class HandshakeSplitter {
public:
HandshakeSplitter(const DataBuffer& input, DataBuffer* output,
uint64_t* sequence_number)
: input_(input),
output_(output),
cursor_(0),
sequence_number_(sequence_number) {}

private:
void WriteRecord(TlsRecordFilter::RecordHeader& record_header,
DataBuffer& record_fragment) {
TlsRecordFilter::RecordHeader fragment_header(
record_header.version(), record_header.content_type(),
*sequence_number_);
++*sequence_number_;
if (::g_ssl_gtest_verbose) {
std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment
<< std::endl;
}
cursor_ = fragment_header.Write(output_, cursor_, record_fragment);
}

bool SplitRecord(TlsRecordFilter::RecordHeader& record_header,
DataBuffer& record) {
TlsParser parser(record);
while (parser.remaining()) {
TlsHandshakeFilter::HandshakeHeader handshake_header;
DataBuffer handshake_body;
if (!handshake_header.Parse(&parser, record_header, &handshake_body)) {
ADD_FAILURE() << "couldn't parse handshake header";
return false;
}

DataBuffer record_fragment;
// We can't fragment handshake records that are too small.
if (handshake_body.len() < 2) {
handshake_header.Write(&record_fragment, 0U, handshake_body);
WriteRecord(record_header, record_fragment);
continue;
}

size_t cut = handshake_body.len() / 2;
handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, 0U,
cut);
WriteRecord(record_header, record_fragment);

handshake_header.WriteFragment(&record_fragment, 0U, handshake_body,
cut, handshake_body.len() - cut);
WriteRecord(record_header, record_fragment);
}
return true;
}

public:
bool Split() {
TlsParser parser(input_);
while (parser.remaining()) {
TlsRecordFilter::RecordHeader header;
DataBuffer record;
if (!header.Parse(&parser, &record)) {
ADD_FAILURE() << "bad record header";
return false;
}

if (::g_ssl_gtest_verbose) {
std::cerr << "Record: " << header << ' ' << record << std::endl;
}

// Don't touch packets from a non-zero epoch. Leave these unmodified.
if ((header.sequence_number() >> 48) != 0ULL) {
cursor_ = header.Write(output_, cursor_, record);
continue;
}

// Just rewrite the sequence number (CCS only).
if (header.content_type() != kTlsHandshakeType) {
EXPECT_EQ(kTlsChangeCipherSpecType, header.content_type());
WriteRecord(header, record);
continue;
}

if (!SplitRecord(header, record)) {
return false;
}
}
return true;
}

private:
const DataBuffer& input_;
DataBuffer* output_;
size_t cursor_;
uint64_t* sequence_number_;
};

protected:
virtual PacketFilter::Action Filter(const DataBuffer& input,
DataBuffer* output) override {
if (!splitting_) {
return KEEP;
}

output->Allocate(input.len());
HandshakeSplitter splitter(input, output, &sequence_number_);
if (!splitter.Split()) {
// If splitting fails, we obviously reached encrypted packets.
// Stop splitting from that point onward.
splitting_ = false;
return KEEP;
}

return CHANGE;
}

private:
uint64_t sequence_number_;
bool splitting_;
};

TEST_P(TlsConnectDatagram, FragmentClientPackets) {
client_->SetPacketFilter(new RecordFragmenter());
Connect();
SendReceive();
}

TEST_P(TlsConnectDatagram, FragmentServerPackets) {
server_->SetPacketFilter(new RecordFragmenter());
Connect();
SendReceive();
}

} // namespace nss_test
1 change: 1 addition & 0 deletions gtests/ssl_gtest/ssl_gtest.gyp
Expand Up @@ -25,6 +25,7 @@
'ssl_exporter_unittest.cc',
'ssl_extension_unittest.cc',
'ssl_fuzz_unittest.cc',
'ssl_fragment_unittest.cc',
'ssl_gtest.cc',
'ssl_hrr_unittest.cc',
'ssl_loopback_unittest.cc',
Expand Down
50 changes: 45 additions & 5 deletions gtests/ssl_gtest/tls_filter.cc
Expand Up @@ -18,6 +18,32 @@ extern "C" {

namespace nss_test {

void TlsRecordFilter::Versioned::WriteStream(std::ostream& stream) const {
stream << (is_dtls() ? "DTLS " : "TLS ");
switch (version()) {
case 0:
stream << "(no version)";
break;
case SSL_LIBRARY_VERSION_TLS_1_0:
stream << "1.0";
break;
case SSL_LIBRARY_VERSION_DTLS_1_0_WIRE:
case SSL_LIBRARY_VERSION_TLS_1_1:
stream << (is_dtls() ? "1.0" : "1.1");
break;
case SSL_LIBRARY_VERSION_DTLS_1_2_WIRE:
case SSL_LIBRARY_VERSION_TLS_1_2:
stream << "1.2";
break;
case SSL_LIBRARY_VERSION_TLS_1_3:
stream << "1.3";
break;
default:
stream << "Invalid version: " << version();
break;
}
}

PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
DataBuffer* output) {
bool changed = false;
Expand All @@ -29,6 +55,7 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
RecordHeader header;
DataBuffer record;
if (!header.Parse(&parser, &record)) {
ADD_FAILURE() << "not a valid record";
return KEEP;
}

Expand Down Expand Up @@ -205,15 +232,28 @@ bool TlsHandshakeFilter::HandshakeHeader::Parse(
return parser->Read(body, length);
}

size_t TlsHandshakeFilter::HandshakeHeader::Write(
DataBuffer* buffer, size_t offset, const DataBuffer& body) const {
size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment(
DataBuffer* buffer, size_t offset, const DataBuffer& body,
size_t fragment_offset, size_t fragment_length) const {
EXPECT_TRUE(is_dtls());
EXPECT_GE(body.len(), fragment_offset + fragment_length);
offset = buffer->Write(offset, handshake_type(), 1);
offset = buffer->Write(offset, body.len(), 3);
offset = buffer->Write(offset, message_seq_, 2);
offset = buffer->Write(offset, fragment_offset, 3);
offset = buffer->Write(offset, fragment_length, 3);
offset =
buffer->Write(offset, body.data() + fragment_offset, fragment_length);
return offset;
}

size_t TlsHandshakeFilter::HandshakeHeader::Write(
DataBuffer* buffer, size_t offset, const DataBuffer& body) const {
if (is_dtls()) {
offset = buffer->Write(offset, message_seq_, 2);
offset = buffer->Write(offset, 0U, 3); // fragment_offset
offset = buffer->Write(offset, body.len(), 3);
return WriteFragment(buffer, offset, body, 0U, body.len());
}
offset = buffer->Write(offset, handshake_type(), 1);
offset = buffer->Write(offset, body.len(), 3);
offset = buffer->Write(offset, body);
return offset;
}
Expand Down
35 changes: 35 additions & 0 deletions gtests/ssl_gtest/tls_filter.h
Expand Up @@ -35,6 +35,8 @@ class TlsRecordFilter : public PacketFilter {
bool is_dtls() const { return IsDtls(version_); }
uint16_t version() const { return version_; }

void WriteStream(std::ostream& stream) const;

protected:
uint16_t version_;
};
Expand Down Expand Up @@ -90,6 +92,36 @@ class TlsRecordFilter : public PacketFilter {
size_t count_;
};

inline std::ostream& operator<<(std::ostream& stream,
const TlsRecordFilter::Versioned v) {
v.WriteStream(stream);
return stream;
}

inline std::ostream& operator<<(std::ostream& stream,
const TlsRecordFilter::RecordHeader& hdr) {
hdr.WriteStream(stream);
stream << ' ';
switch (hdr.content_type()) {
case kTlsChangeCipherSpecType:
stream << "CCS";
break;
case kTlsAlertType:
stream << "Alert";
break;
case kTlsHandshakeType:
stream << "Handshake";
break;
case kTlsApplicationDataType:
stream << "Data";
break;
default:
stream << '<' << hdr.content_type() << '>';
break;
}
return stream << ' ' << std::hex << hdr.sequence_number() << std::dec;
}

// Abstract filter that operates on handshake messages rather than records.
// This assumes that the handshake messages are written in a block as entire
// records and that they don't span records or anything crazy like that.
Expand All @@ -106,6 +138,9 @@ class TlsHandshakeFilter : public TlsRecordFilter {
DataBuffer* body);
size_t Write(DataBuffer* buffer, size_t offset,
const DataBuffer& body) const;
size_t WriteFragment(DataBuffer* buffer, size_t offset,
const DataBuffer& body, size_t fragment_offset,
size_t fragment_length) const;

private:
// Reads the length from the record header.
Expand Down
58 changes: 26 additions & 32 deletions lib/ssl/dtlscon.c
Expand Up @@ -235,6 +235,26 @@ dtls_RetransmitDetected(sslSocket *ss)
return rv;
}

static SECStatus
dtls_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *data, PRBool last)
{

/* At this point we are advancing our state machine, so we can free our last
* flight of messages. */
dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
ss->ssl3.hs.recvdHighWater = -1;

/* Reset the timer to the initial value if the retry counter
* is 0, per Sec. 4.2.4.1 */
dtls_CancelTimer(ss);
if (ss->ssl3.hs.rtRetries == 0) {
ss->ssl3.hs.rtTimeoutMs = DTLS_RETRANSMIT_INITIAL_MS;
}

return ssl3_HandleHandshakeMessage(ss, data, ss->ssl3.hs.msg_len,
last);
}

/* Called only from ssl3_HandleRecord, for each (deciphered) DTLS record.
* origBuf is the decrypted ssl record content and is expected to contain
* complete handshake records
Expand Down Expand Up @@ -329,23 +349,10 @@ dtls_HandleHandshake(sslSocket *ss, sslBuffer *origBuf)
ss->ssl3.hs.msg_type = (SSL3HandshakeType)type;
ss->ssl3.hs.msg_len = message_length;

/* At this point we are advancing our state machine, so
* we can free our last flight of messages */
dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
ss->ssl3.hs.recvdHighWater = -1;
dtls_CancelTimer(ss);

/* Reset the timer to the initial value if the retry counter
* is 0, per Sec. 4.2.4.1 */
if (ss->ssl3.hs.rtRetries == 0) {
ss->ssl3.hs.rtTimeoutMs = DTLS_RETRANSMIT_INITIAL_MS;
}

rv = ssl3_HandleHandshakeMessage(ss, buf.buf, ss->ssl3.hs.msg_len,
rv = dtls_HandleHandshakeMessage(ss, buf.buf,
buf.len == fragment_length);
if (rv == SECFailure) {
/* Do not attempt to process rest of messages in this record */
break;
break; /* Discard the remainder of the record. */
}
} else {
if (message_seq < ss->ssl3.hs.recvMessageSeq) {
Expand Down Expand Up @@ -446,24 +453,11 @@ dtls_HandleHandshake(sslSocket *ss, sslBuffer *origBuf)

/* If we have all the bytes, then we are good to go */
if (ss->ssl3.hs.recvdHighWater == ss->ssl3.hs.msg_len) {
ss->ssl3.hs.recvdHighWater = -1;
rv = dtls_HandleHandshakeMessage(ss, ss->ssl3.hs.msg_body.buf,
buf.len == fragment_length);

rv = ssl3_HandleHandshakeMessage(
ss,
ss->ssl3.hs.msg_body.buf, ss->ssl3.hs.msg_len,
buf.len == fragment_length);
if (rv == SECFailure)
break; /* Skip rest of record */

/* At this point we are advancing our state machine, so
* we can free our last flight of messages */
dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
dtls_CancelTimer(ss);

/* If there have been no retries this time, reset the
* timer value to the default per Section 4.2.4.1 */
if (ss->ssl3.hs.rtRetries == 0) {
ss->ssl3.hs.rtTimeoutMs = DTLS_RETRANSMIT_INITIAL_MS;
if (rv == SECFailure) {
break; /* Discard the rest of the record. */
}
}
}
Expand Down

0 comments on commit bab54a9

Please sign in to comment.