Skip to content

Commit

Permalink
Bug 1471126 - Provide a callback for traffic secrets, r=ekr
Browse files Browse the repository at this point in the history
Summary:
Provide updated secrets to a callback function as soon as those secrets are available.

Reviewers: ekr

Reviewed By: ekr

Bug #: 1471126

Differential Revision: https://phabricator.services.mozilla.com/D1824

--HG--
extra : rebase_source : d562e9b41279b115e92e5b229481df53066e6136
extra : amend_source : 517c3f17f4f7623c203a4ca27ab25824d017c0c4
extra : histedit_source : 1b6e508c139cf560c39ad74bd8c8a7046edb9a9e
  • Loading branch information
martinthomson committed Feb 17, 2019
1 parent d0478ef commit 4a3d54c
Show file tree
Hide file tree
Showing 13 changed files with 327 additions and 75 deletions.
1 change: 1 addition & 0 deletions gtests/ssl_gtest/manifest.mn
Expand Up @@ -36,6 +36,7 @@ CPPSRCS = \
ssl_loopback_unittest.cc \
ssl_misc_unittest.cc \
ssl_record_unittest.cc \
ssl_recordsep_unittest.cc \
ssl_recordsize_unittest.cc \
ssl_resumption_unittest.cc \
ssl_renegotiation_unittest.cc \
Expand Down
1 change: 1 addition & 0 deletions gtests/ssl_gtest/ssl_gtest.gyp
Expand Up @@ -37,6 +37,7 @@
'ssl_loopback_unittest.cc',
'ssl_misc_unittest.cc',
'ssl_record_unittest.cc',
'ssl_recordsep_unittest.cc',
'ssl_recordsize_unittest.cc',
'ssl_resumption_unittest.cc',
'ssl_renegotiation_unittest.cc',
Expand Down
165 changes: 165 additions & 0 deletions gtests/ssl_gtest/ssl_recordsep_unittest.cc
@@ -0,0 +1,165 @@
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at http://mozilla.org/MPL/2.0/. */

#include "secerr.h"
#include "ssl.h"
#include "sslerr.h"
#include "sslproto.h"

extern "C" {
// This is not something that should make you happy.
#include "libssl_internals.h"
}

#include "gtest_utils.h"
#include "nss_scoped_ptrs.h"
#include "tls_connect.h"
#include "tls_filter.h"
#include "tls_parser.h"

