Commit 461a6811 authored by Martin Thomson's avatar Martin Thomson

Bug 1336855 - Use shared_ptr for DummyPRSocket, r=franziskus

--HG--
extra : amend_source : aa61b67592456ceb4966a9560845abb7d9d27a4b
extra : histedit_source : 4443bcadd1c2d69ea22eee3c0c185bc26518a07e
parent e5302b40
......@@ -25,6 +25,7 @@ struct ScopedDelete {
}
void operator()(PK11SlotInfo* slot) { PK11_FreeSlot(slot); }
void operator()(PK11SymKey* key) { PK11_FreeSymKey(key); }
void operator()(PRFileDesc* fd) { PR_Close(fd); }
void operator()(SECAlgorithmID* id) { SECOID_DestroyAlgorithmID(id, true); }
void operator()(SECItem* item) { SECITEM_FreeItem(item, true); }
void operator()(SECKEYPublicKey* key) { SECKEY_DestroyPublicKey(key); }
......@@ -49,6 +50,7 @@ SCOPED(CERTCertList);
SCOPED(CERTSubjectPublicKeyInfo);
SCOPED(PK11SlotInfo);
SCOPED(PK11SymKey);
SCOPED(PRFileDesc);
SCOPED(SECAlgorithmID);
SCOPED(SECItem);
SCOPED(SECKEYPublicKey);
......
......@@ -172,7 +172,6 @@ TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) {
// Here we replace the TLS server with one that does TLS 1.2 only.
// This will happily send the client a TLS 1.2 ServerHello.
server_.reset(new TlsAgent(server_->name(), TlsAgent::SERVER, mode_));
server_->Init();
client_->SetPeer(server_);
server_->SetPeer(client_);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
......
......@@ -52,10 +52,8 @@ class Packet : public DataBuffer {
// Implementation of NSPR methods
static PRStatus DummyClose(PRFileDesc *f) {
DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
f->secret = nullptr;
f->dtor(f);
delete io;
return PR_SUCCESS;
}
......@@ -74,7 +72,7 @@ static int32_t DummyAvailable(PRFileDesc *f) {
return -1;
}
int64_t DummyAvailable64(PRFileDesc *f) {
static int64_t DummyAvailable64(PRFileDesc *f) {
UNIMPLEMENTED();
return -1;
}
......@@ -265,10 +263,7 @@ void DummyPrSocket::SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
}
void DummyPrSocket::Reset() {
if (peer_) {
peer_->SetPeer(nullptr);
peer_ = nullptr;
}
peer_.reset();
while (!input_.empty()) {
Packet *front = input_.front();
input_.pop();
......@@ -296,21 +291,16 @@ static const struct PRIOMethods DummyMethods = {
DummyReserved, DummyReserved,
DummyReserved, DummyReserved};
PRFileDesc *DummyPrSocket::CreateFD(const std::string &name, Mode mode) {
PRFileDesc *DummyPrSocket::CreateFD() {
if (test_fd_identity == PR_INVALID_IO_LAYER) {
test_fd_identity = PR_GetUniqueIdentity("testtransportadapter");
}
PRFileDesc *fd = (PR_CreateIOLayerStub(test_fd_identity, &DummyMethods));
fd->secret = reinterpret_cast<PRFilePrivate *>(new DummyPrSocket(name, mode));
PRFileDesc *fd = PR_CreateIOLayerStub(test_fd_identity, &DummyMethods);
fd->secret = reinterpret_cast<PRFilePrivate *>(this);
return fd;
}
DummyPrSocket *DummyPrSocket::GetAdapter(PRFileDesc *fd) {
return reinterpret_cast<DummyPrSocket *>(fd->secret);
}
void DummyPrSocket::PacketReceived(const DataBuffer &packet) {
input_.push(new Packet(packet));
}
......@@ -367,7 +357,8 @@ int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) {
}
int32_t DummyPrSocket::Write(const void *buf, int32_t length) {
if (!peer_ || !writeable_) {
auto peer = peer_.lock();
if (!peer || !writeable_) {
PR_SetError(PR_IO_ERROR, 0);
return -1;
}
......@@ -383,14 +374,14 @@ int32_t DummyPrSocket::Write(const void *buf, int32_t length) {
case PacketFilter::CHANGE:
LOG("Original packet: " << packet);
LOG("Filtered packet: " << filtered);
peer_->PacketReceived(filtered);
peer->PacketReceived(filtered);
break;
case PacketFilter::DROP:
LOG("Droppped packet: " << packet);
break;
case PacketFilter::KEEP:
LOGV("Packet: " << packet);
peer_->PacketReceived(packet);
peer->PacketReceived(packet);
break;
}
// libssl can't handle it if this reports something other than the length
......@@ -419,35 +410,31 @@ Poller::~Poller() {
}
}
void Poller::Wait(Event event, DummyPrSocket *adapter, PollTarget *target,
PollCallback cb) {
auto it = waiters_.find(adapter);
Waiter *waiter;
void Poller::Wait(Event event, std::shared_ptr<DummyPrSocket> &adapter,
PollTarget *target, PollCallback cb) {
assert(event < TIMER_EVENT);
if (event >= TIMER_EVENT) return;
std::unique_ptr<Waiter> waiter;
auto it = waiters_.find(adapter);
if (it == waiters_.end()) {
waiter = new Waiter(adapter);
waiter.reset(new Waiter(adapter));
} else {
waiter = it->second;
waiter = std::move(it->second);
}
assert(event < TIMER_EVENT);
if (event >= TIMER_EVENT) return;
waiter->targets_[event] = target;
waiter->callbacks_[event] = cb;
waiters_[adapter] = waiter;
waiters_[adapter] = std::move(waiter);
}
void Poller::Cancel(Event event, DummyPrSocket *adapter) {
void Poller::Cancel(Event event, std::shared_ptr<DummyPrSocket> &adapter) {
auto it = waiters_.find(adapter);
Waiter *waiter;
if (it == waiters_.end()) {
return;
}
waiter = it->second;
auto &waiter = it->second;
waiter->targets_[event] = nullptr;
waiter->callbacks_[event] = nullptr;
......@@ -456,7 +443,6 @@ void Poller::Cancel(Event event, DummyPrSocket *adapter) {
if (waiter->callbacks_[i]) return;
}
delete waiter;
waiters_.erase(adapter);
}
......@@ -489,7 +475,7 @@ bool Poller::Poll() {
}
for (auto it = waiters_.begin(); it != waiters_.end(); ++it) {
Waiter *waiter = it->second;
auto &waiter = it->second;
if (waiter->callbacks_[READABLE_EVENT]) {
if (waiter->io_->readable()) {
......
......@@ -50,14 +50,19 @@ inline std::ostream& operator<<(std::ostream& os, Mode m) {
class DummyPrSocket {
public:
DummyPrSocket(const std::string& name, Mode mode)
: name_(name),
mode_(mode),
peer_(),
input_(),
filter_(nullptr),
writeable_(true) {}
~DummyPrSocket();
static PRFileDesc* CreateFD(const std::string& name,
Mode mode); // Returns an FD.
static DummyPrSocket* GetAdapter(PRFileDesc* fd);
PRFileDesc* CreateFD();
DummyPrSocket* peer() const { return peer_; }
void SetPeer(DummyPrSocket* peer) { peer_ = peer; }
std::weak_ptr<DummyPrSocket>& peer() { return peer_; }
void SetPeer(const std::shared_ptr<DummyPrSocket>& peer) { peer_ = peer; }
void SetPacketFilter(std::shared_ptr<PacketFilter> filter);
// Drops peer, packet filter and any outstanding packets.
void Reset();
......@@ -72,17 +77,9 @@ class DummyPrSocket {
bool readable() const { return !input_.empty(); }
private:
DummyPrSocket(const std::string& name, Mode mode)
: name_(name),
mode_(mode),
peer_(nullptr),
input_(),
filter_(nullptr),
writeable_(true) {}
const std::string name_;
Mode mode_;
DummyPrSocket* peer_;
std::weak_ptr<DummyPrSocket> peer_;
std::queue<Packet*> input_;
std::shared_ptr<PacketFilter> filter_;
bool writeable_;
......@@ -111,9 +108,9 @@ class Poller {
PollCallback callback_;
};
void Wait(Event event, DummyPrSocket* adapter, PollTarget* target,
PollCallback cb);
void Cancel(Event event, DummyPrSocket* adapter);
void Wait(Event event, std::shared_ptr<DummyPrSocket>& adapter,
PollTarget* target, PollCallback cb);
void Cancel(Event event, std::shared_ptr<DummyPrSocket>& adapter);
void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb,
Timer** handle);
bool Poll();
......@@ -124,13 +121,13 @@ class Poller {
class Waiter {
public:
Waiter(DummyPrSocket* io) : io_(io) {
Waiter(std::shared_ptr<DummyPrSocket> io) : io_(io) {
memset(&callbacks_[0], 0, sizeof(callbacks_));
}
void WaitFor(Event event, PollCallback callback);
DummyPrSocket* io_;
std::shared_ptr<DummyPrSocket> io_;
PollTarget* targets_[TIMER_EVENT];
PollCallback callbacks_[TIMER_EVENT];
};
......@@ -143,7 +140,7 @@ class Poller {
};
static Poller* instance;
std::map<DummyPrSocket*, Waiter*> waiters_;
std::map<std::shared_ptr<DummyPrSocket>, std::unique_ptr<Waiter>> waiters_;
std::priority_queue<Timer*, std::vector<Timer*>, TimerComparator> timers_;
};
......
......@@ -46,11 +46,10 @@ const std::string TlsAgent::kServerDsa = "dsa";
TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode)
: name_(name),
mode_(mode),
role_(role),
server_key_bits_(0),
pr_fd_(nullptr),
adapter_(nullptr),
adapter_(new DummyPrSocket(role_str(), mode)),
ssl_fd_(nullptr),
role_(role),
state_(STATE_INIT),
timer_handle_(nullptr),
falsestart_enabled_(false),
......@@ -78,16 +77,12 @@ TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode)
}
TlsAgent::~TlsAgent() {
if (adapter_) {
Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
// The adapter is closed when the FD closes.
}
if (timer_handle_) {
timer_handle_->Cancel();
}
if (pr_fd_) {
PR_Close(pr_fd_);
if (adapter_) {
Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
}
if (ssl_fd_) {
......@@ -143,15 +138,22 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
// Don't set up twice
if (ssl_fd_) return true;
ScopedPRFileDesc dummy_fd(adapter_->CreateFD());
EXPECT_NE(nullptr, dummy_fd);
if (!dummy_fd) {
return false;
}
if (adapter_->mode() == STREAM) {
ssl_fd_ = SSL_ImportFD(modelSocket, pr_fd_);
ssl_fd_ = SSL_ImportFD(modelSocket, dummy_fd.get());
} else {
ssl_fd_ = DTLS_ImportFD(modelSocket, pr_fd_);
ssl_fd_ = DTLS_ImportFD(modelSocket, dummy_fd.get());
}
EXPECT_NE(nullptr, ssl_fd_);
if (!ssl_fd_) return false;
pr_fd_ = nullptr;
if (!ssl_fd_) {
return false;
}
dummy_fd.release(); // Now subsumed by ssl_fd_.
SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
EXPECT_EQ(SECSuccess, rv);
......@@ -795,7 +797,12 @@ void TlsAgent::StartRenegotiate() {
void TlsAgent::SendDirect(const DataBuffer& buf) {
LOG("Send Direct " << buf);
adapter_->peer()->PacketReceived(buf);
auto peer = adapter_->peer().lock();
if (peer) {
peer->PacketReceived(buf);
} else {
LOG("Send Direct peer absent");
}
}
static bool ErrorIsNonFatal(PRErrorCode code) {
......@@ -894,29 +901,22 @@ void TlsAgentTestBase::SetUp() {
}
void TlsAgentTestBase::TearDown() {
delete agent_;
agent_ = nullptr;
SSL_ClearSessionCache();
SSL_ShutdownServerSessionIDCache();
}
void TlsAgentTestBase::Reset(const std::string& server_name) {
delete agent_;
Init(server_name);
}
void TlsAgentTestBase::Init(const std::string& server_name) {
agent_ =
agent_.reset(
new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name,
role_, mode_);
agent_->Init();
fd_ = DummyPrSocket::CreateFD(agent_->role_str(), mode_);
agent_->adapter()->SetPeer(DummyPrSocket::GetAdapter(fd_));
role_, mode_));
agent_->adapter()->SetPeer(sink_adapter_);
agent_->StartConnect();
}
void TlsAgentTestBase::EnsureInit() {
if (!agent_) {
Init();
Reset();
}
const std::vector<SSLNamedGroup> groups = {
ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
......
......@@ -77,16 +77,6 @@ class TlsAgent : public PollTarget {
TlsAgent(const std::string& name, Role role, Mode mode);
virtual ~TlsAgent();
bool Init() {
pr_fd_ = DummyPrSocket::CreateFD(role_str(), mode_);
if (!pr_fd_) return false;
adapter_ = DummyPrSocket::GetAdapter(pr_fd_);
if (!adapter_) return false;
return true;
}
void SetPeer(std::shared_ptr<TlsAgent>& peer) {
adapter_->SetPeer(peer->adapter_);
}
......@@ -189,7 +179,7 @@ class TlsAgent : public PollTarget {
static const char* state_str(State state) { return states[state]; }
PRFileDesc* ssl_fd() const { return ssl_fd_; }
DummyPrSocket* adapter() { return adapter_; }
std::shared_ptr<DummyPrSocket>& adapter() { return adapter_; }
bool is_compressed() const {
return info_.compressionMethod != ssl_compression_null;
......@@ -352,11 +342,10 @@ class TlsAgent : public PollTarget {
const std::string name_;
Mode mode_;
Role role_;
uint16_t server_key_bits_;
PRFileDesc* pr_fd_;
DummyPrSocket* adapter_;
std::shared_ptr<DummyPrSocket> adapter_;
PRFileDesc* ssl_fd_;
Role role_;
State state_;
Poller::Timer* timer_handle_;
bool falsestart_enabled_;
......@@ -391,12 +380,11 @@ class TlsAgentTestBase : public ::testing::Test {
static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll;
TlsAgentTestBase(TlsAgent::Role role, Mode mode)
: agent_(nullptr), fd_(nullptr), role_(role), mode_(mode) {}
~TlsAgentTestBase() {
if (fd_) {
PR_Close(fd_);
}
}
: agent_(nullptr),
role_(role),
mode_(mode),
sink_adapter_(new DummyPrSocket("sink", mode)) {}
virtual ~TlsAgentTestBase() {}
void SetUp();
void TearDown();
......@@ -430,10 +418,11 @@ class TlsAgentTestBase : public ::testing::Test {
void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
int32_t error_code = 0);
TlsAgent* agent_;
PRFileDesc* fd_;
std::unique_ptr<TlsAgent> agent_;
TlsAgent::Role role_;
Mode mode_;
// This adapter is here just to accept packets from this agent.
std::shared_ptr<DummyPrSocket> sink_adapter_;
};
class TlsAgentTest : public TlsAgentTestBase,
......
......@@ -188,9 +188,6 @@ void TlsConnectTestBase::TearDown() {
}
void TlsConnectTestBase::Init() {
EXPECT_TRUE(client_->Init());
EXPECT_TRUE(server_->Init());
client_->SetPeer(server_);
server_->SetPeer(client_);
......@@ -498,10 +495,6 @@ void TlsConnectTestBase::EnsureModelSockets() {
server_model_.reset(
new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_));
}
// Initialise agents.
ASSERT_TRUE(client_model_->Init());
ASSERT_TRUE(server_model_->Init());
}
void TlsConnectTestBase::CheckAlpn(const std::string& val) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment