Skip to content

Commit

Permalink
Bug 1320962 - Add decryption and reencryption to TLS 1.3 gtests. r=mt
Browse files Browse the repository at this point in the history
  • Loading branch information
ekr committed Nov 29, 2016
1 parent 622d905 commit b8889f5
Show file tree
Hide file tree
Showing 19 changed files with 584 additions and 110 deletions.
34 changes: 34 additions & 0 deletions gtests/ssl_gtest/libssl_internals.c
Expand Up @@ -313,3 +313,37 @@ SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) {

return groupDef->keaType;
}

SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd,
sslCipherSpecChangedFunc func,
void *arg) {
sslSocket *ss;

ss = ssl_FindSocket(fd);
if (!ss) {
return SECFailure;
}

ss->ssl3.changedCipherSpecFunc = func;
ss->ssl3.changedCipherSpecArg = arg;

return SECSuccess;
}

static ssl3KeyMaterial *GetKeyingMaterial(PRBool isServer,
ssl3CipherSpec *spec) {
return isServer ? &spec->server : &spec->client;
}

PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec) {
return GetKeyingMaterial(isServer, spec)->write_key;
}

SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer,
ssl3CipherSpec *spec) {
return spec->cipher_def->calg;
}

unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec) {
return GetKeyingMaterial(isServer, spec)->write_iv;
}
13 changes: 13 additions & 0 deletions gtests/ssl_gtest/libssl_internals.h
Expand Up @@ -38,4 +38,17 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to);
SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra);
SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group);

typedef struct ssl3CipherSpecStr ssl3CipherSpec;

typedef void (*sslCipherSpecChangedFunc)(void *arg, PRBool sending,
ssl3CipherSpec *newSpec);

SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd,
sslCipherSpecChangedFunc func,
void *arg);
PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec);
SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer,
ssl3CipherSpec *spec);
unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec);

#endif // ndef libssl_internals_h_
1 change: 1 addition & 0 deletions gtests/ssl_gtest/manifest.mn
Expand Up @@ -41,6 +41,7 @@ CPPSRCS = \
tls_hkdf_unittest.cc \
tls_filter.cc \
tls_parser.cc \
tls_protect.cc \
$(NULL)

INCLUDES += -I$(CORE_DEPTH)/gtests/google_test/gtest/include \
Expand Down
6 changes: 3 additions & 3 deletions gtests/ssl_gtest/ssl_auth_unittest.cc
Expand Up @@ -289,7 +289,7 @@ class BeforeFinished : public TlsRecordFilter {
state_(BEFORE_CCS) {}

protected:
virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) {
switch (state_) {
Expand Down Expand Up @@ -507,7 +507,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());

// Remove this before closing or the close_notify alert will trigger it.
client_->SetPacketFilter(nullptr);
client_->DeletePacketFilter();
}

// TLS 1.3 handles a delayed AuthComplete callback differently since the
Expand All @@ -528,7 +528,7 @@ TEST_P(TlsConnectTls13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());

