From 461a68111f0f8e36024e59647c0c688dd73d48bd Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Mon, 6 Feb 2017 15:36:42 +1100 Subject: [PATCH] Bug 1336855 - Use shared_ptr for DummyPRSocket, r=franziskus --HG-- extra : amend_source : aa61b67592456ceb4966a9560845abb7d9d27a4b extra : histedit_source : 4443bcadd1c2d69ea22eee3c0c185bc26518a07e --- gtests/common/scoped_ptrs.h | 2 + gtests/ssl_gtest/ssl_hrr_unittest.cc | 1 - gtests/ssl_gtest/test_io.cc | 56 +++++++++++----------------- gtests/ssl_gtest/test_io.h | 37 +++++++++--------- gtests/ssl_gtest/tls_agent.cc | 52 +++++++++++++------------- gtests/ssl_gtest/tls_agent.h | 33 ++++++---------- gtests/ssl_gtest/tls_connect.cc | 7 ---- 7 files changed, 77 insertions(+), 111 deletions(-) diff --git a/gtests/common/scoped_ptrs.h b/gtests/common/scoped_ptrs.h index 2a96ee94a0..4707393ad1 100644 --- a/gtests/common/scoped_ptrs.h +++ b/gtests/common/scoped_ptrs.h @@ -25,6 +25,7 @@ struct ScopedDelete { } void operator()(PK11SlotInfo* slot) { PK11_FreeSlot(slot); } void operator()(PK11SymKey* key) { PK11_FreeSymKey(key); } + void operator()(PRFileDesc* fd) { PR_Close(fd); } void operator()(SECAlgorithmID* id) { SECOID_DestroyAlgorithmID(id, true); } void operator()(SECItem* item) { SECITEM_FreeItem(item, true); } void operator()(SECKEYPublicKey* key) { SECKEY_DestroyPublicKey(key); } @@ -49,6 +50,7 @@ SCOPED(CERTCertList); SCOPED(CERTSubjectPublicKeyInfo); SCOPED(PK11SlotInfo); SCOPED(PK11SymKey); +SCOPED(PRFileDesc); SCOPED(SECAlgorithmID); SCOPED(SECItem); SCOPED(SECKEYPublicKey); diff --git a/gtests/ssl_gtest/ssl_hrr_unittest.cc b/gtests/ssl_gtest/ssl_hrr_unittest.cc index 3930bdb99d..8d3c126e0b 100644 --- a/gtests/ssl_gtest/ssl_hrr_unittest.cc +++ b/gtests/ssl_gtest/ssl_hrr_unittest.cc @@ -172,7 +172,6 @@ TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) { // Here we replace the TLS server with one that does TLS 1.2 only. // This will happily send the client a TLS 1.2 ServerHello. server_.reset(new TlsAgent(server_->name(), TlsAgent::SERVER, mode_)); - server_->Init(); client_->SetPeer(server_); server_->SetPeer(client_); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, diff --git a/gtests/ssl_gtest/test_io.cc b/gtests/ssl_gtest/test_io.cc index 7ac12eb3f4..ac9d71513b 100644 --- a/gtests/ssl_gtest/test_io.cc +++ b/gtests/ssl_gtest/test_io.cc @@ -52,10 +52,8 @@ class Packet : public DataBuffer { // Implementation of NSPR methods static PRStatus DummyClose(PRFileDesc *f) { - DummyPrSocket *io = reinterpret_cast(f->secret); f->secret = nullptr; f->dtor(f); - delete io; return PR_SUCCESS; } @@ -74,7 +72,7 @@ static int32_t DummyAvailable(PRFileDesc *f) { return -1; } -int64_t DummyAvailable64(PRFileDesc *f) { +static int64_t DummyAvailable64(PRFileDesc *f) { UNIMPLEMENTED(); return -1; } @@ -265,10 +263,7 @@ void DummyPrSocket::SetPacketFilter(std::shared_ptr filter) { } void DummyPrSocket::Reset() { - if (peer_) { - peer_->SetPeer(nullptr); - peer_ = nullptr; - } + peer_.reset(); while (!input_.empty()) { Packet *front = input_.front(); input_.pop(); @@ -296,21 +291,16 @@ static const struct PRIOMethods DummyMethods = { DummyReserved, DummyReserved, DummyReserved, DummyReserved}; -PRFileDesc *DummyPrSocket::CreateFD(const std::string &name, Mode mode) { +PRFileDesc *DummyPrSocket::CreateFD() { if (test_fd_identity == PR_INVALID_IO_LAYER) { test_fd_identity = PR_GetUniqueIdentity("testtransportadapter"); } - PRFileDesc *fd = (PR_CreateIOLayerStub(test_fd_identity, &DummyMethods)); - fd->secret = reinterpret_cast(new DummyPrSocket(name, mode)); - + PRFileDesc *fd = PR_CreateIOLayerStub(test_fd_identity, &DummyMethods); + fd->secret = reinterpret_cast(this); return fd; } -DummyPrSocket *DummyPrSocket::GetAdapter(PRFileDesc *fd) { - return reinterpret_cast(fd->secret); -} - void DummyPrSocket::PacketReceived(const DataBuffer &packet) { input_.push(new Packet(packet)); } @@ -367,7 +357,8 @@ int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) { } int32_t DummyPrSocket::Write(const void *buf, int32_t length) { - if (!peer_ || !writeable_) { + auto peer = peer_.lock(); + if (!peer || !writeable_) { PR_SetError(PR_IO_ERROR, 0); return -1; } @@ -383,14 +374,14 @@ int32_t DummyPrSocket::Write(const void *buf, int32_t length) { case PacketFilter::CHANGE: LOG("Original packet: " << packet); LOG("Filtered packet: " << filtered); - peer_->PacketReceived(filtered); + peer->PacketReceived(filtered); break; case PacketFilter::DROP: LOG("Droppped packet: " << packet); break; case PacketFilter::KEEP: LOGV("Packet: " << packet); - peer_->PacketReceived(packet); + peer->PacketReceived(packet); break; } // libssl can't handle it if this reports something other than the length @@ -419,35 +410,31 @@ Poller::~Poller() { } } -void Poller::Wait(Event event, DummyPrSocket *adapter, PollTarget *target, - PollCallback cb) { - auto it = waiters_.find(adapter); - Waiter *waiter; +void Poller::Wait(Event event, std::shared_ptr &adapter, + PollTarget *target, PollCallback cb) { + assert(event < TIMER_EVENT); + if (event >= TIMER_EVENT) return; + std::unique_ptr waiter; + auto it = waiters_.find(adapter); if (it == waiters_.end()) { - waiter = new Waiter(adapter); + waiter.reset(new Waiter(adapter)); } else { - waiter = it->second; + waiter = std::move(it->second); } - assert(event < TIMER_EVENT); - if (event >= TIMER_EVENT) return; - waiter->targets_[event] = target; waiter->callbacks_[event] = cb; - waiters_[adapter] = waiter; + waiters_[adapter] = std::move(waiter); } -void Poller::Cancel(Event event, DummyPrSocket *adapter) { +void Poller::Cancel(Event event, std::shared_ptr &adapter) { auto it = waiters_.find(adapter); - Waiter *waiter; - if (it == waiters_.end()) { return; } - waiter = it->second; - + auto &waiter = it->second; waiter->targets_[event] = nullptr; waiter->callbacks_[event] = nullptr; @@ -456,7 +443,6 @@ void Poller::Cancel(Event event, DummyPrSocket *adapter) { if (waiter->callbacks_[i]) return; } - delete waiter; waiters_.erase(adapter); } @@ -489,7 +475,7 @@ bool Poller::Poll() { } for (auto it = waiters_.begin(); it != waiters_.end(); ++it) { - Waiter *waiter = it->second; + auto &waiter = it->second; if (waiter->callbacks_[READABLE_EVENT]) { if (waiter->io_->readable()) { diff --git a/gtests/ssl_gtest/test_io.h b/gtests/ssl_gtest/test_io.h index 9e91845fd0..d4a7655691 100644 --- a/gtests/ssl_gtest/test_io.h +++ b/gtests/ssl_gtest/test_io.h @@ -50,14 +50,19 @@ inline std::ostream& operator<<(std::ostream& os, Mode m) { class DummyPrSocket { public: + DummyPrSocket(const std::string& name, Mode mode) + : name_(name), + mode_(mode), + peer_(), + input_(), + filter_(nullptr), + writeable_(true) {} ~DummyPrSocket(); - static PRFileDesc* CreateFD(const std::string& name, - Mode mode); // Returns an FD. - static DummyPrSocket* GetAdapter(PRFileDesc* fd); + PRFileDesc* CreateFD(); - DummyPrSocket* peer() const { return peer_; } - void SetPeer(DummyPrSocket* peer) { peer_ = peer; } + std::weak_ptr& peer() { return peer_; } + void SetPeer(const std::shared_ptr& peer) { peer_ = peer; } void SetPacketFilter(std::shared_ptr filter); // Drops peer, packet filter and any outstanding packets. void Reset(); @@ -72,17 +77,9 @@ class DummyPrSocket { bool readable() const { return !input_.empty(); } private: - DummyPrSocket(const std::string& name, Mode mode) - : name_(name), - mode_(mode), - peer_(nullptr), - input_(), - filter_(nullptr), - writeable_(true) {} - const std::string name_; Mode mode_; - DummyPrSocket* peer_; + std::weak_ptr peer_; std::queue input_; std::shared_ptr filter_; bool writeable_; @@ -111,9 +108,9 @@ class Poller { PollCallback callback_; }; - void Wait(Event event, DummyPrSocket* adapter, PollTarget* target, - PollCallback cb); - void Cancel(Event event, DummyPrSocket* adapter); + void Wait(Event event, std::shared_ptr& adapter, + PollTarget* target, PollCallback cb); + void Cancel(Event event, std::shared_ptr& adapter); void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb, Timer** handle); bool Poll(); @@ -124,13 +121,13 @@ class Poller { class Waiter { public: - Waiter(DummyPrSocket* io) : io_(io) { + Waiter(std::shared_ptr io) : io_(io) { memset(&callbacks_[0], 0, sizeof(callbacks_)); } void WaitFor(Event event, PollCallback callback); - DummyPrSocket* io_; + std::shared_ptr io_; PollTarget* targets_[TIMER_EVENT]; PollCallback callbacks_[TIMER_EVENT]; }; @@ -143,7 +140,7 @@ class Poller { }; static Poller* instance; - std::map waiters_; + std::map, std::unique_ptr> waiters_; std::priority_queue, TimerComparator> timers_; }; diff --git a/gtests/ssl_gtest/tls_agent.cc b/gtests/ssl_gtest/tls_agent.cc index c2c596d57b..912b5fec4a 100644 --- a/gtests/ssl_gtest/tls_agent.cc +++ b/gtests/ssl_gtest/tls_agent.cc @@ -46,11 +46,10 @@ const std::string TlsAgent::kServerDsa = "dsa"; TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode) : name_(name), mode_(mode), + role_(role), server_key_bits_(0), - pr_fd_(nullptr), - adapter_(nullptr), + adapter_(new DummyPrSocket(role_str(), mode)), ssl_fd_(nullptr), - role_(role), state_(STATE_INIT), timer_handle_(nullptr), falsestart_enabled_(false), @@ -78,16 +77,12 @@ TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode) } TlsAgent::~TlsAgent() { - if (adapter_) { - Poller::Instance()->Cancel(READABLE_EVENT, adapter_); - // The adapter is closed when the FD closes. - } if (timer_handle_) { timer_handle_->Cancel(); } - if (pr_fd_) { - PR_Close(pr_fd_); + if (adapter_) { + Poller::Instance()->Cancel(READABLE_EVENT, adapter_); } if (ssl_fd_) { @@ -143,15 +138,22 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { // Don't set up twice if (ssl_fd_) return true; + ScopedPRFileDesc dummy_fd(adapter_->CreateFD()); + EXPECT_NE(nullptr, dummy_fd); + if (!dummy_fd) { + return false; + } if (adapter_->mode() == STREAM) { - ssl_fd_ = SSL_ImportFD(modelSocket, pr_fd_); + ssl_fd_ = SSL_ImportFD(modelSocket, dummy_fd.get()); } else { - ssl_fd_ = DTLS_ImportFD(modelSocket, pr_fd_); + ssl_fd_ = DTLS_ImportFD(modelSocket, dummy_fd.get()); } EXPECT_NE(nullptr, ssl_fd_); - if (!ssl_fd_) return false; - pr_fd_ = nullptr; + if (!ssl_fd_) { + return false; + } + dummy_fd.release(); // Now subsumed by ssl_fd_. SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_); EXPECT_EQ(SECSuccess, rv); @@ -795,7 +797,12 @@ void TlsAgent::StartRenegotiate() { void TlsAgent::SendDirect(const DataBuffer& buf) { LOG("Send Direct " << buf); - adapter_->peer()->PacketReceived(buf); + auto peer = adapter_->peer().lock(); + if (peer) { + peer->PacketReceived(buf); + } else { + LOG("Send Direct peer absent"); + } } static bool ErrorIsNonFatal(PRErrorCode code) { @@ -894,29 +901,22 @@ void TlsAgentTestBase::SetUp() { } void TlsAgentTestBase::TearDown() { - delete agent_; + agent_ = nullptr; SSL_ClearSessionCache(); SSL_ShutdownServerSessionIDCache(); } void TlsAgentTestBase::Reset(const std::string& server_name) { - delete agent_; - Init(server_name); -} - -void TlsAgentTestBase::Init(const std::string& server_name) { - agent_ = + agent_.reset( new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name, - role_, mode_); - agent_->Init(); - fd_ = DummyPrSocket::CreateFD(agent_->role_str(), mode_); - agent_->adapter()->SetPeer(DummyPrSocket::GetAdapter(fd_)); + role_, mode_)); + agent_->adapter()->SetPeer(sink_adapter_); agent_->StartConnect(); } void TlsAgentTestBase::EnsureInit() { if (!agent_) { - Init(); + Reset(); } const std::vector groups = { ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, diff --git a/gtests/ssl_gtest/tls_agent.h b/gtests/ssl_gtest/tls_agent.h index f44e9dc17f..5274dac5f0 100644 --- a/gtests/ssl_gtest/tls_agent.h +++ b/gtests/ssl_gtest/tls_agent.h @@ -77,16 +77,6 @@ class TlsAgent : public PollTarget { TlsAgent(const std::string& name, Role role, Mode mode); virtual ~TlsAgent(); - bool Init() { - pr_fd_ = DummyPrSocket::CreateFD(role_str(), mode_); - if (!pr_fd_) return false; - - adapter_ = DummyPrSocket::GetAdapter(pr_fd_); - if (!adapter_) return false; - - return true; - } - void SetPeer(std::shared_ptr& peer) { adapter_->SetPeer(peer->adapter_); } @@ -189,7 +179,7 @@ class TlsAgent : public PollTarget { static const char* state_str(State state) { return states[state]; } PRFileDesc* ssl_fd() const { return ssl_fd_; } - DummyPrSocket* adapter() { return adapter_; } + std::shared_ptr& adapter() { return adapter_; } bool is_compressed() const { return info_.compressionMethod != ssl_compression_null; @@ -352,11 +342,10 @@ class TlsAgent : public PollTarget { const std::string name_; Mode mode_; + Role role_; uint16_t server_key_bits_; - PRFileDesc* pr_fd_; - DummyPrSocket* adapter_; + std::shared_ptr adapter_; PRFileDesc* ssl_fd_; - Role role_; State state_; Poller::Timer* timer_handle_; bool falsestart_enabled_; @@ -391,12 +380,11 @@ class TlsAgentTestBase : public ::testing::Test { static ::testing::internal::ParamGenerator kTlsRolesAll; TlsAgentTestBase(TlsAgent::Role role, Mode mode) - : agent_(nullptr), fd_(nullptr), role_(role), mode_(mode) {} - ~TlsAgentTestBase() { - if (fd_) { - PR_Close(fd_); - } - } + : agent_(nullptr), + role_(role), + mode_(mode), + sink_adapter_(new DummyPrSocket("sink", mode)) {} + virtual ~TlsAgentTestBase() {} void SetUp(); void TearDown(); @@ -430,10 +418,11 @@ class TlsAgentTestBase : public ::testing::Test { void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, int32_t error_code = 0); - TlsAgent* agent_; - PRFileDesc* fd_; + std::unique_ptr agent_; TlsAgent::Role role_; Mode mode_; + // This adapter is here just to accept packets from this agent. + std::shared_ptr sink_adapter_; }; class TlsAgentTest : public TlsAgentTestBase, diff --git a/gtests/ssl_gtest/tls_connect.cc b/gtests/ssl_gtest/tls_connect.cc index 8813fb72e5..f6269a0af1 100644 --- a/gtests/ssl_gtest/tls_connect.cc +++ b/gtests/ssl_gtest/tls_connect.cc @@ -188,9 +188,6 @@ void TlsConnectTestBase::TearDown() { } void TlsConnectTestBase::Init() { - EXPECT_TRUE(client_->Init()); - EXPECT_TRUE(server_->Init()); - client_->SetPeer(server_); server_->SetPeer(client_); @@ -498,10 +495,6 @@ void TlsConnectTestBase::EnsureModelSockets() { server_model_.reset( new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_)); } - - // Initialise agents. - ASSERT_TRUE(client_model_->Init()); - ASSERT_TRUE(server_model_->Init()); } void TlsConnectTestBase::CheckAlpn(const std::string& val) {