diff --git a/cpputil/tls_parser.h b/cpputil/tls_parser.h index 2708b4dbf3..15ba3b175b 100644 --- a/cpputil/tls_parser.h +++ b/cpputil/tls_parser.h @@ -16,7 +16,6 @@ #include #endif #include "databuffer.h" - #include "sslt.h" namespace nss_test { @@ -79,6 +78,10 @@ static const uint8_t kTls13PskDhKe = 1; static const uint8_t kTls13PskAuth = 0; static const uint8_t kTls13PskSignAuth = 1; +inline std::ostream& operator<<(std::ostream& os, SSLProtocolVariant v) { + return os << ((v == ssl_variant_stream) ? "TLS" : "DTLS"); +} + inline bool IsDtls(uint16_t version) { return (version & 0x8000) == 0x8000; } inline uint16_t NormalizeTlsVersion(uint16_t version) { @@ -135,10 +138,6 @@ class TlsParser { size_t offset_; }; -inline std::ostream& operator<<(std::ostream& os, SSLProtocolVariant v) { - return os << ((v == ssl_variant_stream) ? "TLS" : "DTLS"); -} - } // namespace nss_test #endif diff --git a/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/gtests/ssl_gtest/ssl_0rtt_unittest.cc index a3c0075cc5..85b7011a1c 100644 --- a/gtests/ssl_gtest/ssl_0rtt_unittest.cc +++ b/gtests/ssl_gtest/ssl_0rtt_unittest.cc @@ -227,7 +227,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngrade) { client_->Set0RttEnabled(true); client_->ExpectSendAlert(kTlsAlertIllegalParameter); - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); } client_->Handshake(); @@ -237,7 +237,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngrade) { // DTLS will timeout as we bump the epoch when installing the early app data // cipher suite. Thus the encrypted alert will be ignored. - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { // The client sends an encrypted alert message. ASSERT_TRUE_WAIT( (server_->error_code() == SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA), @@ -269,7 +269,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) { client_->Set0RttEnabled(true); ZeroRttSendReceive(true, false, [this]() { client_->ExpectSendAlert(kTlsAlertIllegalParameter); - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); } return true; @@ -282,7 +282,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) { // DTLS will timeout as we bump the epoch when installing the early app data // cipher suite. Thus the encrypted alert will be ignored. - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { // The server sends an alert when receiving the early app data record. ASSERT_TRUE_WAIT( (server_->error_code() == SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA), @@ -316,7 +316,7 @@ TEST_P(TlsConnectTls13, SendTooMuchEarlyData) { PRInt32 sent; // Writing more than the limit will succeed in TLS, but fail in DTLS. - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { sent = PR_Write(client_->ssl_fd(), big_message, static_cast(strlen(big_message))); } else { @@ -377,7 +377,7 @@ TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) { const PRInt32 message_len = static_cast(strlen(message)); EXPECT_EQ(message_len, PR_Write(client_->ssl_fd(), message, message_len)); - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { // This error isn't fatal for DTLS. ExpectAlert(server_, kTlsAlertUnexpectedMessage); } @@ -388,13 +388,13 @@ TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) { // Attempt to read early data. std::vector buf(strlen(message) + 1); EXPECT_GT(0, PR_Read(server_->ssl_fd(), buf.data(), buf.capacity())); - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { server_->CheckErrorCode(SSL_ERROR_TOO_MUCH_EARLY_DATA); } client_->Handshake(); // Process the handshake. client_->Handshake(); // Process the alert. - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); } } diff --git a/gtests/ssl_gtest/ssl_agent_unittest.cc b/gtests/ssl_gtest/ssl_agent_unittest.cc index 13861c543f..5035a338d2 100644 --- a/gtests/ssl_gtest/ssl_agent_unittest.cc +++ b/gtests/ssl_gtest/ssl_agent_unittest.cc @@ -204,14 +204,15 @@ TEST_F(TlsAgentStreamTestServer, Set0RttOptionClientHelloThenRead) { ProcessMessage(buffer, TlsAgent::STATE_ERROR, SSL_ERROR_BAD_MAC_READ); } -INSTANTIATE_TEST_CASE_P(AgentTests, TlsAgentTest, - ::testing::Combine(TlsAgentTestBase::kTlsRolesAll, - TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_CASE_P( + AgentTests, TlsAgentTest, + ::testing::Combine(TlsAgentTestBase::kTlsRolesAll, + TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); INSTANTIATE_TEST_CASE_P(ClientTests, TlsAgentTestClient, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsVAll)); INSTANTIATE_TEST_CASE_P(ClientTests13, TlsAgentTestClient13, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV13)); } // namespace nss_test diff --git a/gtests/ssl_gtest/ssl_auth_unittest.cc b/gtests/ssl_gtest/ssl_auth_unittest.cc index aa2dc195bc..dbcbc9aa33 100644 --- a/gtests/ssl_gtest/ssl_auth_unittest.cc +++ b/gtests/ssl_gtest/ssl_auth_unittest.cc @@ -716,8 +716,8 @@ TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPss) { &ServerCertDataRsaPss)); } -// mode, version, certificate, auth type, signature scheme -typedef std::tuple SignatureSchemeProfile; @@ -778,7 +778,7 @@ TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigBoth) { INSTANTIATE_TEST_CASE_P( SignatureSchemeRsa, TlsSignatureSchemeConfiguration, ::testing::Combine( - TlsConnectTestBase::kTlsModesAll, TlsConnectTestBase::kTlsV12Plus, + TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus, ::testing::Values(TlsAgent::kServerRsaSign), ::testing::Values(ssl_auth_rsa_sign), ::testing::Values(ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384, @@ -787,42 +787,42 @@ INSTANTIATE_TEST_CASE_P( // PSS with SHA-512 needs a bigger key to work. INSTANTIATE_TEST_CASE_P( SignatureSchemeBigRsa, TlsSignatureSchemeConfiguration, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus, ::testing::Values(TlsAgent::kRsa2048), ::testing::Values(ssl_auth_rsa_sign), ::testing::Values(ssl_sig_rsa_pss_sha512))); INSTANTIATE_TEST_CASE_P( SignatureSchemeRsaSha1, TlsSignatureSchemeConfiguration, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12, ::testing::Values(TlsAgent::kServerRsa), ::testing::Values(ssl_auth_rsa_sign), ::testing::Values(ssl_sig_rsa_pkcs1_sha1))); INSTANTIATE_TEST_CASE_P( SignatureSchemeEcdsaP256, TlsSignatureSchemeConfiguration, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus, ::testing::Values(TlsAgent::kServerEcdsa256), ::testing::Values(ssl_auth_ecdsa), ::testing::Values(ssl_sig_ecdsa_secp256r1_sha256))); INSTANTIATE_TEST_CASE_P( SignatureSchemeEcdsaP384, TlsSignatureSchemeConfiguration, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus, ::testing::Values(TlsAgent::kServerEcdsa384), ::testing::Values(ssl_auth_ecdsa), ::testing::Values(ssl_sig_ecdsa_secp384r1_sha384))); INSTANTIATE_TEST_CASE_P( SignatureSchemeEcdsaP521, TlsSignatureSchemeConfiguration, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus, ::testing::Values(TlsAgent::kServerEcdsa521), ::testing::Values(ssl_auth_ecdsa), ::testing::Values(ssl_sig_ecdsa_secp521r1_sha512))); INSTANTIATE_TEST_CASE_P( SignatureSchemeEcdsaSha1, TlsSignatureSchemeConfiguration, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12, ::testing::Values(TlsAgent::kServerEcdsa256, TlsAgent::kServerEcdsa384), diff --git a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc index 93c1b85f3a..85c30b2bfa 100644 --- a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc +++ b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc @@ -22,17 +22,17 @@ extern "C" { namespace nss_test { -// mode, version, cipher suite -typedef std::tuple CipherSuiteProfile; class TlsCipherSuiteTestBase : public TlsConnectTestBase { public: - TlsCipherSuiteTestBase(const std::string &mode, uint16_t version, + TlsCipherSuiteTestBase(SSLProtocolVariant variant, uint16_t version, uint16_t cipher_suite, SSLNamedGroup group, SSLSignatureScheme signature_scheme) - : TlsConnectTestBase(mode, version), + : TlsConnectTestBase(variant, version), cipher_suite_(cipher_suite), group_(group), signature_scheme_(signature_scheme), @@ -259,7 +259,7 @@ TEST_P(TlsCipherSuiteTest, ReadLimit) { static const uint8_t payload[18] = {6}; DataBuffer record; uint64_t epoch; - if (mode_ == DGRAM) { + if (variant_ == ssl_variant_datagram) { if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { epoch = 3; // Application traffic keys. } else { @@ -268,7 +268,7 @@ TEST_P(TlsCipherSuiteTest, ReadLimit) { } else { epoch = 0; } - TlsAgentTestBase::MakeRecord(mode_, kTlsApplicationDataType, version_, + TlsAgentTestBase::MakeRecord(variant_, kTlsApplicationDataType, version_, payload, sizeof(payload), &record, (epoch << 48) | record_limit()); server_->adapter()->PacketReceived(record); @@ -296,7 +296,7 @@ TEST_P(TlsCipherSuiteTest, WriteLimit) { k##name##Ciphers = ::testing::ValuesIn(k##name##CiphersArr); \ INSTANTIATE_TEST_CASE_P( \ CipherSuite##name, TlsCipherSuiteTest, \ - ::testing::Combine(TlsConnectTestBase::kTlsModes##modes, \ + ::testing::Combine(TlsConnectTestBase::kTlsVariants##modes, \ TlsConnectTestBase::kTls##versions, k##name##Ciphers, \ groups, sigalgs)); @@ -405,7 +405,7 @@ class SecurityStatusTest public ::testing::WithParamInterface { public: SecurityStatusTest() - : TlsCipherSuiteTestBase("TLS", GetParam().version, + : TlsCipherSuiteTestBase(ssl_variant_stream, GetParam().version, GetParam().cipher_suite, ssl_grp_none, ssl_sig_none) {} }; diff --git a/gtests/ssl_gtest/ssl_damage_unittest.cc b/gtests/ssl_gtest/ssl_damage_unittest.cc index dac76aed44..69fd003313 100644 --- a/gtests/ssl_gtest/ssl_damage_unittest.cc +++ b/gtests/ssl_gtest/ssl_damage_unittest.cc @@ -82,7 +82,7 @@ TEST_P(TlsConnectTls13, DamageServerSignature) { filter->EnableDecryption(); client_->ExpectSendAlert(kTlsAlertDecryptError); // The server can't read the client's alert, so it also sends an alert. - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); diff --git a/gtests/ssl_gtest/ssl_dhe_unittest.cc b/gtests/ssl_gtest/ssl_dhe_unittest.cc index e2dbbdb53f..97943303ad 100644 --- a/gtests/ssl_gtest/ssl_dhe_unittest.cc +++ b/gtests/ssl_gtest/ssl_dhe_unittest.cc @@ -272,10 +272,11 @@ class TlsDheSkeChangeYClient : public TlsDheSkeChangeY { std::shared_ptr server_filter_; }; -/* This matrix includes: mode (stream/datagram), TLS version, what change to +/* This matrix includes: variant (stream/datagram), TLS version, what change to * make to dh_Ys, whether the client will be configured to require DH named * groups. Test all combinations. */ -typedef std::tuple +typedef std::tuple DamageDHYProfile; class TlsDamageDHYTest : public TlsConnectTestBase, @@ -358,13 +359,13 @@ static const bool kTrueFalseArr[] = {true, false}; static ::testing::internal::ParamGenerator kTrueFalse = ::testing::ValuesIn(kTrueFalseArr); -INSTANTIATE_TEST_CASE_P(DamageYStream, TlsDamageDHYTest, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10ToV12, - kAllY, kTrueFalse)); +INSTANTIATE_TEST_CASE_P( + DamageYStream, TlsDamageDHYTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12, kAllY, kTrueFalse)); INSTANTIATE_TEST_CASE_P( DamageYDatagram, TlsDamageDHYTest, - ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11V12, kAllY, kTrueFalse)); class TlsDheSkeMakePEven : public TlsHandshakeFilter { diff --git a/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/gtests/ssl_gtest/ssl_ecdh_unittest.cc index b22a3ca4e0..1e406b6c20 100644 --- a/gtests/ssl_gtest/ssl_ecdh_unittest.cc +++ b/gtests/ssl_gtest/ssl_ecdh_unittest.cc @@ -574,12 +574,12 @@ TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyClientPoint) { } INSTANTIATE_TEST_CASE_P(KeyExchangeTest, TlsKeyExchangeTest, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV11Plus)); #ifndef NSS_DISABLE_TLS_1_3 INSTANTIATE_TEST_CASE_P(KeyExchangeTest, TlsKeyExchangeTest13, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV13)); #endif diff --git a/gtests/ssl_gtest/ssl_extension_unittest.cc b/gtests/ssl_gtest/ssl_extension_unittest.cc index 23842f7263..d15139419a 100644 --- a/gtests/ssl_gtest/ssl_extension_unittest.cc +++ b/gtests/ssl_gtest/ssl_extension_unittest.cc @@ -166,10 +166,8 @@ class TlsExtensionAppender : public TlsHandshakeFilter { class TlsExtensionTestBase : public TlsConnectTestBase { protected: - TlsExtensionTestBase(Mode mode, uint16_t version) - : TlsConnectTestBase(mode, version) {} - TlsExtensionTestBase(const std::string& mode, uint16_t version) - : TlsConnectTestBase(mode, version) {} + TlsExtensionTestBase(SSLProtocolVariant variant, uint16_t version) + : TlsConnectTestBase(variant, version) {} void ClientHelloErrorTest(std::shared_ptr filter, uint8_t desc = kTlsAlertDecodeError) { @@ -216,29 +214,31 @@ class TlsExtensionTestBase : public TlsConnectTestBase { class TlsExtensionTestDtls : public TlsExtensionTestBase, public ::testing::WithParamInterface { public: - TlsExtensionTestDtls() : TlsExtensionTestBase(DGRAM, GetParam()) {} + TlsExtensionTestDtls() + : TlsExtensionTestBase(ssl_variant_datagram, GetParam()) {} }; -class TlsExtensionTest12Plus - : public TlsExtensionTestBase, - public ::testing::WithParamInterface> { +class TlsExtensionTest12Plus : public TlsExtensionTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsExtensionTest12Plus() : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { } }; -class TlsExtensionTest12 - : public TlsExtensionTestBase, - public ::testing::WithParamInterface> { +class TlsExtensionTest12 : public TlsExtensionTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsExtensionTest12() : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { } }; -class TlsExtensionTest13 : public TlsExtensionTestBase, - public ::testing::WithParamInterface { +class TlsExtensionTest13 + : public TlsExtensionTestBase, + public ::testing::WithParamInterface { public: TlsExtensionTest13() : TlsExtensionTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} @@ -266,21 +266,21 @@ class TlsExtensionTest13 : public TlsExtensionTestBase, class TlsExtensionTest13Stream : public TlsExtensionTestBase { public: TlsExtensionTest13Stream() - : TlsExtensionTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_3) {} + : TlsExtensionTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {} }; -class TlsExtensionTestGeneric - : public TlsExtensionTestBase, - public ::testing::WithParamInterface> { +class TlsExtensionTestGeneric : public TlsExtensionTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsExtensionTestGeneric() : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { } }; -class TlsExtensionTestPre13 - : public TlsExtensionTestBase, - public ::testing::WithParamInterface> { +class TlsExtensionTestPre13 : public TlsExtensionTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsExtensionTestPre13() : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { @@ -992,9 +992,9 @@ TEST_P(TlsExtensionTest13, OddVersionList) { // TODO: this only tests extensions in server messages. The client can extend // Certificate messages, which is not checked here. -class TlsBogusExtensionTest - : public TlsConnectTestBase, - public ::testing::WithParamInterface> { +class TlsBogusExtensionTest : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsBogusExtensionTest() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} @@ -1044,7 +1044,7 @@ class TlsBogusExtensionTest13 : public TlsBogusExtensionTest { client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); client_->Handshake(); - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { server_->ExpectSendAlert(kTlsAlertBadRecordMac); } server_->Handshake(); @@ -1139,40 +1139,43 @@ TEST_P(TlsConnectStream, IncludePadding) { EXPECT_TRUE(capture->captured()); } -INSTANTIATE_TEST_CASE_P(ExtensionStream, TlsExtensionTestGeneric, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_CASE_P( + ExtensionStream, TlsExtensionTestGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); INSTANTIATE_TEST_CASE_P( ExtensionDatagram, TlsExtensionTestGeneric, - ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11Plus)); INSTANTIATE_TEST_CASE_P(ExtensionDatagramOnly, TlsExtensionTestDtls, TlsConnectTestBase::kTlsV11Plus); INSTANTIATE_TEST_CASE_P(ExtensionTls12Plus, TlsExtensionTest12Plus, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus)); -INSTANTIATE_TEST_CASE_P(ExtensionPre13Stream, TlsExtensionTestPre13, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10ToV12)); +INSTANTIATE_TEST_CASE_P( + ExtensionPre13Stream, TlsExtensionTestPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); INSTANTIATE_TEST_CASE_P(ExtensionPre13Datagram, TlsExtensionTestPre13, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV11V12)); INSTANTIATE_TEST_CASE_P(ExtensionTls13, TlsExtensionTest13, - TlsConnectTestBase::kTlsModesAll); + TlsConnectTestBase::kTlsVariantsAll); -INSTANTIATE_TEST_CASE_P(BogusExtensionStream, TlsBogusExtensionTestPre13, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10ToV12)); +INSTANTIATE_TEST_CASE_P( + BogusExtensionStream, TlsBogusExtensionTestPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); INSTANTIATE_TEST_CASE_P( BogusExtensionDatagram, TlsBogusExtensionTestPre13, - ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11V12)); INSTANTIATE_TEST_CASE_P(BogusExtension13, TlsBogusExtensionTest13, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV13)); } // namespace nss_test diff --git a/gtests/ssl_gtest/ssl_gather_unittest.cc b/gtests/ssl_gtest/ssl_gather_unittest.cc index 66976a796e..f47b2f4452 100644 --- a/gtests/ssl_gtest/ssl_gather_unittest.cc +++ b/gtests/ssl_gtest/ssl_gather_unittest.cc @@ -11,7 +11,7 @@ namespace nss_test { class GatherV2ClientHelloTest : public TlsConnectTestBase { public: - GatherV2ClientHelloTest() : TlsConnectTestBase(STREAM, 0) {} + GatherV2ClientHelloTest() : TlsConnectTestBase(ssl_variant_stream, 0) {} void ConnectExpectMalformedClientHello(const DataBuffer &data) { EnsureTlsSetup(); diff --git a/gtests/ssl_gtest/ssl_hrr_unittest.cc b/gtests/ssl_gtest/ssl_hrr_unittest.cc index b81dc020b2..39055f6419 100644 --- a/gtests/ssl_gtest/ssl_hrr_unittest.cc +++ b/gtests/ssl_gtest/ssl_hrr_unittest.cc @@ -106,7 +106,7 @@ TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) { // A new client that tries to resume with 0-RTT but doesn't send the // correct key share(s). The server will respond with an HRR. auto orig_client = - std::make_shared(client_->name(), TlsAgent::CLIENT, mode_); + std::make_shared(client_->name(), TlsAgent::CLIENT, variant_); client_.swap(orig_client); client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, SSL_LIBRARY_VERSION_TLS_1_3); @@ -130,7 +130,7 @@ TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) { orig_client.reset(); // Correct the DTLS message sequence number after an HRR. - if (mode_ == DGRAM) { + if (variant_ == ssl_variant_datagram) { client_->SetPacketFilter( std::make_shared()); } @@ -253,7 +253,7 @@ TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) { // Here we replace the TLS server with one that does TLS 1.2 only. // This will happily send the client a TLS 1.2 ServerHello. - server_.reset(new TlsAgent(server_->name(), TlsAgent::SERVER, mode_)); + server_.reset(new TlsAgent(server_->name(), TlsAgent::SERVER, variant_)); client_->SetPeer(server_); server_->SetPeer(client_); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, @@ -357,11 +357,11 @@ TEST_P(HelloRetryRequestAgentTest, HandleHelloRetryRequestCookie) { } INSTANTIATE_TEST_CASE_P(HelloRetryRequestAgentTests, HelloRetryRequestAgentTest, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV13)); #ifndef NSS_DISABLE_TLS_1_3 INSTANTIATE_TEST_CASE_P(HelloRetryRequestKeyExchangeTests, TlsKeyExchange13, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV13)); #endif diff --git a/gtests/ssl_gtest/ssl_loopback_unittest.cc b/gtests/ssl_gtest/ssl_loopback_unittest.cc index 7b4caf6357..fd05754d80 100644 --- a/gtests/ssl_gtest/ssl_loopback_unittest.cc +++ b/gtests/ssl_gtest/ssl_loopback_unittest.cc @@ -130,7 +130,7 @@ TEST_P(TlsConnectTls13, CaptureAlertClient) { client_->ExpectSendAlert(kTlsAlertDecodeError); server_->Handshake(); client_->Handshake(); - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { // DTLS just drops the alert it can't decrypt. server_->ExpectSendAlert(kTlsAlertBadRecordMac); } @@ -227,7 +227,8 @@ TEST_P(TlsConnectGeneric, ConnectWithCompressionMaybe) { client_->EnableCompression(); server_->EnableCompression(); Connect(); - EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && mode_ != DGRAM, + EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && + variant_ != ssl_variant_datagram, client_->is_compressed()); SendReceive(); } @@ -320,12 +321,13 @@ TEST_F(TlsConnectStreamTls13, NegotiateShortHeaders) { Connect(); } -INSTANTIATE_TEST_CASE_P(GenericStream, TlsConnectGeneric, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_CASE_P( + GenericStream, TlsConnectGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); INSTANTIATE_TEST_CASE_P( GenericDatagram, TlsConnectGeneric, - ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11Plus)); INSTANTIATE_TEST_CASE_P(StreamOnly, TlsConnectStream, @@ -333,33 +335,35 @@ INSTANTIATE_TEST_CASE_P(StreamOnly, TlsConnectStream, INSTANTIATE_TEST_CASE_P(DatagramOnly, TlsConnectDatagram, TlsConnectTestBase::kTlsV11Plus); -INSTANTIATE_TEST_CASE_P(Pre12Stream, TlsConnectPre12, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10V11)); +INSTANTIATE_TEST_CASE_P( + Pre12Stream, TlsConnectPre12, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10V11)); INSTANTIATE_TEST_CASE_P( Pre12Datagram, TlsConnectPre12, - ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11)); INSTANTIATE_TEST_CASE_P(Version12Only, TlsConnectTls12, - TlsConnectTestBase::kTlsModesAll); + TlsConnectTestBase::kTlsVariantsAll); #ifndef NSS_DISABLE_TLS_1_3 INSTANTIATE_TEST_CASE_P(Version13Only, TlsConnectTls13, - TlsConnectTestBase::kTlsModesAll); + TlsConnectTestBase::kTlsVariantsAll); #endif -INSTANTIATE_TEST_CASE_P(Pre13Stream, TlsConnectGenericPre13, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10ToV12)); +INSTANTIATE_TEST_CASE_P( + Pre13Stream, TlsConnectGenericPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); INSTANTIATE_TEST_CASE_P( Pre13Datagram, TlsConnectGenericPre13, - ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11V12)); INSTANTIATE_TEST_CASE_P(Pre13StreamOnly, TlsConnectStreamPre13, TlsConnectTestBase::kTlsV10ToV12); INSTANTIATE_TEST_CASE_P(Version12Plus, TlsConnectTls12Plus, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus)); } // namespace nspr_test diff --git a/gtests/ssl_gtest/ssl_resumption_unittest.cc b/gtests/ssl_gtest/ssl_resumption_unittest.cc index cde9f84e38..7b43870b00 100644 --- a/gtests/ssl_gtest/ssl_resumption_unittest.cc +++ b/gtests/ssl_gtest/ssl_resumption_unittest.cc @@ -523,7 +523,7 @@ class SelectedVersionReplacer : public TlsHandshakeFilter { // lower version number on resumption. TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) { uint16_t override_version = 0; - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { switch (version_) { case SSL_LIBRARY_VERSION_TLS_1_0: return; // Skip the test. diff --git a/gtests/ssl_gtest/ssl_skip_unittest.cc b/gtests/ssl_gtest/ssl_skip_unittest.cc index 65e3fcd54c..a130ef77fb 100644 --- a/gtests/ssl_gtest/ssl_skip_unittest.cc +++ b/gtests/ssl_gtest/ssl_skip_unittest.cc @@ -78,9 +78,9 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter { bool skipped_; }; -class TlsSkipTest - : public TlsConnectTestBase, - public ::testing::WithParamInterface> { +class TlsSkipTest : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple> { protected: TlsSkipTest() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} @@ -93,7 +93,7 @@ class TlsSkipTest }; class Tls13SkipTest : public TlsConnectTestBase, - public ::testing::WithParamInterface { + public ::testing::WithParamInterface { protected: Tls13SkipTest() : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} @@ -103,14 +103,14 @@ class Tls13SkipTest : public TlsConnectTestBase, server_->SetTlsRecordFilter(filter); filter->EnableDecryption(); client_->ExpectSendAlert(kTlsAlertUnexpectedMessage); - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); } else { ConnectExpectFailOneSide(TlsAgent::CLIENT); } client_->CheckErrorCode(error); - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); } else { ASSERT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); @@ -227,12 +227,13 @@ TEST_P(Tls13SkipTest, SkipClientCertificateVerify) { SSL_ERROR_RX_UNEXPECTED_FINISHED); } -INSTANTIATE_TEST_CASE_P(SkipTls10, TlsSkipTest, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10)); +INSTANTIATE_TEST_CASE_P( + SkipTls10, TlsSkipTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10)); INSTANTIATE_TEST_CASE_P(SkipVariants, TlsSkipTest, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV11V12)); INSTANTIATE_TEST_CASE_P(Skip13Variants, Tls13SkipTest, - TlsConnectTestBase::kTlsModesAll); + TlsConnectTestBase::kTlsVariantsAll); } // namespace nss_test diff --git a/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc index 096b76dd48..110e3e0b6f 100644 --- a/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc +++ b/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc @@ -141,10 +141,11 @@ class SSLv2ClientHelloFilter : public PacketFilter { class SSLv2ClientHelloTestF : public TlsConnectTestBase { public: - SSLv2ClientHelloTestF() : TlsConnectTestBase(STREAM, 0), filter_(nullptr) {} + SSLv2ClientHelloTestF() + : TlsConnectTestBase(ssl_variant_stream, 0), filter_(nullptr) {} - SSLv2ClientHelloTestF(Mode mode, uint16_t version) - : TlsConnectTestBase(mode, version), filter_(nullptr) {} + SSLv2ClientHelloTestF(SSLProtocolVariant variant, uint16_t version) + : TlsConnectTestBase(variant, version), filter_(nullptr) {} void SetUp() { TlsConnectTestBase::SetUp(); @@ -193,7 +194,8 @@ class SSLv2ClientHelloTestF : public TlsConnectTestBase { class SSLv2ClientHelloTest : public SSLv2ClientHelloTestF, public ::testing::WithParamInterface { public: - SSLv2ClientHelloTest() : SSLv2ClientHelloTestF(STREAM, GetParam()) {} + SSLv2ClientHelloTest() + : SSLv2ClientHelloTestF(ssl_variant_stream, GetParam()) {} }; // Test negotiating TLS 1.0 - 1.2. diff --git a/gtests/ssl_gtest/ssl_version_unittest.cc b/gtests/ssl_gtest/ssl_version_unittest.cc index a6c6ce17af..379a67e350 100644 --- a/gtests/ssl_gtest/ssl_version_unittest.cc +++ b/gtests/ssl_gtest/ssl_version_unittest.cc @@ -260,7 +260,7 @@ TEST_P(TlsConnectGeneric, AlertBeforeServerHello) { static const uint8_t kWarningAlert[] = {kTlsAlertWarning, kTlsAlertUnrecognizedName}; DataBuffer alert; - TlsAgentTestBase::MakeRecord(mode_, kTlsAlertType, + TlsAgentTestBase::MakeRecord(variant_, kTlsAlertType, SSL_LIBRARY_VERSION_TLS_1_0, kWarningAlert, PR_ARRAY_SIZE(kWarningAlert), &alert); client_->adapter()->PacketReceived(alert); diff --git a/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc b/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc index fdd35e7a46..eda96831c7 100644 --- a/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc +++ b/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc @@ -92,12 +92,8 @@ class TestPolicyVersionRange public ::testing::WithParamInterface { public: TestPolicyVersionRange() - : TlsConnectTestBase(((static_cast( - std::get<0>(GetParam())) == ssl_variant_stream) - ? STREAM - : DGRAM), - 0), - variant_(static_cast(std::get<0>(GetParam()))), + : TlsConnectTestBase(std::get<0>(GetParam()), 0), + variant_(std::get<0>(GetParam())), policy_("policy", std::get<1>(GetParam()), std::get<2>(GetParam())), input_("input", std::get<3>(GetParam()), std::get<4>(GetParam())), library_("supported-by-library", @@ -124,9 +120,7 @@ class TestPolicyVersionRange void CreateDummySocket(std::shared_ptr* dummy_socket, ScopedPRFileDesc* ssl_fd) { - (*dummy_socket) - .reset(new DummyPrSocket( - "dummy", (variant_ == ssl_variant_stream) ? STREAM : DGRAM)); + (*dummy_socket).reset(new DummyPrSocket("dummy", variant_)); *ssl_fd = (*dummy_socket)->CreateFD(); if (variant_ == ssl_variant_stream) { SSL_ImportFD(nullptr, ssl_fd->get()); @@ -275,11 +269,6 @@ static const uint16_t kExpandedVersionsArr[] = { static ::testing::internal::ParamGenerator kExpandedVersions = ::testing::ValuesIn(kExpandedVersionsArr); -static const SSLProtocolVariant kVariantsArr[] = {ssl_variant_stream, - ssl_variant_datagram}; -static ::testing::internal::ParamGenerator kVariants = - ::testing::ValuesIn(kVariantsArr); - TEST_P(TestPolicyVersionRange, TestAllTLSVersionsAndPolicyCombinations) { ASSERT_TRUE(variant_ == ssl_variant_stream || variant_ == ssl_variant_datagram) @@ -398,7 +387,8 @@ TEST_P(TestPolicyVersionRange, TestAllTLSVersionsAndPolicyCombinations) { } INSTANTIATE_TEST_CASE_P(TLSVersionRanges, TestPolicyVersionRange, - ::testing::Combine(kVariants, kExpandedVersions, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, kExpandedVersions, kExpandedVersions, + kExpandedVersions, kExpandedVersions)); } // namespace nss_test diff --git a/gtests/ssl_gtest/test_io.cc b/gtests/ssl_gtest/test_io.cc index 42470a1a8e..b9f0c672e8 100644 --- a/gtests/ssl_gtest/test_io.cc +++ b/gtests/ssl_gtest/test_io.cc @@ -40,9 +40,8 @@ void DummyPrSocket::PacketReceived(const DataBuffer &packet) { } int32_t DummyPrSocket::Read(PRFileDesc *f, void *data, int32_t len) { - PR_ASSERT(mode_ == STREAM); - - if (mode_ != STREAM) { + PR_ASSERT(variant_ == ssl_variant_stream); + if (variant_ != ssl_variant_stream) { PR_SetError(PR_INVALID_METHOD_ERROR, 0); return -1; } @@ -75,7 +74,7 @@ int32_t DummyPrSocket::Recv(PRFileDesc *f, void *buf, int32_t buflen, return -1; } - if (mode() != DGRAM) { + if (variant() != ssl_variant_datagram) { return Read(f, buf, buflen); } diff --git a/gtests/ssl_gtest/test_io.h b/gtests/ssl_gtest/test_io.h index 3cd7571db5..ac24972228 100644 --- a/gtests/ssl_gtest/test_io.h +++ b/gtests/ssl_gtest/test_io.h @@ -18,6 +18,7 @@ #include "dummy_io.h" #include "prio.h" #include "scoped_ptrs.h" +#include "sslt.h" namespace nss_test { @@ -44,17 +45,11 @@ class PacketFilter { virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0; }; -enum Mode { STREAM, DGRAM }; - -inline std::ostream& operator<<(std::ostream& os, Mode m) { - return os << ((m == STREAM) ? "TLS" : "DTLS"); -} - class DummyPrSocket : public DummyIOLayerMethods { public: - DummyPrSocket(const std::string& name, Mode mode) + DummyPrSocket(const std::string& name, SSLProtocolVariant variant) : name_(name), - mode_(mode), + variant_(variant), peer_(), input_(), filter_(nullptr), @@ -78,7 +73,7 @@ class DummyPrSocket : public DummyIOLayerMethods { int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override; void CloseWrites() { writeable_ = false; } - Mode mode() const { return mode_; } + SSLProtocolVariant variant() const { return variant_; } bool readable() const { return !input_.empty(); } private: @@ -99,7 +94,7 @@ class DummyPrSocket : public DummyIOLayerMethods { }; const std::string name_; - Mode mode_; + SSLProtocolVariant variant_; std::weak_ptr peer_; std::queue input_; std::shared_ptr filter_; diff --git a/gtests/ssl_gtest/tls_agent.cc b/gtests/ssl_gtest/tls_agent.cc index 03d31f972e..a53cf8868b 100644 --- a/gtests/ssl_gtest/tls_agent.cc +++ b/gtests/ssl_gtest/tls_agent.cc @@ -43,12 +43,13 @@ const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa"; const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa"; const std::string TlsAgent::kServerDsa = "dsa"; -TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode) +TlsAgent::TlsAgent(const std::string& name, Role role, + SSLProtocolVariant variant) : name_(name), - mode_(mode), + variant_(variant), role_(role), server_key_bits_(0), - adapter_(new DummyPrSocket(role_str(), mode)), + adapter_(new DummyPrSocket(role_str(), variant)), ssl_fd_(nullptr), state_(STATE_INIT), timer_handle_(nullptr), @@ -76,8 +77,7 @@ TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode) skip_version_checks_(false) { memset(&info_, 0, sizeof(info_)); memset(&csinfo_, 0, sizeof(csinfo_)); - SECStatus rv = SSL_VersionRangeGetDefault( - mode_ == STREAM ? ssl_variant_stream : ssl_variant_datagram, &vrange_); + SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_); EXPECT_EQ(SECSuccess, rv); } @@ -154,7 +154,7 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { if (!dummy_fd) { return false; } - if (adapter_->mode() == STREAM) { + if (adapter_->variant() == ssl_variant_stream) { ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get())); } else { ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get())); @@ -757,7 +757,8 @@ void TlsAgent::Connected() { PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd()); // We use one ciphersuite in each direction, plus one that's kept around // by DTLS for retransmission. - PRInt32 expected = ((mode_ == DGRAM) && (role_ == CLIENT)) ? 3 : 2; + PRInt32 expected = + ((variant_ == ssl_variant_datagram) && (role_ == CLIENT)) ? 3 : 2; EXPECT_EQ(expected, cipherSuites); if (expected != cipherSuites) { SSLInt_PrintTls13CipherSpecs(ssl_fd()); @@ -835,7 +836,7 @@ void TlsAgent::Handshake() { int32_t err = PR_GetError(); if (err == PR_WOULD_BLOCK_ERROR) { LOGV("Would have blocked"); - if (mode_ == DGRAM) { + if (variant_ == ssl_variant_datagram) { if (timer_handle_) { timer_handle_->Cancel(); timer_handle_ = nullptr; @@ -986,7 +987,7 @@ void TlsAgentTestBase::TearDown() { void TlsAgentTestBase::Reset(const std::string& server_name) { agent_.reset( new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name, - role_, mode_)); + role_, variant_)); if (version_) { agent_->SetVersionRange(version_, version_); } @@ -1024,14 +1025,16 @@ void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, } } -void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version, - const uint8_t* buf, size_t len, - DataBuffer* out, uint64_t seq_num) { +void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type, + uint16_t version, const uint8_t* buf, + size_t len, DataBuffer* out, + uint64_t seq_num) { size_t index = 0; index = out->Write(index, type, 1); - index = out->Write( - index, mode == STREAM ? version : TlsVersionToDtlsVersion(version), 2); - if (mode == DGRAM) { + if (variant == ssl_variant_stream) { + index = out->Write(index, version, 2); + } else { + index = out->Write(index, TlsVersionToDtlsVersion(version), 2); index = out->Write(index, seq_num >> 32, 4); index = out->Write(index, seq_num & PR_UINT32_MAX, 4); } @@ -1042,7 +1045,7 @@ void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version, void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf, size_t len, DataBuffer* out, uint64_t seq_num) const { - MakeRecord(mode_, type, version, buf, len, out, seq_num); + MakeRecord(variant_, type, version, buf, len, out, seq_num); } void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type, @@ -1061,7 +1064,7 @@ void TlsAgentTestBase::MakeHandshakeMessageFragment( if (!fragment_length) fragment_length = hs_len; index = out->Write(index, hs_type, 1); // Handshake record type. index = out->Write(index, hs_len, 3); // Handshake length - if (mode_ == DGRAM) { + if (variant_ == ssl_variant_datagram) { index = out->Write(index, seq_num, 2); index = out->Write(index, fragment_offset, 3); index = out->Write(index, fragment_length, 3); diff --git a/gtests/ssl_gtest/tls_agent.h b/gtests/ssl_gtest/tls_agent.h index aad401da19..32f6175b66 100644 --- a/gtests/ssl_gtest/tls_agent.h +++ b/gtests/ssl_gtest/tls_agent.h @@ -74,7 +74,7 @@ class TlsAgent : public PollTarget { static const std::string kServerEcdhRsa; static const std::string kServerDsa; - TlsAgent(const std::string& name, Role role, Mode mode); + TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant); virtual ~TlsAgent(); void SetPeer(std::shared_ptr& peer) { @@ -358,7 +358,7 @@ class TlsAgent : public PollTarget { void Connected(); const std::string name_; - Mode mode_; + SSLProtocolVariant variant_; Role role_; uint16_t server_key_bits_; std::shared_ptr adapter_; @@ -401,12 +401,13 @@ class TlsAgentTestBase : public ::testing::Test { public: static ::testing::internal::ParamGenerator kTlsRolesAll; - TlsAgentTestBase(TlsAgent::Role role, Mode mode, uint16_t version = 0) + TlsAgentTestBase(TlsAgent::Role role, SSLProtocolVariant variant, + uint16_t version = 0) : agent_(nullptr), role_(role), - mode_(mode), + variant_(variant), version_(version), - sink_adapter_(new DummyPrSocket("sink", mode)) {} + sink_adapter_(new DummyPrSocket("sink", variant)) {} virtual ~TlsAgentTestBase() {} void SetUp(); @@ -414,9 +415,9 @@ class TlsAgentTestBase : public ::testing::Test { void ExpectAlert(uint8_t alert); - static void MakeRecord(Mode mode, uint8_t type, uint16_t version, - const uint8_t* buf, size_t len, DataBuffer* out, - uint64_t seq_num = 0); + static void MakeRecord(SSLProtocolVariant variant, uint8_t type, + uint16_t version, const uint8_t* buf, size_t len, + DataBuffer* out, uint64_t seq_num = 0); void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf, size_t len, DataBuffer* out, uint64_t seq_num = 0) const; void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len, @@ -431,10 +432,6 @@ class TlsAgentTestBase : public ::testing::Test { return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER; } - static inline Mode ToMode(const std::string& str) { - return str == "TLS" ? STREAM : DGRAM; - } - void Init(const std::string& server_name = TlsAgent::kServerRsa); void Reset(const std::string& server_name = TlsAgent::kServerRsa); @@ -445,28 +442,28 @@ class TlsAgentTestBase : public ::testing::Test { std::unique_ptr agent_; TlsAgent::Role role_; - Mode mode_; + SSLProtocolVariant variant_; uint16_t version_; // This adapter is here just to accept packets from this agent. std::shared_ptr sink_adapter_; }; -class TlsAgentTest : public TlsAgentTestBase, - public ::testing::WithParamInterface< - std::tuple> { +class TlsAgentTest + : public TlsAgentTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsAgentTest() : TlsAgentTestBase(ToRole(std::get<0>(GetParam())), - ToMode(std::get<1>(GetParam())), - std::get<2>(GetParam())) {} + std::get<1>(GetParam()), std::get<2>(GetParam())) {} }; -class TlsAgentTestClient - : public TlsAgentTestBase, - public ::testing::WithParamInterface> { +class TlsAgentTestClient : public TlsAgentTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsAgentTestClient() - : TlsAgentTestBase(TlsAgent::CLIENT, ToMode(std::get<0>(GetParam())), + : TlsAgentTestBase(TlsAgent::CLIENT, std::get<0>(GetParam()), std::get<1>(GetParam())) {} }; @@ -474,17 +471,20 @@ class TlsAgentTestClient13 : public TlsAgentTestClient {}; class TlsAgentStreamTestClient : public TlsAgentTestBase { public: - TlsAgentStreamTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, STREAM) {} + TlsAgentStreamTestClient() + : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_stream) {} }; class TlsAgentStreamTestServer : public TlsAgentTestBase { public: - TlsAgentStreamTestServer() : TlsAgentTestBase(TlsAgent::SERVER, STREAM) {} + TlsAgentStreamTestServer() + : TlsAgentTestBase(TlsAgent::SERVER, ssl_variant_stream) {} }; class TlsAgentDgramTestClient : public TlsAgentTestBase { public: - TlsAgentDgramTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, DGRAM) {} + TlsAgentDgramTestClient() + : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_datagram) {} }; inline bool operator==(const SSLVersionRange& vr1, const SSLVersionRange& vr2) { diff --git a/gtests/ssl_gtest/tls_connect.cc b/gtests/ssl_gtest/tls_connect.cc index 93d48be02b..861d162ae3 100644 --- a/gtests/ssl_gtest/tls_connect.cc +++ b/gtests/ssl_gtest/tls_connect.cc @@ -20,17 +20,20 @@ extern std::string g_working_dir_path; namespace nss_test { -static const std::string kTlsModesStreamArr[] = {"TLS"}; -::testing::internal::ParamGenerator - TlsConnectTestBase::kTlsModesStream = - ::testing::ValuesIn(kTlsModesStreamArr); -static const std::string kTlsModesDatagramArr[] = {"DTLS"}; -::testing::internal::ParamGenerator - TlsConnectTestBase::kTlsModesDatagram = - ::testing::ValuesIn(kTlsModesDatagramArr); -static const std::string kTlsModesAllArr[] = {"TLS", "DTLS"}; -::testing::internal::ParamGenerator - TlsConnectTestBase::kTlsModesAll = ::testing::ValuesIn(kTlsModesAllArr); +static const SSLProtocolVariant kTlsVariantsStreamArr[] = {ssl_variant_stream}; +::testing::internal::ParamGenerator + TlsConnectTestBase::kTlsVariantsStream = + ::testing::ValuesIn(kTlsVariantsStreamArr); +static const SSLProtocolVariant kTlsVariantsDatagramArr[] = { + ssl_variant_datagram}; +::testing::internal::ParamGenerator + TlsConnectTestBase::kTlsVariantsDatagram = + ::testing::ValuesIn(kTlsVariantsDatagramArr); +static const SSLProtocolVariant kTlsVariantsAllArr[] = {ssl_variant_stream, + ssl_variant_datagram}; +::testing::internal::ParamGenerator + TlsConnectTestBase::kTlsVariantsAll = + ::testing::ValuesIn(kTlsVariantsAllArr); static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV10 = @@ -100,10 +103,11 @@ std::string VersionString(uint16_t version) { } } -TlsConnectTestBase::TlsConnectTestBase(Mode mode, uint16_t version) - : mode_(mode), - client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_)), - server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_)), +TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant, + uint16_t version) + : variant_(variant), + client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)), + server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)), client_model_(nullptr), server_model_(nullptr), version_(version), @@ -113,18 +117,15 @@ TlsConnectTestBase::TlsConnectTestBase(Mode mode, uint16_t version) expect_early_data_accepted_(false), skip_version_checks_(false) { std::string v; - if (mode_ == DGRAM && version_ == SSL_LIBRARY_VERSION_TLS_1_1) { + if (variant_ == ssl_variant_datagram && + version_ == SSL_LIBRARY_VERSION_TLS_1_1) { v = "1.0"; } else { v = VersionString(version_); } - std::cerr << "Version: " << mode_ << " " << v << std::endl; + std::cerr << "Version: " << variant_ << " " << v << std::endl; } -TlsConnectTestBase::TlsConnectTestBase(const std::string& mode, - uint16_t version) - : TlsConnectTestBase(TlsConnectTestBase::ToMode(mode), version) {} - TlsConnectTestBase::~TlsConnectTestBase() {} // Check the group of each of the supported groups @@ -208,8 +209,8 @@ void TlsConnectTestBase::Reset() { void TlsConnectTestBase::Reset(const std::string& server_name, const std::string& client_name) { - client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, mode_)); - server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, mode_)); + client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_)); + server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_)); if (skip_version_checks_) { client_->SkipVersionChecks(); server_->SkipVersionChecks(); @@ -514,9 +515,9 @@ void TlsConnectTestBase::EnsureModelSockets() { if (!client_model_) { ASSERT_EQ(server_model_, nullptr); client_model_.reset( - new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_)); + new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)); server_model_.reset( - new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_)); + new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)); if (skip_version_checks_) { client_model_->SkipVersionChecks(); server_model_->SkipVersionChecks(); diff --git a/gtests/ssl_gtest/tls_connect.h b/gtests/ssl_gtest/tls_connect.h index 3211f20f78..73e8dc81a9 100644 --- a/gtests/ssl_gtest/tls_connect.h +++ b/gtests/ssl_gtest/tls_connect.h @@ -25,9 +25,12 @@ extern std::string VersionString(uint16_t version); // A generic TLS connection test base. class TlsConnectTestBase : public ::testing::Test { public: - static ::testing::internal::ParamGenerator kTlsModesStream; - static ::testing::internal::ParamGenerator kTlsModesDatagram; - static ::testing::internal::ParamGenerator kTlsModesAll; + static ::testing::internal::ParamGenerator + kTlsVariantsStream; + static ::testing::internal::ParamGenerator + kTlsVariantsDatagram; + static ::testing::internal::ParamGenerator + kTlsVariantsAll; static ::testing::internal::ParamGenerator kTlsV10; static ::testing::internal::ParamGenerator kTlsV11; static ::testing::internal::ParamGenerator kTlsV12; @@ -39,8 +42,7 @@ class TlsConnectTestBase : public ::testing::Test { static ::testing::internal::ParamGenerator kTlsV12Plus; static ::testing::internal::ParamGenerator kTlsVAll; - TlsConnectTestBase(Mode mode, uint16_t version); - TlsConnectTestBase(const std::string& mode, uint16_t version); + TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version); virtual ~TlsConnectTestBase(); void SetUp(); @@ -114,7 +116,7 @@ class TlsConnectTestBase : public ::testing::Test { void SkipVersionChecks(); protected: - Mode mode_; + SSLProtocolVariant variant_; std::shared_ptr client_; std::shared_ptr server_; std::unique_ptr client_model_; @@ -130,10 +132,6 @@ class TlsConnectTestBase : public ::testing::Test { const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61}; private: - static inline Mode ToMode(const std::string& str) { - return str == "TLS" ? STREAM : DGRAM; - } - void CheckResumption(SessionResumptionMode expected); void CheckExtendedMasterSecret(); void CheckEarlyDataAccepted(); @@ -159,20 +157,20 @@ class TlsConnectTestBase : public ::testing::Test { // A non-parametrized TLS test base. class TlsConnectTest : public TlsConnectTestBase { public: - TlsConnectTest() : TlsConnectTestBase(STREAM, 0) {} + TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {} }; // A non-parametrized DTLS-only test base. class DtlsConnectTest : public TlsConnectTestBase { public: - DtlsConnectTest() : TlsConnectTestBase(DGRAM, 0) {} + DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {} }; // A TLS-only test base. class TlsConnectStream : public TlsConnectTestBase, public ::testing::WithParamInterface { public: - TlsConnectStream() : TlsConnectTestBase(STREAM, GetParam()) {} + TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {} }; // A TLS-only test base for tests before 1.3 @@ -182,30 +180,30 @@ class TlsConnectStreamPre13 : public TlsConnectStream {}; class TlsConnectDatagram : public TlsConnectTestBase, public ::testing::WithParamInterface { public: - TlsConnectDatagram() : TlsConnectTestBase(DGRAM, GetParam()) {} + TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {} }; -// A generic test class that can be either STREAM or DGRAM and a single version -// of TLS. This is configured in ssl_loopback_unittest.cc. All uses of this -// should use TEST_P(). -class TlsConnectGeneric - : public TlsConnectTestBase, - public ::testing::WithParamInterface> { +// A generic test class that can be either stream or datagram and a single +// version of TLS. This is configured in ssl_loopback_unittest.cc. +class TlsConnectGeneric : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsConnectGeneric(); }; // A Pre TLS 1.2 generic test. -class TlsConnectPre12 - : public TlsConnectTestBase, - public ::testing::WithParamInterface> { +class TlsConnectPre12 : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsConnectPre12(); }; // A TLS 1.2 only generic test. -class TlsConnectTls12 : public TlsConnectTestBase, - public ::testing::WithParamInterface { +class TlsConnectTls12 + : public TlsConnectTestBase, + public ::testing::WithParamInterface { public: TlsConnectTls12(); }; @@ -214,20 +212,21 @@ class TlsConnectTls12 : public TlsConnectTestBase, class TlsConnectStreamTls12 : public TlsConnectTestBase { public: TlsConnectStreamTls12() - : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_2) {} + : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {} }; // A TLS 1.2+ generic test. -class TlsConnectTls12Plus - : public TlsConnectTestBase, - public ::testing::WithParamInterface> { +class TlsConnectTls12Plus : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsConnectTls12Plus(); }; // A TLS 1.3 only generic test. -class TlsConnectTls13 : public TlsConnectTestBase, - public ::testing::WithParamInterface { +class TlsConnectTls13 + : public TlsConnectTestBase, + public ::testing::WithParamInterface { public: TlsConnectTls13(); }; @@ -236,13 +235,13 @@ class TlsConnectTls13 : public TlsConnectTestBase, class TlsConnectStreamTls13 : public TlsConnectTestBase { public: TlsConnectStreamTls13() - : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_3) {} + : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {} }; class TlsConnectDatagram13 : public TlsConnectTestBase { public: TlsConnectDatagram13() - : TlsConnectTestBase(DGRAM, SSL_LIBRARY_VERSION_TLS_1_3) {} + : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {} }; // A variant that is used only with Pre13.