/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "secerr.h" #include "ssl.h" #include "sslerr.h" #include "sslproto.h" extern "C" { // This is not something that should make you happy. #include "libssl_internals.h" } #include #include "gtest_utils.h" #include "nss_scoped_ptrs.h" #include "tls_connect.h" #include "tls_filter.h" #include "tls_parser.h" namespace nss_test { class HandshakeSecretTracker { public: HandshakeSecretTracker(const std::shared_ptr& agent, uint16_t first_read_epoch, uint16_t first_write_epoch) : agent_(agent), next_read_epoch_(first_read_epoch), next_write_epoch_(first_write_epoch) { EXPECT_EQ(SECSuccess, SSL_SecretCallback(agent_->ssl_fd(), HandshakeSecretTracker::SecretCb, this)); } void CheckComplete() const { EXPECT_EQ(0, next_read_epoch_); EXPECT_EQ(0, next_write_epoch_); } private: static void SecretCb(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir, PK11SymKey* secret, void* arg) { HandshakeSecretTracker* t = reinterpret_cast(arg); t->SecretUpdated(epoch, dir, secret); } void SecretUpdated(PRUint16 epoch, SSLSecretDirection dir, PK11SymKey* secret) { if (g_ssl_gtest_verbose) { std::cerr << agent_->role_str() << ": secret callback for " << dir << " epoch " << epoch << std::endl; } EXPECT_TRUE(secret); uint16_t* p; if (dir == ssl_secret_read) { p = &next_read_epoch_; } else { ASSERT_EQ(ssl_secret_write, dir); p = &next_write_epoch_; } EXPECT_EQ(*p, epoch); switch (*p) { case 1: // 1 == 0-RTT, next should be handshake. case 2: // 2 == handshake, next should be application data. (*p)++; break; case 3: // 3 == application data, there should be no more. // Use 0 as a sentinel value. *p = 0; break; default: ADD_FAILURE() << "Unexpected next epoch: " << *p; } } std::shared_ptr agent_; uint16_t next_read_epoch_; uint16_t next_write_epoch_; }; TEST_F(TlsConnectTest, HandshakeSecrets) { ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); EnsureTlsSetup(); HandshakeSecretTracker c(client_, 2, 2); HandshakeSecretTracker s(server_, 2, 2); Connect(); SendReceive(); c.CheckComplete(); s.CheckComplete(); } TEST_F(TlsConnectTest, ZeroRttSecrets) { SetupForZeroRtt(); HandshakeSecretTracker c(client_, 2, 1); HandshakeSecretTracker s(server_, 1, 2); client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); ExpectResumption(RESUME_TICKET); ZeroRttSendReceive(true, true); Handshake(); ExpectEarlyDataAccepted(true); CheckConnected(); SendReceive(); c.CheckComplete(); s.CheckComplete(); } class KeyUpdateTracker { public: KeyUpdateTracker(const std::shared_ptr& agent, bool expect_read_secret) : agent_(agent), expect_read_secret_(expect_read_secret), called_(false) { EXPECT_EQ(SECSuccess, SSL_SecretCallback(agent_->ssl_fd(), KeyUpdateTracker::SecretCb, this)); } void CheckCalled() const { EXPECT_TRUE(called_); } private: static void SecretCb(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir, PK11SymKey* secret, void* arg) { KeyUpdateTracker* t = reinterpret_cast(arg); t->SecretUpdated(epoch, dir, secret); } void SecretUpdated(PRUint16 epoch, SSLSecretDirection dir, PK11SymKey* secret) { EXPECT_EQ(4U, epoch); EXPECT_EQ(expect_read_secret_, dir == ssl_secret_read); EXPECT_TRUE(secret); called_ = true; } std::shared_ptr agent_; bool expect_read_secret_; bool called_; }; TEST_F(TlsConnectTest, KeyUpdateSecrets) { ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); Connect(); // The update is to the client write secret; the server read secret. KeyUpdateTracker c(client_, false); KeyUpdateTracker s(server_, true); EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE)); SendReceive(50); SendReceive(60); CheckEpochs(4, 3); c.CheckCalled(); s.CheckCalled(); } // BadPrSocket is an instance of a PR IO layer that crashes the test if it is // ever used for reading or writing. It does that by failing to overwrite any // of the DummyIOLayerMethods, which all crash when invoked. class BadPrSocket : public DummyIOLayerMethods { public: BadPrSocket(std::shared_ptr& agent) : DummyIOLayerMethods() { static PRDescIdentity bad_identity = PR_GetUniqueIdentity("bad NSPR id"); fd_ = DummyIOLayerMethods::CreateFD(bad_identity, this); // This is terrible, but NSPR doesn't provide an easy way to replace the // bottom layer of an IO stack. Take the DummyPrSocket and replace its // NSPR method vtable with the ones from this object. dummy_layer_ = PR_GetIdentitiesLayer(agent->ssl_fd(), DummyPrSocket::LayerId()); EXPECT_TRUE(dummy_layer_); original_methods_ = dummy_layer_->methods; original_secret_ = dummy_layer_->secret; dummy_layer_->methods = fd_->methods; dummy_layer_->secret = reinterpret_cast(this); } // This will be destroyed before the agent, so we need to restore the state // before we tampered with it. virtual ~BadPrSocket() { dummy_layer_->methods = original_methods_; dummy_layer_->secret = original_secret_; } private: ScopedPRFileDesc fd_; PRFileDesc* dummy_layer_; const PRIOMethods* original_methods_; PRFilePrivate* original_secret_; }; class StagedRecords { public: StagedRecords(std::shared_ptr& agent) : agent_(agent), records_() { EXPECT_EQ(SECSuccess, SSL_RecordLayerWriteCallback( agent_->ssl_fd(), StagedRecords::StageRecordData, this)); } virtual ~StagedRecords() { // Uninstall so that the callback doesn't fire during cleanup. EXPECT_EQ(SECSuccess, SSL_RecordLayerWriteCallback(agent_->ssl_fd(), nullptr, nullptr)); } bool empty() const { return records_.empty(); } void ForwardAll(std::shared_ptr& peer) { EXPECT_NE(agent_, peer) << "can't forward to self"; for (auto r : records_) { r.Forward(peer); } records_.clear(); } // This forwards all saved data and checks the resulting state. void ForwardAll(std::shared_ptr& peer, TlsAgent::State expected_state) { ForwardAll(peer); switch (expected_state) { case TlsAgent::STATE_CONNECTED: // The handshake callback should have been called, so check that before // checking that SSL_ForceHandshake succeeds. EXPECT_EQ(expected_state, peer->state()); EXPECT_EQ(SECSuccess, SSL_ForceHandshake(peer->ssl_fd())); break; case TlsAgent::STATE_CONNECTING: // Check that SSL_ForceHandshake() blocks. EXPECT_EQ(SECFailure, SSL_ForceHandshake(peer->ssl_fd())); EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); // Update and check the state. peer->Handshake(); EXPECT_EQ(TlsAgent::STATE_CONNECTING, peer->state()); break; default: ADD_FAILURE() << "No idea how to handle this state"; } } void ForwardPartial(std::shared_ptr& peer) { if (records_.empty()) { ADD_FAILURE() << "No records to slice"; return; } auto& last = records_.back(); auto tail = last.SliceTail(); ForwardAll(peer, TlsAgent::STATE_CONNECTING); records_.push_back(tail); EXPECT_EQ(TlsAgent::STATE_CONNECTING, peer->state()); } private: // A single record. class StagedRecord { public: StagedRecord(const std::string role, uint16_t epoch, SSLContentType ct, const uint8_t* data, size_t len) : role_(role), epoch_(epoch), content_type_(ct), data_(data, len) { if (g_ssl_gtest_verbose) { std::cerr << role_ << ": staged epoch " << epoch_ << " " << content_type_ << ": " << data_ << std::endl; } } // This forwards staged data to the identified agent. void Forward(std::shared_ptr& peer) { // Now there should be staged data. EXPECT_FALSE(data_.empty()); if (g_ssl_gtest_verbose) { std::cerr << role_ << ": forward " << data_ << std::endl; } EXPECT_EQ(SECSuccess, SSL_RecordLayerData(peer->ssl_fd(), epoch_, content_type_, data_.data(), static_cast(data_.len()))); } // Slices the tail off this record and returns it. StagedRecord SliceTail() { size_t slice = 1; if (data_.len() <= slice) { ADD_FAILURE() << "record too small to slice in two"; slice = 0; } size_t keep = data_.len() - slice; StagedRecord tail(role_, epoch_, content_type_, data_.data() + keep, slice); data_.Truncate(keep); return tail; } private: std::string role_; uint16_t epoch_; SSLContentType content_type_; DataBuffer data_; }; // This is an SSLRecordWriteCallback that stages data. static SECStatus StageRecordData(PRFileDesc* fd, PRUint16 epoch, SSLContentType content_type, const PRUint8* data, unsigned int len, void* arg) { auto stage = reinterpret_cast(arg); stage->records_.push_back(StagedRecord(stage->agent_->role_str(), epoch, content_type, data, static_cast(len))); return SECSuccess; } std::shared_ptr& agent_; std::deque records_; }; // Attempting to feed application data in before the handshake is complete // should be caught. static void RefuseApplicationData(std::shared_ptr& peer, uint16_t epoch) { static const uint8_t d[] = {1, 2, 3}; EXPECT_EQ(SECFailure, SSL_RecordLayerData(peer->ssl_fd(), epoch, ssl_ct_application_data, d, static_cast(sizeof(d)))); EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); } static void SendForwardReceive(std::shared_ptr& sender, StagedRecords& sender_stage, std::shared_ptr& receiver) { const size_t count = 10; sender->SendData(count, count); sender_stage.ForwardAll(receiver); receiver->ReadBytes(count); } TEST_P(TlsConnectStream, ReplaceRecordLayer) { StartConnect(); client_->SetServerKeyBits(server_->server_key_bits()); // BadPrSocket installs an IO layer that crashes when the SSL layer attempts // to read or write. BadPrSocket bad_layer_client(client_); BadPrSocket bad_layer_server(server_); // StagedRecords installs a handler for unprotected data from the socket, and // captures that data. StagedRecords client_stage(client_); StagedRecords server_stage(server_); // Both peers should refuse application data from epoch 0. RefuseApplicationData(client_, 0); RefuseApplicationData(server_, 0); // This first call forwards nothing, but it causes the client to handshake, // which starts things off. This stages the ClientHello as a result. server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); // This processes the ClientHello and stages the first server flight. client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); RefuseApplicationData(server_, 1); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { // Process the server flight and the client is done. server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); } else { server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); RefuseApplicationData(client_, 1); client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); } CheckKeys(); // Reading and writing application data should work. SendForwardReceive(client_, client_stage, server_); SendForwardReceive(server_, server_stage, client_); } static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) { return SECWouldBlock; } TEST_P(TlsConnectStream, ReplaceRecordLayerAsyncLateAuth) { StartConnect(); client_->SetServerKeyBits(server_->server_key_bits()); BadPrSocket bad_layer_client(client_); BadPrSocket bad_layer_server(server_); StagedRecords client_stage(client_); StagedRecords server_stage(server_); client_->SetAuthCertificateCallback(AuthCompleteBlock); server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); // Prior to TLS 1.3, the client sends its second flight immediately. But in // TLS 1.3, a client won't send a Finished until it is happy with the server // certificate. So blocking certificate validation causes the client to send // nothing. if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { ASSERT_TRUE(client_stage.empty()); // Client should have stopped reading when it saw the Certificate message, // so it will be reading handshake epoch, and writing cleartext. client_->CheckEpochs(2, 0); // Server should be reading handshake, and writing application data. server_->CheckEpochs(2, 3); // Handshake again and the client will read the remainder of the server's // flight, but it will remain blocked. client_->Handshake(); ASSERT_TRUE(client_stage.empty()); EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); } else { // In prior versions, the client's second flight is always sent. ASSERT_FALSE(client_stage.empty()); } // Now declare the certificate good. EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); client_->Handshake(); ASSERT_FALSE(client_stage.empty()); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); } else { client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); } CheckKeys(); // Reading and writing application data should work. SendForwardReceive(client_, client_stage, server_); } TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerAsyncPostHandshake) { StartConnect(); client_->SetServerKeyBits(server_->server_key_bits()); BadPrSocket bad_layer_client(client_); BadPrSocket bad_layer_server(server_); StagedRecords client_stage(client_); StagedRecords server_stage(server_); client_->SetAuthCertificateCallback(AuthCompleteBlock); server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); ASSERT_TRUE(client_stage.empty()); client_->Handshake(); ASSERT_TRUE(client_stage.empty()); EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); // Now declare the certificate good. EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); client_->Handshake(); ASSERT_FALSE(client_stage.empty()); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); } else { client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); } CheckKeys(); // Reading and writing application data should work. SendForwardReceive(client_, client_stage, server_); // Post-handshake messages should work here. EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), nullptr, 0)); SendForwardReceive(server_, server_stage, client_); } // This test ensures that data is correctly forwarded when the handshake is // resumed after asynchronous server certificate authentication, when // SSL_AuthCertificateComplete() is called. The logic for resuming the // handshake involves a different code path than the usual one, so this test // exercises that code fully. TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerAsyncEarlyAuth) { StartConnect(); client_->SetServerKeyBits(server_->server_key_bits()); BadPrSocket bad_layer_client(client_); BadPrSocket bad_layer_server(server_); StagedRecords client_stage(client_); StagedRecords server_stage(server_); client_->SetAuthCertificateCallback(AuthCompleteBlock); server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); // Send a partial flight on to the client. // This includes enough to trigger the certificate callback. server_stage.ForwardPartial(client_); EXPECT_TRUE(client_stage.empty()); // Declare the certificate good. EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); client_->Handshake(); EXPECT_TRUE(client_stage.empty()); // Send the remainder of the server flight. PRBool pending = PR_FALSE; EXPECT_EQ(SECSuccess, SSLInt_HasPendingHandshakeData(client_->ssl_fd(), &pending)); EXPECT_EQ(PR_TRUE, pending); EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); CheckKeys(); SendForwardReceive(server_, server_stage, client_); } TEST_P(TlsConnectStream, ForwardDataFromWrongEpoch) { const uint8_t data[] = {1}; Connect(); uint16_t next_epoch; if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { EXPECT_EQ(SECFailure, SSL_RecordLayerData(client_->ssl_fd(), 2, ssl_ct_application_data, data, sizeof(data))); EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()) << "Passing data from an old epoch is rejected"; next_epoch = 4; } else { // Prior to TLS 1.3, the epoch is only updated once during the handshake. next_epoch = 2; } EXPECT_EQ(SECFailure, SSL_RecordLayerData(client_->ssl_fd(), next_epoch, ssl_ct_application_data, data, sizeof(data))); EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()) << "Passing data from a future epoch blocks"; } TEST_F(TlsConnectStreamTls13, ForwardInvalidData) { const uint8_t data[1] = {0}; EnsureTlsSetup(); // Zero-length data. EXPECT_EQ(SECFailure, SSL_RecordLayerData(client_->ssl_fd(), 0, ssl_ct_application_data, data, 0)); EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); // NULL data. EXPECT_EQ(SECFailure, SSL_RecordLayerData(client_->ssl_fd(), 0, ssl_ct_application_data, nullptr, 1)); EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); } TEST_F(TlsConnectDatagram13, ForwardDataDtls) { EnsureTlsSetup(); const uint8_t data[1] = {0}; EXPECT_EQ(SECFailure, SSL_RecordLayerData(client_->ssl_fd(), 0, ssl_ct_application_data, data, sizeof(data))); EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); } } // namespace nss_test