Skip to content

Commit

Permalink
Bug 1336851 - Use shared_ptr for packet filters, r=franziskus
Browse files Browse the repository at this point in the history
--HG--
extra : histedit_source : d4e7e31aaa2a5003ddf953a8498c604de5bd1975
  • Loading branch information
martinthomson committed Feb 5, 2017
1 parent c3e914a commit d4d4b45
Show file tree
Hide file tree
Showing 26 changed files with 330 additions and 298 deletions.
4 changes: 2 additions & 2 deletions gtests/ssl_gtest/ssl_agent_unittest.cc
Expand Up @@ -158,8 +158,8 @@ TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenWrite) {
SSL_LIBRARY_VERSION_TLS_1_3);
agent_->StartConnect();
agent_->Set0RttEnabled(true);
auto filter =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeClientHello);
auto filter = std::make_shared<TlsInspectorRecordHandshakeMessage>(
kTlsHandshakeClientHello);
agent_->SetPacketFilter(filter);
PRInt32 rv = PR_Write(agent_->ssl_fd(), k0RttData, strlen(k0RttData));
EXPECT_EQ(-1, rv);
Expand Down
45 changes: 25 additions & 20 deletions gtests/ssl_gtest/ssl_auth_unittest.cc
Expand Up @@ -77,9 +77,9 @@ TEST_P(TlsConnectGeneric, ClientAuthBigRsa) {
}