// This should allow the handshake to complete now.
client_->SetPacketFilter(nullptr);
client_->DeletePacketFilter();
EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
client_->Handshake(); // Send Finished
server_->Handshake(); // Transition to connected and send NewSessionTicket
Expand Down
13 changes: 6 additions & 7 deletions gtests/ssl_gtest/ssl_fragment_unittest.cc
Expand Up @@ -33,11 +33,11 @@ class RecordFragmenter : public PacketFilter {
sequence_number_(sequence_number) {}

private:
void WriteRecord(TlsRecordFilter::RecordHeader& record_header,
void WriteRecord(TlsRecordHeader& record_header,
DataBuffer& record_fragment) {
TlsRecordFilter::RecordHeader fragment_header(
record_header.version(), record_header.content_type(),
*sequence_number_);
TlsRecordHeader fragment_header(record_header.version(),
record_header.content_type(),
*sequence_number_);
++*sequence_number_;
if (::g_ssl_gtest_verbose) {
std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment
Expand All @@ -46,8 +46,7 @@ class RecordFragmenter : public PacketFilter {
cursor_ = fragment_header.Write(output_, cursor_, record_fragment);
}

bool SplitRecord(TlsRecordFilter::RecordHeader& record_header,
DataBuffer& record) {
bool SplitRecord(TlsRecordHeader& record_header, DataBuffer& record) {
TlsParser parser(record);
while (parser.remaining()) {
TlsHandshakeFilter::HandshakeHeader handshake_header;
Expand Down Expand Up @@ -81,7 +80,7 @@ class RecordFragmenter : public PacketFilter {
bool Split() {
TlsParser parser(input_);
while (parser.remaining()) {
TlsRecordFilter::RecordHeader header;
TlsRecordHeader header;
DataBuffer record;
if (!header.Parse(&parser, &record)) {
ADD_FAILURE() << "bad record header";
Expand Down
6 changes: 3 additions & 3 deletions gtests/ssl_gtest/ssl_fuzz_unittest.cc
Expand Up @@ -23,7 +23,7 @@ class TlsApplicationDataRecorder : public TlsRecordFilter {
public:
TlsApplicationDataRecorder() : buffer_() {}

virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output) {
if (header.content_type() == kTlsApplicationDataType) {
Expand Down Expand Up @@ -130,8 +130,8 @@ TEST_P(TlsConnectGeneric, Fuzz_DeterministicTranscript) {
Connect();

// Ensure the filters go away before |buffer| does.
client_->SetPacketFilter(nullptr);
server_->SetPacketFilter(nullptr);
client_->DeletePacketFilter();
server_->DeletePacketFilter();

if (last.len() > 0) {
EXPECT_EQ(last, buffer);
Expand Down
3 changes: 2 additions & 1 deletion gtests/ssl_gtest/ssl_gtest.gyp
Expand Up @@ -40,7 +40,8 @@
'tls_connect.cc',
'tls_filter.cc',
'tls_hkdf_unittest.cc',
'tls_parser.cc'
'tls_parser.cc',
'tls_protect.cc'
],
'dependencies': [
'<(DEPTH)/exports.gyp:nss_exports',
Expand Down
9 changes: 4 additions & 5 deletions gtests/ssl_gtest/ssl_loopback_unittest.cc
Expand Up @@ -161,16 +161,15 @@ TEST_P(TlsConnectDatagram, TestDtlsHolddownExpiry) {
class TlsPreCCSHeaderInjector : public TlsRecordFilter {
public:
TlsPreCCSHeaderInjector() {}
virtual PacketFilter::Action FilterRecord(const RecordHeader& record_header,
const DataBuffer& input,
size_t* offset,
DataBuffer* output) override {
virtual PacketFilter::Action FilterRecord(
const TlsRecordHeader& record_header, const DataBuffer& input,
size_t* offset, DataBuffer* output) override {
if (record_header.content_type() != kTlsChangeCipherSpecType) return KEEP;

std::cerr << "Injecting Finished header before CCS\n";
const uint8_t hhdr[] = {kTlsHandshakeFinished, 0x00, 0x00, 0x0c};
DataBuffer hhdr_buf(hhdr, sizeof(hhdr));
RecordHeader nhdr(record_header.version(), kTlsHandshakeType, 0);
TlsRecordHeader nhdr(record_header.version(), kTlsHandshakeType, 0);
*offset = nhdr.Write(output, *offset, hhdr_buf);
*offset = record_header.Write(output, *offset, input);
return CHANGE;
Expand Down
1 change: 1 addition & 0 deletions gtests/ssl_gtest/ssl_resumption_unittest.cc
Expand Up @@ -21,6 +21,7 @@ extern "C" {
#include "tls_connect.h"
#include "tls_filter.h"
#include "tls_parser.h"
#include "tls_protect.h"

namespace nss_test {

Expand Down
72 changes: 68 additions & 4 deletions gtests/ssl_gtest/ssl_skip_unittest.cc
Expand Up @@ -28,9 +28,9 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter {
protected:
// Takes a record; if it is a handshake record, it removes the first handshake
// message that is of handshake_type_ type.
virtual PacketFilter::Action FilterRecord(const RecordHeader& record_header,
const DataBuffer& input,
DataBuffer* output) {
virtual PacketFilter::Action FilterRecord(
const TlsRecordHeader& record_header, const DataBuffer& input,
DataBuffer* output) {
if (record_header.content_type() != kTlsHandshakeType) {
return KEEP;
}
Expand Down Expand Up @@ -98,6 +98,40 @@ class TlsSkipTest
}
};

class Tls13SkipTest : public TlsConnectTestBase,
public ::testing::WithParamInterface<std::string> {
protected:
Tls13SkipTest()
: TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}

void ServerSkipTest(TlsRecordFilter* filter, int32_t error) {
EnsureTlsSetup();
server_->SetPacketFilter(filter);
filter->EnableDecryption();
if (mode_ == STREAM) {
ConnectExpectFail();
} else {
ConnectExpectFailOneSide(TlsAgent::CLIENT);
}
client_->CheckErrorCode(error);
if (mode_ == STREAM) {
server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
} else {
ASSERT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
}
}

void ClientSkipTest(TlsRecordFilter* filter, int32_t error) {
EnsureTlsSetup();
client_->SetPacketFilter(filter);
filter->EnableDecryption();
ConnectExpectFailOneSide(TlsAgent::SERVER);

server_->CheckErrorCode(error);
ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
}
};

TEST_P(TlsSkipTest, SkipCertificateRsa) {
EnableOnlyStaticRsaCiphers();
ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
Expand Down Expand Up @@ -148,11 +182,41 @@ TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) {
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}

TEST_P(Tls13SkipTest, SkipEncryptedExtensions) {
ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeEncryptedExtensions),
SSL_ERROR_RX_UNEXPECTED_CERTIFICATE);
}

TEST_P(Tls13SkipTest, SkipServerCertificate) {
ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate),
SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
}

TEST_P(Tls13SkipTest, SkipServerCertificateVerify) {
ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificateVerify),
SSL_ERROR_RX_UNEXPECTED_FINISHED);
}

TEST_P(Tls13SkipTest, SkipClientCertificate) {
client_->SetupClientAuth();
server_->RequestClientAuth(true);
ClientSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate),
SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
}

TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
client_->SetupClientAuth();
server_->RequestClientAuth(true);
ClientSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificateVerify),
SSL_ERROR_RX_UNEXPECTED_FINISHED);
}

INSTANTIATE_TEST_CASE_P(SkipTls10, TlsSkipTest,
::testing::Combine(TlsConnectTestBase::kTlsModesStream,
TlsConnectTestBase::kTlsV10));
INSTANTIATE_TEST_CASE_P(SkipVariants, TlsSkipTest,
::testing::Combine(TlsConnectTestBase::kTlsModesAll,
TlsConnectTestBase::kTlsV11V12));

INSTANTIATE_TEST_CASE_P(Skip13Variants, Tls13SkipTest,
TlsConnectTestBase::kTlsModesAll);
} // namespace nss_test
10 changes: 9 additions & 1 deletion gtests/ssl_gtest/tls_agent.h
Expand Up @@ -14,6 +14,7 @@
#include <iostream>

#include "test_io.h"
#include "tls_filter.h"

#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
Expand Down Expand Up @@ -85,10 +86,17 @@ class TlsAgent : public PollTarget {

void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); }

void SetPacketFilter(TlsRecordFilter* filter) {
filter->SetAgent(this);
adapter_->SetPacketFilter(filter);
}

void SetPacketFilter(PacketFilter* filter) {
adapter_->SetPacketFilter(filter);
}

void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); }

void StartConnect(PRFileDesc* model = nullptr);
void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
size_t kea_size = 0) const;
Expand Down Expand Up @@ -171,7 +179,7 @@ class TlsAgent : public PollTarget {

static const char* state_str(State state) { return states[state]; }

PRFileDesc* ssl_fd() { return ssl_fd_; }
PRFileDesc* ssl_fd() const { return ssl_fd_; }
DummyPrSocket* adapter() { return adapter_; }

bool is_compressed() const {
Expand Down
16 changes: 16 additions & 0 deletions gtests/ssl_gtest/tls_connect.cc
Expand Up @@ -373,6 +373,22 @@ void TlsConnectTestBase::ConnectExpectFail() {
ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state());
}

void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) {
server_->StartConnect();
client_->StartConnect();
client_->SetServerKeyBits(server_->server_key_bits());
client_->Handshake();
server_->Handshake();
TlsAgent* fail_agent;

if (failing_side == TlsAgent::CLIENT) {
fail_agent = client_;
} else {
fail_agent = server_;
}
ASSERT_TRUE_WAIT(fail_agent->state() == TlsAgent::STATE_ERROR, 5000);
}

void TlsConnectTestBase::ConfigureVersion(uint16_t version) {
client_->SetVersionRange(version, version);
server_->SetVersionRange(version, version);
Expand Down
1 change: 1 addition & 0 deletions gtests/ssl_gtest/tls_connect.h
Expand Up @@ -68,6 +68,7 @@ class TlsConnectTestBase : public ::testing::Test {
void CheckConnected();
// Connect and expect it to fail.
void ConnectExpectFail();
void ConnectExpectFailOneSide(TlsAgent::Role failingSide);
void ConnectWithCipherSuite(uint16_t cipher_suite);
// Check that the keys used in the handshake match expectations.
void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
Expand Down

0 comments on commit b8889f5

Please sign in to comment.