Skip to content

Commit

Permalink
Bug 1235366 - Test case for writing between receipt of CCS and Finish…
Browse files Browse the repository at this point in the history
…ed, r=ekr

--HG--
extra : rebase_source : ed7e3b4064ff8a734c90606b5c8f813da326eba8
  • Loading branch information
martinthomson committed Feb 3, 2016
1 parent 7bbfadb commit 6271d25
Show file tree
Hide file tree
Showing 12 changed files with 542 additions and 303 deletions.
13 changes: 8 additions & 5 deletions external_tests/ssl_gtest/databuffer.h
Expand Up @@ -64,7 +64,8 @@ class DataBuffer {
}

// Write will do a new allocation and expand the size of the buffer if needed.
void Write(size_t index, const uint8_t* val, size_t count) {
// Returns the offset of the end of the write.
size_t Write(size_t index, const uint8_t* val, size_t count) {
if (index + count > len_) {
size_t newlen = index + count;
uint8_t* tmp = new uint8_t[newlen]; // Always > 0.
Expand All @@ -79,18 +80,20 @@ class DataBuffer {
}
memcpy(static_cast<void*>(data_ + index),
static_cast<const void*>(val), count);
return index + count;
}

void Write(size_t index, const DataBuffer& buf) {
Write(index, buf.data(), buf.len());
size_t Write(size_t index, const DataBuffer& buf) {
return Write(index, buf.data(), buf.len());
}

// Write an integer, also performing host-to-network order conversion.
void Write(size_t index, uint32_t val, size_t count) {
// Returns the offset of the end of the write.
size_t Write(size_t index, uint32_t val, size_t count) {
assert(count <= sizeof(uint32_t));
uint32_t nvalue = htonl(val);
auto* addr = reinterpret_cast<const uint8_t*>(&nvalue);
Write(index, addr + sizeof(uint32_t) - count, count);
return Write(index, addr + sizeof(uint32_t) - count, count);
}

// This can't use the same trick as Write(), since we might be reading from a
Expand Down
199 changes: 101 additions & 98 deletions external_tests/ssl_gtest/ssl_extension_unittest.cc
Expand Up @@ -17,146 +17,151 @@ namespace nss_test {

class TlsExtensionFilter : public TlsHandshakeFilter {
protected:
virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
const DataBuffer& input, DataBuffer* output) {
if (handshake_type == kTlsHandshakeClientHello) {
virtual PacketFilter::Action FilterHandshake(
const HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
if (header.handshake_type() == kTlsHandshakeClientHello) {
TlsParser parser(input);
if (!FindClientHelloExtensions(parser, version)) {
return false;
if (!FindClientHelloExtensions(&parser, header)) {
return KEEP;
}
return FilterExtensions(parser, input, output);
return FilterExtensions(&parser, input, output);
}
if (handshake_type == kTlsHandshakeServerHello) {
if (header.handshake_type() == kTlsHandshakeServerHello) {
TlsParser parser(input);
if (!FindServerHelloExtensions(parser, version)) {
return false;
if (!FindServerHelloExtensions(&parser, header.version())) {
return KEEP;
}
return FilterExtensions(parser, input, output);
return FilterExtensions(&parser, input, output);
}
return false;
return KEEP;
}

virtual bool FilterExtension(uint16_t extension_type,
const DataBuffer& input, DataBuffer* output) = 0;
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) = 0;

public:
static bool FindClientHelloExtensions(TlsParser& parser, uint16_t version) {
if (!parser.Skip(2 + 32)) { // version + random
static bool FindClientHelloExtensions(TlsParser* parser, const Versioned& header) {
if (!parser->Skip(2 + 32)) { // version + random
return false;
}
if (!parser.SkipVariable(1)) { // session ID
if (!parser->SkipVariable(1)) { // session ID
return false;
}
if (IsDtls(version) && !parser.SkipVariable(1)) { // DTLS cookie
if (header.is_dtls() && !parser->SkipVariable(1)) { // DTLS cookie
return false;
}
if (!parser.SkipVariable(2)) { // cipher suites
if (!parser->SkipVariable(2)) { // cipher suites
return false;
}
if (!parser.SkipVariable(1)) { // compression methods
if (!parser->SkipVariable(1)) { // compression methods
return false;
}
return true;
}

static bool FindServerHelloExtensions(TlsParser& parser, uint16_t version) {
if (!parser.Skip(2 + 32)) { // version + random
static bool FindServerHelloExtensions(TlsParser* parser, uint16_t version) {
if (!parser->Skip(2 + 32)) { // version + random
return false;
}
if (!parser.SkipVariable(1)) { // session ID
if (!parser->SkipVariable(1)) { // session ID
return false;
}
if (!parser.Skip(2)) { // cipher suite
if (!parser->Skip(2)) { // cipher suite
return false;
}
if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) {
if (!parser.Skip(1)) { // compression method
if (!parser->Skip(1)) { // compression method
return false;
}
}
return true;
}

private:
bool FilterExtensions(TlsParser& parser,
const DataBuffer& input, DataBuffer* output) {
size_t length_offset = parser.consumed();
PacketFilter::Action FilterExtensions(TlsParser* parser,
const DataBuffer& input,
DataBuffer* output) {
size_t length_offset = parser->consumed();
uint32_t all_extensions;
if (!parser.Read(&all_extensions, 2)) {
return false; // no extensions, odd but OK
if (!parser->Read(&all_extensions, 2)) {
return KEEP; // no extensions, odd but OK
}
if (all_extensions != parser.remaining()) {
return false; // malformed
if (all_extensions != parser->remaining()) {
return KEEP; // malformed
}

bool changed = false;

// Write out the start of the message.
output->Allocate(input.len());
output->Write(0, input.data(), parser.consumed());
size_t output_offset = parser.consumed();
size_t offset = output->Write(0, input.data(), parser->consumed());

while (parser.remaining()) {
while (parser->remaining()) {
uint32_t extension_type;
if (!parser.Read(&extension_type, 2)) {
return false; // malformed
if (!parser->Read(&extension_type, 2)) {
return KEEP; // malformed
}

// Copy extension type.
output->Write(output_offset, extension_type, 2);

DataBuffer extension;
if (!parser.ReadVariable(&extension, 2)) {
return false; // malformed
if (!parser->ReadVariable(&extension, 2)) {
return KEEP; // malformed
}

DataBuffer filtered;
PacketFilter::Action action = FilterExtension(extension_type, extension,
&filtered);
if (action == DROP) {
changed = true;
std::cerr << "extension drop: " << extension << std::endl;
continue;
}

const DataBuffer* source = &extension;
if (action == CHANGE) {
EXPECT_GT(0x10000, filtered.len());
changed = true;
std::cerr << "extension old: " << extension << std::endl;
std::cerr << "extension new: " << filtered << std::endl;
source = &filtered;
}
output_offset = ApplyFilter(static_cast<uint16_t>(extension_type), extension,
output, output_offset + 2, &changed);

// Write out extension.
offset = output->Write(offset, extension_type, 2);
offset = output->Write(offset, source->len(), 2);
offset = output->Write(offset, *source);
}
output->Truncate(output_offset);
output->Truncate(offset);

if (changed) {
size_t newlen = output->len() - length_offset - 2;
EXPECT_GT(0x10000, newlen);
if (newlen >= 0x10000) {
return false; // bad: size increased too much
return KEEP; // bad: size increased too much
}
output->Write(length_offset, newlen, 2);
return CHANGE;
}
return changed;
}

size_t ApplyFilter(uint16_t extension_type, const DataBuffer& extension,
DataBuffer* output, size_t offset, bool* changed) {
const DataBuffer* source = &extension;
DataBuffer filtered;
if (FilterExtension(extension_type, extension, &filtered) &&
filtered.len() < 0x10000) {
*changed = true;
std::cerr << "extension old: " << extension << std::endl;
std::cerr << "extension new: " << filtered << std::endl;
source = &filtered;
}

output->Write(offset, source->len(), 2);
output->Write(offset + 2, *source);
return offset + 2 + source->len();
return KEEP;
}
};

class TlsExtensionTruncator : public TlsExtensionFilter {
public:
TlsExtensionTruncator(uint16_t extension, size_t length)
: extension_(extension), length_(length) {}
virtual bool FilterExtension(uint16_t extension_type,
const DataBuffer& input, DataBuffer* output) {
virtual PacketFilter::Action FilterExtension(
uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
if (extension_type != extension_) {
return false;
return KEEP;
}
if (input.len() <= length_) {
return false;
return KEEP;
}

output->Assign(input.data(), length_);
return true;
return CHANGE;
}
private:
uint16_t extension_;
Expand All @@ -167,15 +172,15 @@ class TlsExtensionDamager : public TlsExtensionFilter {
public:
TlsExtensionDamager(uint16_t extension, size_t index)
: extension_(extension), index_(index) {}
virtual bool FilterExtension(uint16_t extension_type,
const DataBuffer& input, DataBuffer* output) {
virtual PacketFilter::Action FilterExtension(
uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
if (extension_type != extension_) {
return false;
return KEEP;
}

*output = input;
output->data()[index_] += 73; // Increment selected for maximum damage
return true;
return CHANGE;
}
private:
uint16_t extension_;
Expand All @@ -186,14 +191,14 @@ class TlsExtensionReplacer : public TlsExtensionFilter {
public:
TlsExtensionReplacer(uint16_t extension, const DataBuffer& data)
: extension_(extension), data_(data) {}
virtual bool FilterExtension(uint16_t extension_type,
const DataBuffer& input, DataBuffer* output) {
virtual PacketFilter::Action FilterExtension(
uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
if (extension_type != extension_) {
return false;
return KEEP;
}

*output = data_;
return true;
return CHANGE;
}
private:
const uint16_t extension_;
Expand All @@ -205,36 +210,31 @@ class TlsExtensionInjector : public TlsHandshakeFilter {
TlsExtensionInjector(uint16_t ext, DataBuffer& data)
: extension_(ext), data_(data) {}

virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
const DataBuffer& input, DataBuffer* output) {
virtual PacketFilter::Action FilterHandshake(
const HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
size_t offset;
if (handshake_type == kTlsHandshakeClientHello) {
if (header.handshake_type() == kTlsHandshakeClientHello) {
TlsParser parser(input);
if (!TlsExtensionFilter::FindClientHelloExtensions(parser, version)) {
return false;
if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) {
return KEEP;
}
offset = parser.consumed();
} else if (handshake_type == kTlsHandshakeServerHello) {
} else if (header.handshake_type() == kTlsHandshakeServerHello) {
TlsParser parser(input);
if (!TlsExtensionFilter::FindServerHelloExtensions(parser, version)) {
return false;
if (!TlsExtensionFilter::FindServerHelloExtensions(&parser, header.version())) {
return KEEP;
}
offset = parser.consumed();
} else {
return false;
return KEEP;
}

*output = input;

std::cerr << "Pre:" << input << std::endl;
std::cerr << "Lof:" << offset << std::endl;

// Increase the size of the extensions.
uint16_t* len_addr = reinterpret_cast<uint16_t*>(output->data() + offset);
std::cerr << "L-p:" << ntohs(*len_addr) << std::endl;
*len_addr = htons(ntohs(*len_addr) + data_.len() + 4);
std::cerr << "L-i:" << ntohs(*len_addr) << std::endl;


// Insert the extension type and length.
DataBuffer type_length;
Expand All @@ -246,8 +246,7 @@ class TlsExtensionInjector : public TlsHandshakeFilter {
// Insert the payload.
output->Splice(data_, offset + 6);

std::cerr << "Aft:" << *output << std::endl;
return true;
return CHANGE;
}

private:
Expand All @@ -260,12 +259,12 @@ class TlsExtensionCapture : public TlsExtensionFilter {
TlsExtensionCapture(uint16_t ext)
: extension_(ext), data_() {}

virtual bool FilterExtension(uint16_t extension_type,
const DataBuffer& input, DataBuffer* output) {
virtual PacketFilter::Action FilterExtension(
uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
if (extension_type == extension_) {
data_.Assign(input);
}
return false;
return KEEP;
}

const DataBuffer& extension() const { return data_; }
Expand Down Expand Up @@ -628,10 +627,14 @@ class SignedCertificateTimestampsExtractor {
public:
SignedCertificateTimestampsExtractor(TlsAgent& client) {
client.SetAuthCertificateCallback(
[&](TlsAgent& agent, PRBool checksig, PRBool isServer) {
[&](TlsAgent& agent, PRBool checksig, PRBool isServer) -> SECStatus {
const SECItem *scts = SSL_PeerSignedCertTimestamps(agent.ssl_fd());
ASSERT_TRUE(scts);
EXPECT_TRUE(scts);
if (!scts) {
return SECFailure;
}
auth_timestamps_.reset(new DataBuffer(scts->data, scts->len));
return SECSuccess;
}
);
client.SetHandshakeCallback(
Expand Down

0 comments on commit 6271d25

Please sign in to comment.