namespace nss_test {

class HandshakeSecretTracker {
public:
HandshakeSecretTracker(const std::shared_ptr<TlsAgent>& agent,
uint16_t first_read_epoch, uint16_t first_write_epoch)
: agent_(agent),
next_read_epoch_(first_read_epoch),
next_write_epoch_(first_write_epoch) {
EXPECT_EQ(SECSuccess,
SSL_SecretCallback(agent_->ssl_fd(),
HandshakeSecretTracker::SecretCb, this));
}

void CheckComplete() const {
EXPECT_EQ(0, next_read_epoch_);
EXPECT_EQ(0, next_write_epoch_);
}

private:
static void SecretCb(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir,
PK11SymKey* secret, void* arg) {
HandshakeSecretTracker* t = reinterpret_cast<HandshakeSecretTracker*>(arg);
t->SecretUpdated(epoch, dir, secret);
}

void SecretUpdated(PRUint16 epoch, SSLSecretDirection dir,
PK11SymKey* secret) {
if (g_ssl_gtest_verbose) {
std::cerr << agent_->role_str() << ": secret callback for "
<< (dir == ssl_secret_read ? "read" : "write") << " epoch "
<< epoch << std::endl;
}

EXPECT_TRUE(secret);
uint16_t* p;
if (dir == ssl_secret_read) {
p = &next_read_epoch_;
} else {
ASSERT_EQ(ssl_secret_write, dir);
p = &next_write_epoch_;
}
EXPECT_EQ(*p, epoch);
switch (*p) {
case 1: // 1 == 0-RTT, next should be handshake.
case 2: // 2 == handshake, next should be application data.
(*p)++;
break;

case 3: // 3 == application data, there should be no more.
// Use 0 as a sentinel value.
*p = 0;
break;

default:
ADD_FAILURE() << "Unexpected next epoch: " << *p;
}
}

std::shared_ptr<TlsAgent> agent_;
uint16_t next_read_epoch_;
uint16_t next_write_epoch_;
};

TEST_F(TlsConnectTest, HandshakeSecrets) {
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
EnsureTlsSetup();

HandshakeSecretTracker c(client_, 2, 2);
HandshakeSecretTracker s(server_, 2, 2);

Connect();
SendReceive();

c.CheckComplete();
s.CheckComplete();
}

TEST_F(TlsConnectTest, ZeroRttSecrets) {
SetupForZeroRtt();

HandshakeSecretTracker c(client_, 2, 1);
HandshakeSecretTracker s(server_, 1, 2);

client_->Set0RttEnabled(true);
server_->Set0RttEnabled(true);
ExpectResumption(RESUME_TICKET);
ZeroRttSendReceive(true, true);
Handshake();
ExpectEarlyDataAccepted(true);
CheckConnected();
SendReceive();

c.CheckComplete();
s.CheckComplete();
}

class KeyUpdateTracker {
public:
KeyUpdateTracker(const std::shared_ptr<TlsAgent>& agent,
bool expect_read_secret)
: agent_(agent), expect_read_secret_(expect_read_secret), called_(false) {
EXPECT_EQ(SECSuccess, SSL_SecretCallback(agent_->ssl_fd(),
KeyUpdateTracker::SecretCb, this));
}

void CheckCalled() const { EXPECT_TRUE(called_); }

private:
static void SecretCb(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir,
PK11SymKey* secret, void* arg) {
KeyUpdateTracker* t = reinterpret_cast<KeyUpdateTracker*>(arg);
t->SecretUpdated(epoch, dir, secret);
}

void SecretUpdated(PRUint16 epoch, SSLSecretDirection dir,
PK11SymKey* secret) {
EXPECT_EQ(4U, epoch);
EXPECT_EQ(expect_read_secret_, dir == ssl_secret_read);
EXPECT_TRUE(secret);
called_ = true;
}

std::shared_ptr<TlsAgent> agent_;
bool expect_read_secret_;
bool called_;
};

TEST_F(TlsConnectTest, KeyUpdateSecrets) {
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
Connect();
// The update is to the client write secret; the server read secret.
KeyUpdateTracker c(client_, false);
KeyUpdateTracker s(server_, true);
EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE));
SendReceive(50);
SendReceive(60);
CheckEpochs(4, 3);
c.CheckCalled();
s.CheckCalled();
}

} // namespace nss_test
4 changes: 2 additions & 2 deletions lib/ssl/dtls13con.c
Expand Up @@ -482,7 +482,7 @@ dtls13_HandleAck(sslSocket *ss, sslBuffer *databuf)
* for the holddown period to process retransmitted Finisheds.
*/
if (!ss->sec.isServer && (ss->ssl3.hs.ws == idle_handshake)) {
ssl_CipherSpecReleaseByEpoch(ss, CipherSpecRead,
ssl_CipherSpecReleaseByEpoch(ss, ssl_secret_read,
TrafficKeyHandshake);
}
}
Expand All @@ -509,6 +509,6 @@ dtls13_HolddownTimerCb(sslSocket *ss)
{
SSL_TRC(10, ("%d: SSL3[%d]: holddown timer fired",
SSL_GETPID(), ss->fd));
ssl_CipherSpecReleaseByEpoch(ss, CipherSpecRead, TrafficKeyHandshake);
ssl_CipherSpecReleaseByEpoch(ss, ssl_secret_read, TrafficKeyHandshake);
ssl_ClearPRCList(&ss->ssl3.hs.dtlsRcvdHandshake, NULL);
}
22 changes: 11 additions & 11 deletions lib/ssl/ssl3con.c
Expand Up @@ -1394,14 +1394,14 @@ ssl3_ComputeDHKeyHash(sslSocket *ss, SSLHashType hashAlg, SSL3Hashes *hashes,
}

