Commit 7755a623 authored by Kevin Jacobs's avatar Kevin Jacobs

Bug 1599514 - Update DTLS 1.3 support to draft-30 r=mt

This patch updates the DTLS 1.3 implementation to draft version 30, including unified header format and sequence number encryption.

Also added are new `SSL_CreateMask` experimental functions.

Differential Revision: https://phabricator.services.mozilla.com/D51014

--HG--
rename : gtests/ssl_gtest/ssl_primitive_unittest.cc => gtests/ssl_gtest/ssl_aead_unittest.cc
extra : moz-landing-system : lando
parent c2f253f4
......@@ -23,6 +23,7 @@ class DataBuffer {
DataBuffer(const DataBuffer& other) : data_(nullptr), len_(0) {
Assign(other);
}
explicit DataBuffer(size_t l) : data_(nullptr), len_(0) { Allocate(l); }
~DataBuffer() { delete[] data_; }
DataBuffer& operator=(const DataBuffer& other) {
......
......@@ -12,6 +12,7 @@
struct ScopedDeleteSSL {
void operator()(SSLAeadContext* ctx) { SSL_DestroyAead(ctx); }
void operator()(SSLMaskingContext* ctx) { SSL_DestroyMaskingContext(ctx); }
void operator()(SSLAntiReplayContext* ctx) {
SSL_ReleaseAntiReplayContext(ctx);
}
......@@ -34,6 +35,7 @@ struct ScopedMaybeDeleteSSL {
SCOPED(SSLAeadContext);
SCOPED(SSLAntiReplayContext);
SCOPED(SSLMaskingContext);
SCOPED(SSLResumptionTokenInfo);
#undef SCOPED
......
......@@ -74,6 +74,11 @@ const uint8_t kTlsFakeChangeCipherSpec[] = {
0x01 // Value
};
const uint8_t kCtDtlsCiphertext = 0x20;
const uint8_t kCtDtlsCiphertextMask = 0xE0;
const uint8_t kCtDtlsCiphertext16bSeqno = 0x08;
const uint8_t kCtDtlsCiphertextLengthPresent = 0x04;
static const uint8_t kTls13PskKe = 0;
static const uint8_t kTls13PskDhKe = 1;
static const uint8_t kTls13PskAuth = 0;
......
......@@ -14,6 +14,7 @@ CSRCS = \
CPPSRCS = \
bloomfilter_unittest.cc \
ssl_0rtt_unittest.cc \
ssl_aead_unittest.cc \
ssl_agent_unittest.cc \
ssl_auth_unittest.cc \
ssl_cert_ext_unittest.cc \
......@@ -35,8 +36,8 @@ CPPSRCS = \
ssl_hrr_unittest.cc \
ssl_keyupdate_unittest.cc \
ssl_loopback_unittest.cc \
ssl_masking_unittest.cc \
ssl_misc_unittest.cc \
ssl_primitive_unittest.cc \
ssl_record_unittest.cc \
ssl_recordsep_unittest.cc \
ssl_recordsize_unittest.cc \
......
......@@ -54,7 +54,7 @@ class AeadTest : public ::testing::Test {
ASSERT_GE(kMaxSize, ciphertext_len);
ASSERT_LT(0U, ciphertext_len);
uint8_t output[kMaxSize];
uint8_t output[kMaxSize] = {0};
unsigned int output_len = 0;
EXPECT_EQ(SECSuccess, SSL_AeadEncrypt(ctx.get(), 0, kAad, sizeof(kAad),
kPlaintext, sizeof(kPlaintext),
......@@ -181,7 +181,7 @@ TEST_F(AeadTest, AeadNoPointer) {
}
TEST_F(AeadTest, AeadAes128Gcm) {
SSLAeadContext *ctxInit;
SSLAeadContext *ctxInit = nullptr;
ASSERT_EQ(SECSuccess,
SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256,
secret_.get(), kLabel, strlen(kLabel), &ctxInit));
......@@ -203,7 +203,7 @@ TEST_F(AeadTest, AeadAes256Gcm) {
}
TEST_F(AeadTest, AeadChaCha20Poly1305) {
SSLAeadContext *ctxInit;
SSLAeadContext *ctxInit = nullptr;
ASSERT_EQ(
SECSuccess,
SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_CHACHA20_POLY1305_SHA256,
......
......@@ -263,6 +263,7 @@ TEST_P(TlsCipherSuiteTest, ResumeCipherSuite) {
TEST_P(TlsCipherSuiteTest, ReadLimit) {
SetupCertificate();
EnableSingleCipher();
TlsSendCipherSpecCapturer capturer(client_);
ConnectAndCheckCipherSuite();
if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
uint64_t last = last_safe_write();
......@@ -295,9 +296,31 @@ TEST_P(TlsCipherSuiteTest, ReadLimit) {
} else {
epoch = 0;
}
TlsAgentTestBase::MakeRecord(variant_, ssl_ct_application_data, version_,
payload, sizeof(payload), &record,
(epoch << 48) | record_limit());
uint64_t seqno = (epoch << 48) | record_limit();
// DTLS 1.3 masks the sequence number
if (variant_ == ssl_variant_datagram &&
version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
auto spec = capturer.spec(1);
ASSERT_NE(nullptr, spec.get());
ASSERT_EQ(3, spec->epoch());
DataBuffer pt, ct;
uint8_t dtls13_ctype = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
kCtDtlsCiphertextLengthPresent;
TlsRecordHeader hdr(variant_, version_, dtls13_ctype, seqno);
pt.Assign(payload, sizeof(payload));
TlsRecordHeader out_hdr;
spec->Protect(hdr, pt, &ct, &out_hdr);
auto rv = out_hdr.Write(&record, 0, ct);
EXPECT_EQ(out_hdr.header_length() + ct.len(), rv);
} else {
TlsAgentTestBase::MakeRecord(variant_, ssl_ct_application_data, version_,
payload, sizeof(payload), &record, seqno);
}
client_->SendDirect(record);
server_->ExpectReadWriteError();
server_->ReadBytes();
......
......@@ -619,55 +619,6 @@ TEST_P(TlsDropDatagram13, ReorderServerEE) {
// The client sends an out of order non-handshake message
// but with the handshake key.
class TlsSendCipherSpecCapturer {
public:
TlsSendCipherSpecCapturer(const std::shared_ptr<TlsAgent>& agent)
: agent_(agent), send_cipher_specs_() {
EXPECT_EQ(SECSuccess,
SSL_SecretCallback(agent_->ssl_fd(), SecretCallback, this));
}
std::shared_ptr<TlsCipherSpec> spec(size_t i) {
if (i >= send_cipher_specs_.size()) {
return nullptr;
}
return send_cipher_specs_[i];
}
private:
static void SecretCallback(PRFileDesc* fd, PRUint16 epoch,
SSLSecretDirection dir, PK11SymKey* secret,
void* arg) {
auto self = static_cast<TlsSendCipherSpecCapturer*>(arg);
std::cerr << self->agent_->role_str() << ": capture " << dir
<< " secret for epoch " << epoch << std::endl;
if (dir == ssl_secret_read) {
return;
}
SSLPreliminaryChannelInfo preinfo;
EXPECT_EQ(SECSuccess,
SSL_GetPreliminaryChannelInfo(self->agent_->ssl_fd(), &preinfo,
sizeof(preinfo)));
EXPECT_EQ(sizeof(preinfo), preinfo.length);
EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite);
SSLCipherSuiteInfo cipherinfo;
EXPECT_EQ(SECSuccess,
SSL_GetCipherSuiteInfo(preinfo.cipherSuite, &cipherinfo,
sizeof(cipherinfo)));
EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length);
auto spec = std::make_shared<TlsCipherSpec>(true, epoch);
EXPECT_TRUE(spec->SetKeys(&cipherinfo, secret));
self->send_cipher_specs_.push_back(spec);
}
std::shared_ptr<TlsAgent> agent_;
std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_;
};
TEST_F(TlsConnectDatagram13, SendOutOfOrderAppWithHandshakeKey) {
StartConnect();
// Capturing secrets means that we can't use decrypting filters on the client.
......@@ -684,8 +635,10 @@ TEST_F(TlsConnectDatagram13, SendOutOfOrderAppWithHandshakeKey) {
auto spec = capturer.spec(0);
ASSERT_NE(nullptr, spec.get());
ASSERT_EQ(2, spec->epoch());
ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002,
ssl_ct_application_data,
uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
kCtDtlsCiphertextLengthPresent;
ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002, dtls13_ct,
DataBuffer(buf, sizeof(buf))));
// Now have the server consume the bogus message.
......@@ -844,7 +797,7 @@ static void GetCipherAndLimit(uint16_t version, uint16_t* cipher,
// a reasonable amount of time.
*cipher = TLS_CHACHA20_POLY1305_SHA256;
// Assume that we are starting with an expected sequence number of 0.
*limit = (1ULL << 29) - 1;
*limit = (1ULL << 15) - 1;
}
}
......@@ -866,14 +819,14 @@ TEST_P(TlsConnectDatagram, MissLotsOfPackets) {
SendReceive();
}
// Send a sequence number of 0xfffffffd and it should be interpreted as that
// Send a sequence number of 0xfffd and it should be interpreted as that
// (and not -3 or UINT64_MAX - 2).
TEST_F(TlsConnectDatagram13, UnderflowSequenceNumber) {
Connect();
// This is only valid if short headers are disabled.
client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_FALSE);
EXPECT_EQ(SECSuccess,
SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), (1ULL << 30) - 3));
SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), (1ULL << 16) - 3));
SendReceive();
}
......@@ -918,9 +871,13 @@ class TlsReplaceFirstRecordWithJunk : public TlsRecordFilter {
return KEEP;
}
replaced_ = true;
TlsRecordHeader out_header(header.variant(), header.version(),
ssl_ct_application_data,
header.sequence_number());
uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
kCtDtlsCiphertextLengthPresent;
TlsRecordHeader out_header(
header.variant(), header.version(),
is_dtls13() ? dtls13_ct : ssl_ct_application_data,
header.sequence_number());
static const uint8_t junk[] = {1, 2, 3, 4};
*offset = out_header.Write(output, *offset, DataBuffer(junk, sizeof(junk)));
......
......@@ -15,6 +15,7 @@
'libssl_internals.c',
'selfencrypt_unittest.cc',
'ssl_0rtt_unittest.cc',
'ssl_aead_unittest.cc',
'ssl_agent_unittest.cc',
'ssl_auth_unittest.cc',
'ssl_cert_ext_unittest.cc',
......@@ -36,8 +37,8 @@
'ssl_hrr_unittest.cc',
'ssl_keyupdate_unittest.cc',
'ssl_loopback_unittest.cc',
'ssl_masking_unittest.cc',
'ssl_misc_unittest.cc',
'ssl_primitive_unittest.cc',
'ssl_record_unittest.cc',
'ssl_recordsep_unittest.cc',
'ssl_recordsize_unittest.cc',
......
This diff is collapsed.
......@@ -185,8 +185,8 @@ TEST_F(TlsConnectStreamTls13, TooLargeRecord) {
class ShortHeaderChecker : public PacketFilter {
public:
PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output) {
// The first octet should be 0b001xxxxx.
EXPECT_EQ(1, input.data()[0] >> 5);
// The first octet should be 0b001000xx.
EXPECT_EQ(kCtDtlsCiphertext, (input.data()[0] & ~0x3));
return KEEP;
}
};
......@@ -205,6 +205,35 @@ TEST_F(TlsConnectDatagram13, ShortHeadersServer) {
SendReceive();
}
// Send a DTLSCiphertext header with a 2B sequence number, and no length.
TEST_F(TlsConnectDatagram13, DtlsAlternateShortHeader) {
StartConnect();
TlsSendCipherSpecCapturer capturer(client_);
Connect();
SendReceive(50);
uint8_t buf[] = {0x32, 0x33, 0x34};
auto spec = capturer.spec(1);
ASSERT_NE(nullptr, spec.get());
ASSERT_EQ(3, spec->epoch());
uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno;
TlsRecordHeader header(variant_, SSL_LIBRARY_VERSION_TLS_1_3, dtls13_ct,
0x0003000000000001);
TlsRecordHeader out_header(header);
DataBuffer msg(buf, sizeof(buf));
msg.Write(msg.len(), ssl_ct_application_data, 1);
DataBuffer ciphertext;
EXPECT_TRUE(spec->Protect(header, msg, &ciphertext, &out_header));
DataBuffer record;
auto rv = out_header.Write(&record, 0, ciphertext);
EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv);
client_->SendDirect(record);
server_->ReadBytes(3);
}
TEST_F(TlsConnectStreamTls13, UnencryptedFinishedMessage) {
StartConnect();
client_->Handshake(); // Send ClientHello
......
......@@ -19,7 +19,8 @@ namespace nss_test {
// This class tracks the maximum size of record that was sent, both cleartext
// and plain. It only tracks records that have an outer type of
// application_data. In TLS 1.3, this includes handshake messages.
// application_data or DTLSCiphertext. In TLS 1.3, this includes handshake
// messages.
class TlsRecordMaximum : public TlsRecordFilter {
public:
TlsRecordMaximum(const std::shared_ptr<TlsAgent>& a)
......@@ -34,7 +35,7 @@ class TlsRecordMaximum : public TlsRecordFilter {
DataBuffer* output) override {
std::cerr << "max: " << record << std::endl;
// Ignore unprotected packets.
if (header.content_type() != ssl_ct_application_data) {
if (!header.is_protected()) {
return KEEP;
}
......@@ -195,9 +196,23 @@ class TlsRecordExpander : public TlsRecordFilter {
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& data,
DataBuffer* changed) {
if (header.content_type() != ssl_ct_application_data) {
return KEEP;
if (!header.is_protected()) {
// We're targeting application_data records. If the record is
// |!is_protected()|, we have two possibilities:
if (!decrypting()) {
// 1) We're not decrypting, in which this case this is truly an
// unencrypted record (Keep).
return KEEP;
}
if (header.content_type() != ssl_ct_application_data) {
// 2) We are decrypting, so is_protected() read the internal
// content_type. If the internal ct IS NOT application_data, then
// it's not our target (Keep).
return KEEP;
}
// Otherwise, the the internal ct IS application_data (Change).
}
changed->Allocate(data.len() + expansion_);
changed->Write(0, data.data(), data.len());
return CHANGE;
......@@ -261,30 +276,31 @@ class TlsRecordPadder : public TlsRecordFilter {
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& record, size_t* offset,
DataBuffer* output) override {
if (header.content_type() != ssl_ct_application_data) {
if (!header.is_protected()) {
return KEEP;
}
uint16_t protection_epoch;
uint8_t inner_content_type;
DataBuffer plaintext;
TlsRecordHeader out_header;
if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
&plaintext)) {
&plaintext, &out_header)) {
return KEEP;
}
if (inner_content_type != ssl_ct_application_data) {
if (decrypting() && inner_content_type != ssl_ct_application_data) {
return KEEP;
}
DataBuffer ciphertext;
bool ok = Protect(spec(protection_epoch), header, inner_content_type,
plaintext, &ciphertext, padding_);
bool ok = Protect(spec(protection_epoch), out_header, inner_content_type,
plaintext, &ciphertext, &out_header, padding_);
EXPECT_TRUE(ok);
if (!ok) {
return KEEP;
}
*offset = header.Write(output, *offset, ciphertext);
*offset = out_header.Write(output, *offset, ciphertext);
return CHANGE;
}
......
......@@ -384,14 +384,16 @@ TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) {
ASSERT_EQ(2U, client_records->count()); // CH, Fin
EXPECT_EQ(ssl_ct_handshake, client_records->record(0).header.content_type());
EXPECT_EQ(ssl_ct_application_data,
client_records->record(1).header.content_type());
EXPECT_EQ(kCtDtlsCiphertext,
(client_records->record(1).header.content_type() &
kCtDtlsCiphertextMask));
ASSERT_EQ(6U, server_records->count()); // SH, EE, CT, CV, Fin, Ack
EXPECT_EQ(ssl_ct_handshake, server_records->record(0).header.content_type());
for (size_t i = 1; i < server_records->count(); ++i) {
EXPECT_EQ(ssl_ct_application_data,
server_records->record(i).header.content_type());
EXPECT_EQ(kCtDtlsCiphertext,
(server_records->record(i).header.content_type() &
kCtDtlsCiphertextMask));
}
}
......@@ -440,8 +442,9 @@ TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) {
ASSERT_EQ(5U, server_records->count()); // SH, EE, CT, CV, Fin
EXPECT_EQ(ssl_ct_handshake, server_records->record(0).header.content_type());
for (size_t i = 1; i < server_records->count(); ++i) {
EXPECT_EQ(ssl_ct_application_data,
server_records->record(i).header.content_type());
EXPECT_EQ(kCtDtlsCiphertext,
(server_records->record(i).header.content_type() &
kCtDtlsCiphertextMask));
}
uint32_t session_id_len = 0;
......
......@@ -1064,21 +1064,28 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) {
bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
uint64_t seq, uint8_t ct,
const DataBuffer& buf) {
LOGV("Encrypting " << buf.len() << " bytes");
// Ensure that we are doing TLS 1.3.
EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3);
TlsRecordHeader header(variant_, expected_version_, ssl_ct_application_data,
seq);
if (variant_ != ssl_variant_datagram) {
ADD_FAILURE();
return false;
}
LOGV("Encrypting " << buf.len() << " bytes");
uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
kCtDtlsCiphertextLengthPresent;
TlsRecordHeader header(variant_, expected_version_, dtls13_ct, seq);
TlsRecordHeader out_header(header);
DataBuffer padded = buf;
padded.Write(padded.len(), ct, 1);
DataBuffer ciphertext;
if (!spec->Protect(header, padded, &ciphertext)) {
if (!spec->Protect(header, padded, &ciphertext, &out_header)) {
return false;
}
DataBuffer record;
auto rv = header.Write(&record, 0, ciphertext);
EXPECT_EQ(header.header_length() + ciphertext.len(), rv);
auto rv = out_header.Write(&record, 0, ciphertext);
EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv);
SendDirect(record);
return true;
}
......@@ -1202,16 +1209,26 @@ void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type,
uint16_t version, const uint8_t* buf,
size_t len, DataBuffer* out,
uint64_t sequence_number) {
// Fixup the content type for DTLSCiphertext
if (variant == ssl_variant_datagram &&
version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
type == ssl_ct_application_data) {
type = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
kCtDtlsCiphertextLengthPresent;
}
size_t index = 0;
index = out->Write(index, type, 1);
if (variant == ssl_variant_stream) {
index = out->Write(index, type, 1);
index = out->Write(index, version, 2);
} else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
type == ssl_ct_application_data) {
(type & kCtDtlsCiphertextMask) == kCtDtlsCiphertext) {
uint32_t epoch = (sequence_number >> 48) & 0x3;
uint32_t seqno = sequence_number & ((1ULL << 30) - 1);
index = out->Write(index, (epoch << 30) | seqno, 4);
index = out->Write(index, type | epoch, 1);
uint32_t seqno = sequence_number & ((1ULL << 16) - 1);
index = out->Write(index, seqno, 2);
} else {
index = out->Write(index, type, 1);
index = out->Write(index, TlsVersionToDtlsVersion(version), 2);
index = out->Write(index, sequence_number >> 32, 4);
index = out->Write(index, sequence_number & PR_UINT32_MAX, 4);
......
This diff is collapsed.
......@@ -12,6 +12,7 @@
#include <set>
#include <vector>
#include "sslt.h"
#include "sslproto.h"
#include "test_io.h"
#include "tls_agent.h"
#include "tls_parser.h"
......@@ -25,6 +26,59 @@ namespace nss_test {
class TlsCipherSpec;
class TlsSendCipherSpecCapturer {
public:
TlsSendCipherSpecCapturer(const std::shared_ptr<TlsAgent>& agent)
: agent_(agent), send_cipher_specs_() {
EXPECT_EQ(SECSuccess,
SSL_SecretCallback(agent_->ssl_fd(), SecretCallback, this));
}
std::shared_ptr<TlsCipherSpec> spec(size_t i) {
if (i >= send_cipher_specs_.size()) {
return nullptr;
}
return send_cipher_specs_[i];
}
private:
static void SecretCallback(PRFileDesc* fd, PRUint16 epoch,
SSLSecretDirection dir, PK11SymKey* secret,
void* arg) {
auto self = static_cast<TlsSendCipherSpecCapturer*>(arg);
std::cerr << self->agent_->role_str() << ": capture " << dir
<< " secret for epoch " << epoch << std::endl;
if (dir == ssl_secret_read) {
return;
}
SSLPreliminaryChannelInfo preinfo;
EXPECT_EQ(SECSuccess,
SSL_GetPreliminaryChannelInfo(self->agent_->ssl_fd(), &preinfo,
sizeof(preinfo)));
EXPECT_EQ(sizeof(preinfo), preinfo.length);
EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite);
// Check the version:
EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version);
ASSERT_GE(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion);
SSLCipherSuiteInfo cipherinfo;
EXPECT_EQ(SECSuccess,
SSL_GetCipherSuiteInfo(preinfo.cipherSuite, &cipherinfo,
sizeof(cipherinfo)));
EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length);
auto spec = std::make_shared<TlsCipherSpec>(true, epoch);
EXPECT_TRUE(spec->SetKeys(&cipherinfo, secret));
self->send_cipher_specs_.push_back(spec);
}
std::shared_ptr<TlsAgent> agent_;
std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_;
};
class TlsVersioned {
public:
TlsVersioned() : variant_(ssl_variant_stream), version_(0) {}
......@@ -45,22 +99,57 @@ class TlsVersioned {
class TlsRecordHeader : public TlsVersioned {
public:
TlsRecordHeader()
: TlsVersioned(), content_type_(0), sequence_number_(0), header_() {}
: TlsVersioned(),
content_type_(0),
guess_seqno_(0),
seqno_is_masked_(false),
sequence_number_(0),
header_() {}
TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct,
uint64_t seqno)
: TlsVersioned(var, ver),
content_type_(ct),
guess_seqno_(0),
seqno_is_masked_(false),
sequence_number_(seqno),
header_() {}
header_(),
sn_mask_() {}
bool is_protected() const {
// *TLS < 1.3
if (version() < SSL_LIBRARY_VERSION_TLS_1_3 &&
content_type() == ssl_ct_application_data) {
return true;
}
// TLS 1.3
if (!is_dtls() && version() >= SSL_LIBRARY_VERSION_TLS_1_3 &&
content_type() == ssl_ct_application_data) {
return true;
}
// DTLS 1.3
return is_dtls13_ciphertext();
}
uint8_t content_type() const { return content_type_; }
uint64_t sequence_number() const { return sequence_number_; }
uint16_t epoch() const {
return static_cast<uint16_t>(sequence_number_ >> 48);
}
uint64_t sequence_number() const { return sequence_number_; }
void sequence_number(uint64_t seqno) { sequence_number_ = seqno; }
const DataBuffer& sn_mask() const { return sn_mask_; }
bool is_dtls13_ciphertext() const {
return is_dtls() && (version() >= SSL_LIBRARY_VERSION_TLS_1_3) &&
(content_type() & kCtDtlsCiphertextMask) == kCtDtlsCiphertext;
}
size_t header_length() const;
const DataBuffer& header() const { return header_; }
bool MaskSequenceNumber();
bool MaskSequenceNumber(const DataBuffer& mask);
// Parse the header; return true if successful; body in an outparam if OK.
bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser,
DataBuffer* body);
......@@ -70,14 +159,17 @@ class TlsRecordHeader : public TlsVersioned {
size_t WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const;
private:
static uint64_t RecoverSequenceNumber(uint64_t expected, uint32_t partial,
static uint64_t RecoverSequenceNumber(uint64_t guess_seqno, uint32_t partial,
size_t partial_bits);
static uint64_t ParseSequenceNumber(uint64_t expected, uint32_t raw,
size_t seq_no_bits, size_t epoch_bits);
uint64_t ParseSequenceNumber(uint64_t expected, uint64_t raw,
size_t seq_no_bits, size_t epoch_bits);
uint8_t content_type_;
uint64_t guess_seqno_;
bool seqno_is_masked_;
uint64_t sequence_number_;
DataBuffer header_;
DataBuffer sn_mask_;
};
struct TlsRecord {
......@@ -111,12 +203,14 @@ class TlsRecordFilter : public PacketFilter {
// Enabling it for lower version tests will cause undefined
// behavior.
void EnableDecryption();
bool decrypting() const { return decrypting_; };
bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText,
uint16_t* protection_epoch, uint8_t* inner_content_type,
DataBuffer* plaintext);
DataBuffer* plaintext, TlsRecordHeader* out_header);
bool Protect(TlsCipherSpec& protection_spec, const TlsRecordHeader& header,
uint8_t inner_content_type, const DataBuffer& plaintext,
DataBuffer* ciphertext, size_t padding = 0);
DataBuffer* ciphertext, TlsRecordHeader* out_header,
size_t padding = 0);
protected:
// There are two filter functions which can be overriden. Both are
......@@ -141,6 +235,7 @@ class TlsRecordFilter : public PacketFilter {
}
bool is_dtls13() const;
bool is_dtls13_ciphertext(uint8_t ct) const;
TlsCipherSpec& spec(uint16_t epoch);
private:
......@@ -471,8 +566,9 @@ class TlsEncryptedHandshakeMessageReplacer : public TlsRecordFilter {
uint16_t protection_epoch = 0;
uint8_t inner_content_type;
DataBuffer plaintext;
TlsRecordHeader out_header;
if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
&plaintext) ||
&plaintext, &out_header) ||
!plaintext.len()) {
return KEEP;
}
......@@ -501,12 +597,12 @@ class TlsEncryptedHandshakeMessageReplacer : public TlsRecordFilter {
}
DataBuffer ciphertext;
bool ok = Protect(spec(protection_epoch), header, inner_content_type,
plaintext, &ciphertext, 0);
bool ok = Protect(spec(protection_epoch), out_header, inner_content_type,
plaintext, &ciphertext, &out_header);
if (!ok) {
return KEEP;
}
*offset = header.Write(output, *offset, ciphertext);
*offset = out_header.Write(output, *offset, ciphertext);
return CHANGE;
}
......
......@@ -25,39 +25,66 @@ TlsCipherSpec::TlsCipherSpec(bool dtls, uint16_t epoc)
bool TlsCipherSpec::SetKeys(SSLCipherSuiteInfo* cipherinfo,
PK11SymKey* secret) {
SSLAeadContext* ctx;
SSLAeadContext* aead_ctx;
SECStatus rv = SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3,
cipherinfo->cipherSuite, secret, "",
0, // Use the default labels.
&ctx);
&aead_ctx);
if (rv != SECSuccess) {
return false;
}
aead_.reset(ctx);
aead_.reset(aead_ctx);
SSLMaskingContext* mask_ctx;
const char kHkdfPurposeSn[] = "sn";
rv = SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3,
cipherinfo->cipherSuite, secret, kHkdfPurposeSn,
strlen(kHkdfPurposeSn), &mask_ctx);