// Offset is the position in the captured buffer where the signature sits.
static void CheckSigScheme(TlsInspectorRecordHandshakeMessage* capture,
size_t offset, TlsAgent* peer,
uint16_t expected_scheme, size_t expected_size) {
static void CheckSigScheme(
std::shared_ptr<TlsInspectorRecordHandshakeMessage>& capture, size_t offset,
TlsAgent* peer, uint16_t expected_scheme, size_t expected_size) {
EXPECT_LT(offset + 2U, capture->buffer().len());

uint32_t scheme = 0;
Expand All @@ -95,8 +95,8 @@ static void CheckSigScheme(TlsInspectorRecordHandshakeMessage* capture,
// in the default certificate.
TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) {
EnsureTlsSetup();
auto capture_ske =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
auto capture_ske = std::make_shared<TlsInspectorRecordHandshakeMessage>(
kTlsHandshakeServerKeyExchange);
server_->SetPacketFilter(capture_ske);
Connect();
CheckKeys();
Expand All @@ -114,7 +114,8 @@ TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) {
TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) {
EnsureTlsSetup();
auto capture_cert_verify =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeCertificateVerify);
std::make_shared<TlsInspectorRecordHandshakeMessage>(
kTlsHandshakeCertificateVerify);
client_->SetPacketFilter(capture_cert_verify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
Expand All @@ -127,7 +128,8 @@ TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) {
TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) {
Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048);
auto capture_cert_verify =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeCertificateVerify);
std::make_shared<TlsInspectorRecordHandshakeMessage>(
kTlsHandshakeCertificateVerify);
client_->SetPacketFilter(capture_cert_verify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
Expand Down Expand Up @@ -186,10 +188,11 @@ class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter {
// supported_signature_algorithms in the CertificateRequest message.
TEST_P(TlsConnectTls12, ClientAuthNoSigAlgsFallback) {
EnsureTlsSetup();
auto filter = new TlsZeroCertificateRequestSigAlgsFilter();
auto filter = std::make_shared<TlsZeroCertificateRequestSigAlgsFilter>();
server_->SetPacketFilter(filter);
auto capture_cert_verify =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeCertificateVerify);
std::make_shared<TlsInspectorRecordHandshakeMessage>(
kTlsHandshakeCertificateVerify);
client_->SetPacketFilter(capture_cert_verify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
Expand Down Expand Up @@ -339,7 +342,7 @@ TEST_P(TlsConnectPre12, SignatureAlgorithmNoOverlapEcdsa) {
// The signature_algorithms extension is mandatory in TLS 1.3.
TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) {
client_->SetPacketFilter(
new TlsExtensionDropper(ssl_signature_algorithms_xtn));
std::make_shared<TlsExtensionDropper>(ssl_signature_algorithms_xtn));
ConnectExpectFail();
client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION);
Expand All @@ -349,7 +352,7 @@ TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) {
// only fails when the Finished is checked.
TEST_P(TlsConnectTls12, SignatureAlgorithmDrop) {
client_->SetPacketFilter(
new TlsExtensionDropper(ssl_signature_algorithms_xtn));
std::make_shared<TlsExtensionDropper>(ssl_signature_algorithms_xtn));
ConnectExpectFail();
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
Expand Down Expand Up @@ -490,9 +493,11 @@ static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) {
// processed by the client, SSL_AuthCertificateComplete() is called.
TEST_F(TlsConnectDatagram13, AuthCompleteBeforeFinished) {
client_->SetAuthCertificateCallback(AuthCompleteBlock);
server_->SetPacketFilter(new BeforeFinished13(client_, server_, [this]() {
EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
}));
server_->SetPacketFilter(
std::make_shared<BeforeFinished13>(client_, server_, [this]() {
EXPECT_EQ(SECSuccess,
SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
}));
Connect();
}

Expand Down Expand Up @@ -520,7 +525,7 @@ TEST_F(TlsConnectDatagram13, AuthCompleteAfterFinished) {

TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
client_->EnableFalseStart();
server_->SetPacketFilter(new BeforeFinished(
server_->SetPacketFilter(std::make_shared<BeforeFinished>(
client_, server_,
[this]() { EXPECT_TRUE(client_->can_falsestart_hook_called()); },
[this]() {
Expand All @@ -536,7 +541,7 @@ TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) {
client_->EnableFalseStart();
client_->SetAuthCertificateCallback(AuthCompleteBlock);
server_->SetPacketFilter(new BeforeFinished(
server_->SetPacketFilter(std::make_shared<BeforeFinished>(
client_, server_,
[]() {
// Do nothing before CCS
Expand Down Expand Up @@ -583,7 +588,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());

// The client should send nothing from here on.
client_->SetPacketFilter(new EnforceNoActivity());
client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());

Expand All @@ -610,7 +615,7 @@ TEST_P(TlsConnectTls13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());

// The client will send nothing until AuthCertificateComplete is called.
client_->SetPacketFilter(new EnforceNoActivity());
client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());

Expand Down Expand Up @@ -744,8 +749,8 @@ TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigServer) {

TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigClient) {
Reset(certificate_);
TlsExtensionCapture* capture =
new TlsExtensionCapture(ssl_signature_algorithms_xtn);
auto capture =
std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn);
client_->SetPacketFilter(capture);
TestSignatureSchemeConfig(client_);

Expand Down
7 changes: 4 additions & 3 deletions gtests/ssl_gtest/ssl_cert_ext_unittest.cc
Expand Up @@ -185,8 +185,8 @@ TEST_P(TlsConnectGenericPre13, OcspMangled) {
server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData));

static const uint8_t val[] = {1};
auto replacer = new TlsExtensionReplacer(ssl_cert_status_xtn,
DataBuffer(val, sizeof(val)));
auto replacer = std::make_shared<TlsExtensionReplacer>(
ssl_cert_status_xtn, DataBuffer(val, sizeof(val)));
server_->SetPacketFilter(replacer);
ConnectExpectFail();
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
Expand All @@ -197,7 +197,8 @@ TEST_P(TlsConnectGeneric, OcspSuccess) {
EnsureTlsSetup();
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
auto capture_ocsp = new TlsExtensionCapture(ssl_cert_status_xtn);
auto capture_ocsp =
std::make_shared<TlsExtensionCapture>(ssl_cert_status_xtn);
server_->SetPacketFilter(capture_ocsp);

// The value should be available during the AuthCertificateCallback
Expand Down
2 changes: 1 addition & 1 deletion gtests/ssl_gtest/ssl_damage_unittest.cc
Expand Up @@ -49,7 +49,7 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) {
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetPacketFilter(new AfterRecordN(
server_->SetPacketFilter(std::make_shared<AfterRecordN>(
server_, client_,
0, // ServerHello.
[this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); }));
Expand Down
59 changes: 32 additions & 27 deletions gtests/ssl_gtest/ssl_dhe_unittest.cc
Expand Up @@ -31,12 +31,13 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) {
EnsureTlsSetup();
client_->ConfigNamedGroups(kAllDHEGroups);

auto groups_capture = new TlsExtensionCapture(ssl_supported_groups_xtn);
auto shares_capture = new TlsExtensionCapture(ssl_tls13_key_share_xtn);
std::vector<PacketFilter*> captures;
captures.push_back(groups_capture);
captures.push_back(shares_capture);
client_->SetPacketFilter(new ChainedPacketFilter(captures));
auto groups_capture =
std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
auto shares_capture =
std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture,
shares_capture};
client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));

Connect();

Expand All @@ -60,12 +61,13 @@ TEST_P(TlsConnectGeneric, ConnectFfdheClient) {
EnableOnlyDheCiphers();
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
auto groups_capture = new TlsExtensionCapture(ssl_supported_groups_xtn);
auto shares_capture = new TlsExtensionCapture(ssl_tls13_key_share_xtn);
std::vector<PacketFilter*> captures;
captures.push_back(groups_capture);
captures.push_back(shares_capture);
client_->SetPacketFilter(new ChainedPacketFilter(captures));
auto groups_capture =
std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
auto shares_capture =
std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture,
shares_capture};
client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));

Connect();

Expand Down Expand Up @@ -126,7 +128,7 @@ TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) {
EnableOnlyDheCiphers();
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
server_->SetPacketFilter(new TlsDheServerKeyExchangeDamager());
server_->SetPacketFilter(std::make_shared<TlsDheServerKeyExchangeDamager>());

ConnectExpectFail();

Expand Down Expand Up @@ -249,8 +251,9 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY {

class TlsDheSkeChangeYClient : public TlsDheSkeChangeY {
public:
TlsDheSkeChangeYClient(ChangeYTo change,
const TlsDheSkeChangeYServer* server_filter)
TlsDheSkeChangeYClient(
ChangeYTo change,
std::shared_ptr<const TlsDheSkeChangeYServer> server_filter)
: TlsDheSkeChangeY(change), server_filter_(server_filter) {}

protected:
Expand All @@ -266,7 +269,7 @@ class TlsDheSkeChangeYClient : public TlsDheSkeChangeY {
}

private:
const TlsDheSkeChangeYServer* server_filter_;
std::shared_ptr<const TlsDheSkeChangeYServer> server_filter_;
};

/* This matrix includes: mode (stream/datagram), TLS version, what change to
Expand All @@ -289,7 +292,8 @@ TEST_P(TlsDamageDHYTest, DamageServerY) {
SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
}
TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
server_->SetPacketFilter(new TlsDheSkeChangeYServer(change, true));
server_->SetPacketFilter(
std::make_shared<TlsDheSkeChangeYServer>(change, true));

ConnectExpectFail();
if (change == TlsDheSkeChangeY::kYZeroPad) {
Expand All @@ -314,13 +318,14 @@ TEST_P(TlsDamageDHYTest, DamageClientY) {
SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
}
// The filter on the server is required to capture the prime.
TlsDheSkeChangeYServer* server_filter =
new TlsDheSkeChangeYServer(TlsDheSkeChangeY::kYZero, false);
auto server_filter =
std::make_shared<TlsDheSkeChangeYServer>(TlsDheSkeChangeY::kYZero, false);
server_->SetPacketFilter(server_filter);

// The client filter does the damage.
TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
client_->SetPacketFilter(new TlsDheSkeChangeYClient(change, server_filter));
client_->SetPacketFilter(
std::make_shared<TlsDheSkeChangeYClient>(change, server_filter));

ConnectExpectFail();
if (change == TlsDheSkeChangeY::kYZeroPad) {
Expand Down Expand Up @@ -378,7 +383,7 @@ class TlsDheSkeMakePEven : public TlsHandshakeFilter {
// Even without requiring named groups, an even value for p is bad news.
TEST_P(TlsConnectGenericPre13, MakeDhePEven) {
EnableOnlyDheCiphers();
server_->SetPacketFilter(new TlsDheSkeMakePEven());
server_->SetPacketFilter(std::make_shared<TlsDheSkeMakePEven>());

ConnectExpectFail();

Expand Down Expand Up @@ -409,7 +414,7 @@ class TlsDheSkeZeroPadP : public TlsHandshakeFilter {
// Zero padding only causes signature failure.
TEST_P(TlsConnectGenericPre13, PadDheP) {
EnableOnlyDheCiphers();
server_->SetPacketFilter(new TlsDheSkeZeroPadP());
server_->SetPacketFilter(std::make_shared<TlsDheSkeZeroPadP>());

ConnectExpectFail();

Expand Down Expand Up @@ -533,11 +538,11 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) {
Reset();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
EnableOnlyDheCiphers();
TlsExtensionCapture* clientCapture =
new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn);
auto clientCapture =
std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
client_->SetPacketFilter(clientCapture);
TlsExtensionCapture* serverCapture =
new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn);
auto serverCapture =
std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
server_->SetPacketFilter(serverCapture);
ExpectResumption(RESUME_TICKET);
Connect();
Expand Down Expand Up @@ -599,7 +604,7 @@ TEST_P(TlsConnectGenericPre13, InvalidDERSignatureFfdhe) {
const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ffdhe_2048};
client_->ConfigNamedGroups(client_groups);

server_->SetPacketFilter(new TlsDheSkeChangeSignature(
server_->SetPacketFilter(std::make_shared<TlsDheSkeChangeSignature>(
version_, kBogusDheSignature, sizeof(kBogusDheSignature)));

ConnectExpectFail();
Expand Down
16 changes: 8 additions & 8 deletions gtests/ssl_gtest/ssl_drop_unittest.cc
Expand Up @@ -21,13 +21,13 @@ extern "C" {
namespace nss_test {

TEST_P(TlsConnectDatagram, DropClientFirstFlightOnce) {
client_->SetPacketFilter(new SelectiveDropFilter(0x1));
client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1));
Connect();
SendReceive();
}

TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) {
server_->SetPacketFilter(new SelectiveDropFilter(0x1));
server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1));
Connect();
SendReceive();
}
Expand All @@ -36,32 +36,32 @@ TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) {
// flights that they send. Note: In DTLS 1.3, the shorter handshake means that
// this will also drop some application data, so we can't call SendReceive().
TEST_P(TlsConnectDatagram, DropAllFirstTransmissions) {
client_->SetPacketFilter(new SelectiveDropFilter(0x15));
server_->SetPacketFilter(new SelectiveDropFilter(0x5));
client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x15));
server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x5));
Connect();
}

// This drops the server's first flight three times.
TEST_P(TlsConnectDatagram, DropServerFirstFlightThrice) {
server_->SetPacketFilter(new SelectiveDropFilter(0x7));
server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x7));
Connect();
}

// This drops the client's second flight once
TEST_P(TlsConnectDatagram, DropClientSecondFlightOnce) {
client_->SetPacketFilter(new SelectiveDropFilter(0x2));
client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2));
Connect();
}

// This drops the client's second flight three times.
TEST_P(TlsConnectDatagram, DropClientSecondFlightThrice) {
client_->SetPacketFilter(new SelectiveDropFilter(0xe));
client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe));
Connect();
}

// This drops the server's second flight three times.
TEST_P(TlsConnectDatagram, DropServerSecondFlightThrice) {
server_->SetPacketFilter(new SelectiveDropFilter(0xe));
server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe));
Connect();
}

Expand Down

0 comments on commit d4d4b45

Please sign in to comment.