static SECStatus
ssl3_SetupPendingCipherSpec(sslSocket *ss, CipherSpecDirection direction,
ssl3_SetupPendingCipherSpec(sslSocket *ss, SSLSecretDirection direction,
const ssl3CipherSuiteDef *suiteDef,
ssl3CipherSpec **specp)
{
ssl3CipherSpec *spec;
const ssl3CipherSpec *prev;

prev = (direction == CipherSpecWrite) ? ss->ssl3.cwSpec : ss->ssl3.crSpec;
prev = (direction == ssl_secret_write) ? ss->ssl3.cwSpec : ss->ssl3.crSpec;
if (prev->epoch == PR_UINT16_MAX) {
PORT_SetError(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED);
return SECFailure;
Expand All @@ -1417,7 +1417,7 @@ ssl3_SetupPendingCipherSpec(sslSocket *ss, CipherSpecDirection direction,

spec->epoch = prev->epoch + 1;
spec->nextSeqNum = 0;
if (IS_DTLS(ss) && direction == CipherSpecRead) {
if (IS_DTLS(ss) && direction == ssl_secret_read) {
dtls_InitRecvdRecords(&spec->recvdRecords);
}
ssl_SetSpecVersions(ss, spec);
Expand Down Expand Up @@ -1471,12 +1471,12 @@ ssl3_SetupBothPendingCipherSpecs(sslSocket *ss)
ss->ssl3.hs.kea_def = &kea_defs[kea];
PORT_Assert(ss->ssl3.hs.kea_def->kea == kea);

rv = ssl3_SetupPendingCipherSpec(ss, CipherSpecRead, suiteDef,
rv = ssl3_SetupPendingCipherSpec(ss, ssl_secret_read, suiteDef,
&ss->ssl3.prSpec);
if (rv != SECSuccess) {
goto loser;
}
rv = ssl3_SetupPendingCipherSpec(ss, CipherSpecWrite, suiteDef,
rv = ssl3_SetupPendingCipherSpec(ss, ssl_secret_write, suiteDef,
&ss->ssl3.pwSpec);
if (rv != SECSuccess) {
goto loser;
Expand Down Expand Up @@ -1727,7 +1727,7 @@ ssl3_InitPendingContexts(sslSocket *ss, ssl3CipherSpec *spec)

spec->cipher = (SSLCipher)PK11_CipherOp;
encMechanism = ssl3_Alg2Mech(calg);
encMode = (spec->direction == CipherSpecWrite) ? CKA_ENCRYPT : CKA_DECRYPT;
encMode = (spec->direction == ssl_secret_write) ? CKA_ENCRYPT : CKA_DECRYPT;

/*
* build the context
Expand Down Expand Up @@ -2215,7 +2215,7 @@ ssl_ProtectRecord(sslSocket *ss, ssl3CipherSpec *cwSpec, SSLContentType ct,
unsigned int lenOffset;
SECStatus rv;

PORT_Assert(cwSpec->direction == CipherSpecWrite);
PORT_Assert(cwSpec->direction == ssl_secret_write);
PORT_Assert(SSL_BUFFER_LEN(wrBuf) == 0);
PORT_Assert(cwSpec->cipherDef->max_records <= RECORD_SEQ_MAX);

Expand Down Expand Up @@ -12183,7 +12183,7 @@ ssl3_UnprotectRecord(sslSocket *ss,
unsigned int hashBytes = MAX_MAC_LENGTH + 1;
SECStatus rv;

PORT_Assert(spec->direction == CipherSpecRead);
PORT_Assert(spec->direction == ssl_secret_read);

good = ~0U;
minLength = spec->macDef->mac_size;
Expand Down Expand Up @@ -12429,7 +12429,7 @@ ssl3_GetCipherSpec(sslSocket *ss, SSL3Ciphertext *cText)
}
if (ss->version >= SSL_LIBRARY_VERSION_TLS_1_3) {
/* Try to find the cipher spec. */
newSpec = ssl_FindCipherSpecByEpoch(ss, CipherSpecRead,
newSpec = ssl_FindCipherSpecByEpoch(ss, ssl_secret_read,
epoch);
if (newSpec != NULL) {
return newSpec;
Expand Down Expand Up @@ -12694,8 +12694,8 @@ ssl3_InitState(sslSocket *ss)

ssl_GetSpecWriteLock(ss);
PR_INIT_CLIST(&ss->ssl3.hs.cipherSpecs);
rv = ssl_SetupNullCipherSpec(ss, CipherSpecRead);
rv |= ssl_SetupNullCipherSpec(ss, CipherSpecWrite);
rv = ssl_SetupNullCipherSpec(ss, ssl_secret_read);
rv |= ssl_SetupNullCipherSpec(ss, ssl_secret_write);
ss->ssl3.pwSpec = ss->ssl3.prSpec = NULL;
ssl_ReleaseSpecWriteLock(ss);
if (rv != SECSuccess) {
Expand Down
25 changes: 25 additions & 0 deletions lib/ssl/sslexp.h
Expand Up @@ -511,6 +511,31 @@ typedef SECStatus(PR_CALLBACK *SSLResumptionTokenCallback)(
group, pubKey, pad, notBefore, notAfter, \
out, outlen, maxlen))

/* SSL_SetSecretCallback installs a callback that TLS calls when it installs new
* traffic secrets.
*
* SSLSecretCallback is called with the current epoch and the corresponding
* secret; this matches the epoch used in DTLS 1.3, even if the socket is
* operating in stream mode:
*
* - client_early_traffic_secret corresponds to epoch 1
* - {client|server}_handshake_traffic_secret is epoch 2
* - {client|server}_application_traffic_secret_{N} is epoch 3+N
*
* The callback is invoked separately for read secrets (client secrets on the
* server; server secrets on the client), and write secrets.
*
* This callback is only called if (D)TLS 1.3 is negotiated.
*/
typedef void(PR_CALLBACK *SSLSecretCallback)(
PRFileDesc *fd, PRUint16 epoch, SSLSecretDirection dir, PK11SymKey *secret,
void *arg);

#define SSL_SecretCallback(fd, cb, arg) \
SSL_EXPERIMENTAL_API("SSL_SecretCallback", \
(PRFileDesc * _fd, SSLSecretCallback _cb, void *_arg), \
(fd, cb, arg))

/* Deprecated experimental APIs */
#define SSL_UseAltServerHelloType(fd, enable) SSL_DEPRECATED_EXPERIMENTAL_API

Expand Down
4 changes: 4 additions & 0 deletions lib/ssl/sslimpl.h
Expand Up @@ -994,6 +994,8 @@ struct sslSocketStr {
PRCList extensionHooks;
SSLResumptionTokenCallback resumptionTokenCallback;
void *resumptionTokenContext;
SSLSecretCallback secretCallback;
void *secretCallbackArg;

PRIntervalTime rTimeout; /* timeout for NSPR I/O */
PRIntervalTime wTimeout; /* timeout for NSPR I/O */
Expand Down Expand Up @@ -1742,6 +1744,8 @@ SECStatus SSLExp_GetResumptionTokenInfo(const PRUint8 *tokenData, unsigned int t

SECStatus SSLExp_DestroyResumptionTokenInfo(SSLResumptionTokenInfo *token);

SECStatus SSLExp_SecretCallback(PRFileDesc *fd, SSLSecretCallback cb, void *arg);

#define SSLResumptionTokenVersion 2

SEC_END_PROTOS
Expand Down
14 changes: 7 additions & 7 deletions lib/ssl/sslsecur.c
Expand Up @@ -741,7 +741,7 @@ ssl_SecureShutdown(sslSocket *ss, int nsprHow)
/************************************************************************/

static SECStatus
tls13_CheckKeyUpdate(sslSocket *ss, CipherSpecDirection dir)
tls13_CheckKeyUpdate(sslSocket *ss, SSLSecretDirection dir)
{
PRBool keyUpdate;
ssl3CipherSpec *spec;
Expand All @@ -765,7 +765,7 @@ tls13_CheckKeyUpdate(sslSocket *ss, CipherSpecDirection dir)
* having the write margin larger reduces the number of times that a
* KeyUpdate is sent by a reader. */
ssl_GetSpecReadLock(ss);
if (dir == CipherSpecRead) {
if (dir == ssl_secret_read) {
spec = ss->ssl3.crSpec;
margin = spec->cipherDef->max_records / 8;
} else {
Expand All @@ -781,10 +781,10 @@ tls13_CheckKeyUpdate(sslSocket *ss, CipherSpecDirection dir)

SSL_TRC(5, ("%d: SSL[%d]: automatic key update at %llx for %s cipher spec",
SSL_GETPID(), ss->fd, seqNum,
(dir == CipherSpecRead) ? "read" : "write"));
(dir == ssl_secret_read) ? "read" : "write"));
ssl_GetSSL3HandshakeLock(ss);
rv = tls13_SendKeyUpdate(ss, (dir == CipherSpecRead) ? update_requested : update_not_requested,
dir == CipherSpecWrite /* buffer */);
rv = tls13_SendKeyUpdate(ss, (dir == ssl_secret_read) ? update_requested : update_not_requested,
dir == ssl_secret_write /* buffer */);
ssl_ReleaseSSL3HandshakeLock(ss);
return rv;
}
Expand Down Expand Up @@ -829,7 +829,7 @@ ssl_SecureRecv(sslSocket *ss, unsigned char *buf, int len, int flags)
}
ssl_Release1stHandshakeLock(ss);
} else {
if (tls13_CheckKeyUpdate(ss, CipherSpecRead) != SECSuccess) {
if (tls13_CheckKeyUpdate(ss, ssl_secret_read) != SECSuccess) {
rv = PR_FAILURE;
}
}
Expand Down Expand Up @@ -955,7 +955,7 @@ ssl_SecureSend(sslSocket *ss, const unsigned char *buf, int len, int flags)
}

if (ss->firstHsDone) {
if (tls13_CheckKeyUpdate(ss, CipherSpecWrite) != SECSuccess) {
if (tls13_CheckKeyUpdate(ss, ssl_secret_write) != SECSuccess) {
rv = PR_FAILURE;
goto done;
}
Expand Down

0 comments on commit 4a3d54c

Please sign in to comment.