Skip to content

Commit

Permalink
Bug 1143900 - Refactor NSS TLS unit tests to handle versions more cle…
Browse files Browse the repository at this point in the history
…anly. r=mt
  • Loading branch information
ekr committed Mar 17, 2015
1 parent 3449bd8 commit 74b0a40
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 34 deletions.
42 changes: 18 additions & 24 deletions external_tests/ssl_gtest/ssl_loopback_unittest.cc
Expand Up @@ -152,29 +152,6 @@ TEST_P(TlsConnectGeneric, ConnectClientNoneServerBoth) {
CheckResumption(RESUME_NONE);
}

TEST_P(TlsConnectGeneric, ConnectTLS_1_1_Only) {
EnsureTlsSetup();
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_1);

server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_1);

Connect();

client_->CheckVersion(SSL_LIBRARY_VERSION_TLS_1_1);
}

TEST_P(TlsConnectGeneric, ConnectTLS_1_2_Only) {
EnsureTlsSetup();
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_2);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_2);
Connect();
client_->CheckVersion(SSL_LIBRARY_VERSION_TLS_1_2);
}

TEST_P(TlsConnectGeneric, ConnectAlpn) {
EnableAlpn();
Connect();
Expand Down Expand Up @@ -257,7 +234,24 @@ TEST_F(TlsConnectTest, ConnectECDHETwiceNewKey) {
dhe1.public_key_.len())));
}

TEST_P(TlsConnectGenericSingleVersion, Connect) {
Connect();
}

static const std::string kTls[] = {"TLS"};
static const std::string kTlsDtls[] = {"TLS", "DTLS"};
static const uint16_t kTlsV10[] = {SSL_LIBRARY_VERSION_TLS_1_0};
static const uint16_t kTlsV11V12[] = {SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_2};
INSTANTIATE_TEST_CASE_P(Variants, TlsConnectGeneric,
::testing::Values("TLS", "DTLS"));
::testing::ValuesIn(kTlsDtls));
INSTANTIATE_TEST_CASE_P(VersionsStream, TlsConnectGenericSingleVersion,
::testing::Combine(
::testing::ValuesIn(kTls),
::testing::ValuesIn(kTlsV10)));
INSTANTIATE_TEST_CASE_P(VersionsByVariants, TlsConnectGenericSingleVersion,
::testing::Combine(
::testing::ValuesIn(kTlsDtls),
::testing::ValuesIn(kTlsV11V12)));

} // namespace nspr_test
18 changes: 13 additions & 5 deletions external_tests/ssl_gtest/tls_agent.cc
Expand Up @@ -58,8 +58,12 @@ bool TlsAgent::EnsureTlsSetup() {
if (rv != SECSuccess) return false;
}

SECStatus rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook,
reinterpret_cast<void*>(this));
SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;

rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook,
reinterpret_cast<void*>(this));
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;

Expand Down Expand Up @@ -104,8 +108,13 @@ void TlsAgent::SetSessionCacheEnabled(bool en) {
}

void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
SSLVersionRange range = {minver, maxver};
ASSERT_EQ(SECSuccess, SSL_VersionRangeSet(ssl_fd_, &range));
vrange_.min = minver;
vrange_.max = maxver;

if (ssl_fd_) {
SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
ASSERT_EQ(SECSuccess, rv);
}
}

