Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Bug 1336855 - Use shared_ptr for DummyPRSocket, r=franziskus
--HG--
extra : amend_source : aa61b67592456ceb4966a9560845abb7d9d27a4b
extra : histedit_source : 4443bcadd1c2d69ea22eee3c0c185bc26518a07e
  • Loading branch information
martinthomson committed Feb 6, 2017
1 parent e5302b4 commit 461a681
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 111 deletions.
2 changes: 2 additions & 0 deletions gtests/common/scoped_ptrs.h
Expand Up @@ -25,6 +25,7 @@ struct ScopedDelete {
}
void operator()(PK11SlotInfo* slot) { PK11_FreeSlot(slot); }
void operator()(PK11SymKey* key) { PK11_FreeSymKey(key); }
void operator()(PRFileDesc* fd) { PR_Close(fd); }
void operator()(SECAlgorithmID* id) { SECOID_DestroyAlgorithmID(id, true); }
void operator()(SECItem* item) { SECITEM_FreeItem(item, true); }
void operator()(SECKEYPublicKey* key) { SECKEY_DestroyPublicKey(key); }
Expand All @@ -49,6 +50,7 @@ SCOPED(CERTCertList);
SCOPED(CERTSubjectPublicKeyInfo);
SCOPED(PK11SlotInfo);
SCOPED(PK11SymKey);
SCOPED(PRFileDesc);
SCOPED(SECAlgorithmID);
SCOPED(SECItem);
SCOPED(SECKEYPublicKey);
Expand Down
1 change: 0 additions & 1 deletion gtests/ssl_gtest/ssl_hrr_unittest.cc
Expand Up @@ -172,7 +172,6 @@ TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) {
// Here we replace the TLS server with one that does TLS 1.2 only.
// This will happily send the client a TLS 1.2 ServerHello.
server_.reset(new TlsAgent(server_->name(), TlsAgent::SERVER, mode_));
server_->Init();
client_->SetPeer(server_);
server_->SetPeer(client_);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
Expand Down
56 changes: 21 additions & 35 deletions gtests/ssl_gtest/test_io.cc
Expand Up @@ -52,10 +52,8 @@ class Packet : public DataBuffer {

// Implementation of NSPR methods
static PRStatus DummyClose(PRFileDesc *f) {
DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
f->secret = nullptr;
f->dtor(f);
delete io;
return PR_SUCCESS;
}

Expand All @@ -74,7 +72,7 @@ static int32_t DummyAvailable(PRFileDesc *f) {
return -1;
}

int64_t DummyAvailable64(PRFileDesc *f) {
static int64_t DummyAvailable64(PRFileDesc *f) {
UNIMPLEMENTED();
return -1;
}
Expand Down Expand Up @@ -265,10 +263,7 @@ void DummyPrSocket::SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
}

void DummyPrSocket::Reset() {
if (peer_) {
peer_->SetPeer(nullptr);
peer_ = nullptr;
}
peer_.reset();
while (!input_.empty()) {
Packet *front = input_.front();
input_.pop();
Expand Down Expand Up @@ -296,21 +291,16 @@ static const struct PRIOMethods DummyMethods = {
DummyReserved, DummyReserved,
DummyReserved, DummyReserved};

PRFileDesc *DummyPrSocket::CreateFD(const std::string &name, Mode mode) {
PRFileDesc *DummyPrSocket::CreateFD() {
if (test_fd_identity == PR_INVALID_IO_LAYER) {
test_fd_identity = PR_GetUniqueIdentity("testtransportadapter");
}

PRFileDesc *fd = (PR_CreateIOLayerStub(test_fd_identity, &DummyMethods));
fd->secret = reinterpret_cast<PRFilePrivate *>(new DummyPrSocket(name, mode));

PRFileDesc *fd = PR_CreateIOLayerStub(test_fd_identity, &DummyMethods);
fd->secret = reinterpret_cast<PRFilePrivate *>(this);
return fd;
}

DummyPrSocket *DummyPrSocket::GetAdapter(PRFileDesc *fd) {
return reinterpret_cast<DummyPrSocket *>(fd->secret);
}

void DummyPrSocket::PacketReceived(const DataBuffer &packet) {
input_.push(new Packet(packet));
}
Expand Down Expand Up @@ -367,7 +357,8 @@ int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) {
}

int32_t DummyPrSocket::Write(const void *buf, int32_t length) {
if (!peer_ || !writeable_) {
auto peer = peer_.lock();
if (!peer || !writeable_) {
PR_SetError(PR_IO_ERROR, 0);
return -1;
}
Expand All @@ -383,14 +374,14 @@ int32_t DummyPrSocket::Write(const void *buf, int32_t length) {
case PacketFilter::CHANGE:
LOG("Original packet: " << packet);
LOG("Filtered packet: " << filtered);
peer_->PacketReceived(filtered);
peer->PacketReceived(filtered);
break;
case PacketFilter::DROP:
LOG("Droppped packet: " << packet);
break;
case PacketFilter::KEEP:
LOGV("Packet: " << packet);
peer_->PacketReceived(packet);
peer->PacketReceived(packet);
break;
}
// libssl can't handle it if this reports something other than the length
Expand Down Expand Up @@ -419,35 +410,31 @@ Poller::~Poller() {
}
}

void Poller::Wait(Event event, DummyPrSocket *adapter, PollTarget *target,
PollCallback cb) {
auto it = waiters_.find(adapter);
Waiter *waiter;
void Poller::Wait(Event event, std::shared_ptr<DummyPrSocket> &adapter,
PollTarget *target, PollCallback cb) {
assert(event < TIMER_EVENT);
if (event >= TIMER_EVENT) return;

std::unique_ptr<Waiter> waiter;
auto it = waiters_.find(adapter);
if (it == waiters_.end()) {
waiter = new Waiter(adapter);
waiter.reset(new Waiter(adapter));
} else {
waiter = it->second;
waiter = std::move(it->second);
}

assert(event < TIMER_EVENT);
if (event >= TIMER_EVENT) return;

waiter->targets_[event] = target;
waiter->callbacks_[event] = cb;
waiters_[adapter] = waiter;
waiters_[adapter] = std::move(waiter);
}

void Poller::Cancel(Event event, DummyPrSocket *adapter) {
void Poller::Cancel(Event event, std::shared_ptr<DummyPrSocket> &adapter) {
auto it = waiters_.find(adapter);
Waiter *waiter;

if (it == waiters_.end()) {
return;
}

waiter = it->second;

auto &waiter = it->second;
waiter->targets_[event] = nullptr;
waiter->callbacks_[event] = nullptr;

Expand All @@ -456,7 +443,6 @@ void Poller::Cancel(Event event, DummyPrSocket *adapter) {
if (waiter->callbacks_[i]) return;
}

delete waiter;
waiters_.erase(adapter);
}

Expand Down Expand Up @@ -489,7 +475,7 @@ bool Poller::Poll() {
}

for (auto it = waiters_.begin(); it != waiters_.end(); ++it) {
Waiter *waiter = it->second;
auto &waiter = it->second;

if (waiter->callbacks_[READABLE_EVENT]) {
if (waiter->io_->readable()) {
Expand Down
37 changes: 17 additions & 20 deletions gtests/ssl_gtest/test_io.h
Expand Up @@ -50,14 +50,19 @@ inline std::ostream& operator<<(std::ostream& os, Mode m) {

class DummyPrSocket {
public:
DummyPrSocket(const std::string& name, Mode mode)
: name_(name),
mode_(mode),
peer_(),
input_(),
filter_(nullptr),
writeable_(true) {}
~DummyPrSocket();

static PRFileDesc* CreateFD(const std::string& name,
Mode mode); // Returns an FD.
static DummyPrSocket* GetAdapter(PRFileDesc* fd);
PRFileDesc* CreateFD();

DummyPrSocket* peer() const { return peer_; }
void SetPeer(DummyPrSocket* peer) { peer_ = peer; }
std::weak_ptr<DummyPrSocket>& peer() { return peer_; }
void SetPeer(const std::shared_ptr<DummyPrSocket>& peer) { peer_ = peer; }
void SetPacketFilter(std::shared_ptr<PacketFilter> filter);
// Drops peer, packet filter and any outstanding packets.
void Reset();
Expand All @@ -72,17 +77,9 @@ class DummyPrSocket {
bool readable() const { return !input_.empty(); }

private:
DummyPrSocket(const std::string& name, Mode mode)
: name_(name),
mode_(mode),
peer_(nullptr),
input_(),
filter_(nullptr),
writeable_(true) {}

const std::string name_;
Mode mode_;
DummyPrSocket* peer_;
std::weak_ptr<DummyPrSocket> peer_;
std::queue<Packet*> input_;
std::shared_ptr<PacketFilter> filter_;
bool writeable_;
Expand Down Expand Up @@ -111,9 +108,9 @@ class Poller {
PollCallback callback_;
};

void Wait(Event event, DummyPrSocket* adapter, PollTarget* target,
PollCallback cb);
void Cancel(Event event, DummyPrSocket* adapter);
void Wait(Event event, std::shared_ptr<DummyPrSocket>& adapter,
PollTarget* target, PollCallback cb);
void Cancel(Event event, std::shared_ptr<DummyPrSocket>& adapter);
void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb,
Timer** handle);
bool Poll();
Expand All @@ -124,13 +121,13 @@ class Poller {

class Waiter {
public:
Waiter(DummyPrSocket* io) : io_(io) {
Waiter(std::shared_ptr<DummyPrSocket> io) : io_(io) {
memset(&callbacks_[0], 0, sizeof(callbacks_));
}

void WaitFor(Event event, PollCallback callback);

DummyPrSocket* io_;
std::shared_ptr<DummyPrSocket> io_;
PollTarget* targets_[TIMER_EVENT];
PollCallback callbacks_[TIMER_EVENT];
};
Expand All @@ -143,7 +140,7 @@ class Poller {
};

static Poller* instance;
std::map<DummyPrSocket*, Waiter*> waiters_;
std::map<std::shared_ptr<DummyPrSocket>, std::unique_ptr<Waiter>> waiters_;
std::priority_queue<Timer*, std::vector<Timer*>, TimerComparator> timers_;
};

Expand Down
52 changes: 26 additions & 26 deletions gtests/ssl_gtest/tls_agent.cc
Expand Up @@ -46,11 +46,10 @@ const std::string TlsAgent::kServerDsa = "dsa";
TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode)
: name_(name),
mode_(mode),
role_(role),
server_key_bits_(0),
pr_fd_(nullptr),
adapter_(nullptr),
adapter_(new DummyPrSocket(role_str(), mode)),
ssl_fd_(nullptr),
role_(role),
state_(STATE_INIT),
timer_handle_(nullptr),
falsestart_enabled_(false),
Expand Down Expand Up @@ -78,16 +77,12 @@ TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode)
}

TlsAgent::~TlsAgent() {
if (adapter_) {
Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
// The adapter is closed when the FD closes.
}
if (timer_handle_) {
timer_handle_->Cancel();
}

if (pr_fd_) {
PR_Close(pr_fd_);
if (adapter_) {
Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
}

if (ssl_fd_) {
Expand Down Expand Up @@ -143,15 +138,22 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
// Don't set up twice
if (ssl_fd_) return true;

ScopedPRFileDesc dummy_fd(adapter_->CreateFD());
EXPECT_NE(nullptr, dummy_fd);
if (!dummy_fd) {
return false;
}
if (adapter_->mode() == STREAM) {
ssl_fd_ = SSL_ImportFD(modelSocket, pr_fd_);
ssl_fd_ = SSL_ImportFD(modelSocket, dummy_fd.get());
} else {
ssl_fd_ = DTLS_ImportFD(modelSocket, pr_fd_);
ssl_fd_ = DTLS_ImportFD(modelSocket, dummy_fd.get());
}

EXPECT_NE(nullptr, ssl_fd_);
if (!ssl_fd_) return false;
pr_fd_ = nullptr;
if (!ssl_fd_) {
return false;
}
dummy_fd.release(); // Now subsumed by ssl_fd_.

SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
EXPECT_EQ(SECSuccess, rv);
Expand Down Expand Up @@ -795,7 +797,12 @@ void TlsAgent::StartRenegotiate() {

void TlsAgent::SendDirect(const DataBuffer& buf) {
LOG("Send Direct " << buf);
adapter_->peer()->PacketReceived(buf);
auto peer = adapter_->peer().lock();
if (peer) {
peer->PacketReceived(buf);
} else {
LOG("Send Direct peer absent");
}
}

static bool ErrorIsNonFatal(PRErrorCode code) {
Expand Down Expand Up @@ -894,29 +901,22 @@ void TlsAgentTestBase::SetUp() {
}

void TlsAgentTestBase::TearDown() {
delete agent_;
agent_ = nullptr;
SSL_ClearSessionCache();
SSL_ShutdownServerSessionIDCache();
}

void TlsAgentTestBase::Reset(const std::string& server_name) {
delete agent_;
Init(server_name);
}

void TlsAgentTestBase::Init(const std::string& server_name) {
agent_ =
agent_.reset(
new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name,
role_, mode_);
agent_->Init();
fd_ = DummyPrSocket::CreateFD(agent_->role_str(), mode_);
agent_->adapter()->SetPeer(DummyPrSocket::GetAdapter(fd_));
role_, mode_));
agent_->adapter()->SetPeer(sink_adapter_);
agent_->StartConnect();
}

void TlsAgentTestBase::EnsureInit() {
if (!agent_) {
Init();
Reset();
}
const std::vector<SSLNamedGroup> groups = {
ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
Expand Down

0 comments on commit 461a681

Please sign in to comment.