void TlsAgent::CheckKEAType(SSLKEAType type) const {
Expand Down Expand Up @@ -154,7 +163,6 @@ void TlsAgent::CheckSrtp() {
ASSERT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
}


void TlsAgent::Handshake() {
SECStatus rv = SSL_ForceHandshake(ssl_fd_);
if (rv == SECSuccess) {
Expand Down
17 changes: 17 additions & 0 deletions external_tests/ssl_gtest/tls_agent.h
Expand Up @@ -14,6 +14,9 @@

#include "test_io.h"

#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"

namespace nss_test {

#define LOG(msg) std::cerr << name_ << ": " << msg << std::endl
Expand All @@ -40,6 +43,10 @@ class TlsAgent : public PollTarget {
state_(INIT) {
memset(&info_, 0, sizeof(info_));
memset(&csinfo_, 0, sizeof(csinfo_));
SECStatus rv = SSL_VersionRangeGetDefault(mode_ == STREAM ?
ssl_variant_stream : ssl_variant_datagram,
&vrange_);
EXPECT_EQ(SECSuccess, rv);
}

~TlsAgent() {
Expand Down Expand Up @@ -95,6 +102,9 @@ class TlsAgent : public PollTarget {

PRFileDesc* ssl_fd() { return ssl_fd_; }

uint16_t min_version() const { return vrange_.min; }
uint16_t max_version() const { return vrange_.max; }

bool version(uint16_t* version) const {
if (state_ != CONNECTED) return false;

Expand All @@ -103,6 +113,12 @@ class TlsAgent : public PollTarget {
return true;
}

uint16_t version() const {
EXPECT_EQ(CONNECTED, state_);

return info_.protocolVersion;
}

bool cipher_suite(int16_t* cipher_suite) const {
if (state_ != CONNECTED) return false;

Expand Down Expand Up @@ -163,6 +179,7 @@ class TlsAgent : public PollTarget {
State state_;
SSLChannelInfo info_;
SSLCipherSuiteInfo csinfo_;
SSLVersionRange vrange_;
};

} // namespace nss_test
Expand Down
23 changes: 19 additions & 4 deletions external_tests/ssl_gtest/tls_connect.cc
Expand Up @@ -17,7 +17,9 @@ namespace nss_test {
TlsConnectTestBase::TlsConnectTestBase(Mode mode)
: mode_(mode),
client_(new TlsAgent("client", TlsAgent::CLIENT, mode_)),
server_(new TlsAgent("server", TlsAgent::SERVER, mode_)) {}
server_(new TlsAgent("server", TlsAgent::SERVER, mode_)),
version_(0),
session_ids_() {}

TlsConnectTestBase::~TlsConnectTestBase() {
delete client_;
Expand Down Expand Up @@ -58,6 +60,11 @@ void TlsConnectTestBase::Reset() {
client_ = new TlsAgent("client", TlsAgent::CLIENT, mode_);
server_ = new TlsAgent("server", TlsAgent::SERVER, mode_);

if (version_) {
client_->SetVersionRange(version_, version_);
server_->SetVersionRange(version_, version_);
}

Init();
}

Expand All @@ -72,14 +79,21 @@ void TlsConnectTestBase::Handshake() {
client_->Handshake();
server_->Handshake();

ASSERT_TRUE_WAIT(client_->state() != TlsAgent::CONNECTING &&
server_->state() != TlsAgent::CONNECTING,
ASSERT_TRUE_WAIT((client_->state() != TlsAgent::CONNECTING) &&
(server_->state() != TlsAgent::CONNECTING),
5000);

}

void TlsConnectTestBase::Connect() {
Handshake();

// Check the version is as expected
ASSERT_EQ(client_->version(), server_->version());
ASSERT_EQ(std::min(client_->max_version(),
server_->max_version()),
client_->version());

ASSERT_EQ(TlsAgent::CONNECTED, client_->state());
ASSERT_EQ(TlsAgent::CONNECTED, server_->state());

Expand All @@ -90,7 +104,8 @@ void TlsConnectTestBase::Connect() {
ASSERT_TRUE(ret);
ASSERT_EQ(cipher_suite1, cipher_suite2);

std::cerr << "Connected with cipher suite " << client_->cipher_suite_name()
std::cerr << "Connected with version " << client_->version()
<< " cipher suite " << client_->cipher_suite_name()
<< std::endl;

// Check and store session ids.
Expand Down
22 changes: 21 additions & 1 deletion external_tests/ssl_gtest/tls_connect.h
Expand Up @@ -7,6 +7,8 @@
#ifndef tls_connect_h_
#define tls_connect_h_

#include <tuple>

#include "sslt.h"

#include "tls_agent.h"
Expand Down Expand Up @@ -46,11 +48,12 @@ class TlsConnectTestBase : public ::testing::Test {
void EnableAlpn();
void EnableSrtp();
void CheckSrtp();

protected:

Mode mode_;
TlsAgent* client_;
TlsAgent* server_;
uint16_t version_;
std::vector<std::vector<uint8_t>> session_ids_;
};

Expand All @@ -74,6 +77,23 @@ class TlsConnectGeneric : public TlsConnectTestBase,
TlsConnectGeneric();
};

// A generic test class that is a single version of TLS. This is configured
// in ssl_loopback_unittest.cc. All uses of this should use TEST_P().
class TlsConnectGenericSingleVersion : public TlsConnectTestBase,
public ::testing::WithParamInterface<
std::tuple<std::string,uint16_t>> {
public:
TlsConnectGenericSingleVersion() : TlsConnectTestBase(
std::get<0>(GetParam()) == "TLS" ? STREAM : DGRAM) {
uint16_t version = std::get<1>(GetParam());

std::cerr << "Version : " << version << std::endl;
client_->SetVersionRange(version, version);
server_->SetVersionRange(version, version);
version_ = version;
}
};

} // namespace nss_test

#endif

0 comments on commit 74b0a40

Please sign in to comment.