From 7755a623cb85643382d6b34d86300709209621ac Mon Sep 17 00:00:00 2001 From: Kevin Jacobs Date: Mon, 6 Jan 2020 21:26:20 +0000 Subject: [PATCH] Bug 1599514 - Update DTLS 1.3 support to draft-30 r=mt This patch updates the DTLS 1.3 implementation to draft version 30, including unified header format and sequence number encryption. Also added are new `SSL_CreateMask` experimental functions. Differential Revision: https://phabricator.services.mozilla.com/D51014 --HG-- rename : gtests/ssl_gtest/ssl_primitive_unittest.cc => gtests/ssl_gtest/ssl_aead_unittest.cc extra : moz-landing-system : lando --- cpputil/databuffer.h | 1 + cpputil/scoped_ptrs_ssl.h | 2 + cpputil/tls_parser.h | 5 + gtests/ssl_gtest/manifest.mn | 3 +- ...itive_unittest.cc => ssl_aead_unittest.cc} | 6 +- gtests/ssl_gtest/ssl_ciphersuite_unittest.cc | 29 +- gtests/ssl_gtest/ssl_drop_unittest.cc | 71 +--- gtests/ssl_gtest/ssl_gtest.gyp | 3 +- gtests/ssl_gtest/ssl_masking_unittest.cc | 337 ++++++++++++++++++ gtests/ssl_gtest/ssl_record_unittest.cc | 33 +- gtests/ssl_gtest/ssl_recordsize_unittest.cc | 36 +- gtests/ssl_gtest/ssl_tls13compat_unittest.cc | 15 +- gtests/ssl_gtest/tls_agent.cc | 37 +- gtests/ssl_gtest/tls_filter.cc | 189 +++++++--- gtests/ssl_gtest/tls_filter.h | 120 ++++++- gtests/ssl_gtest/tls_protect.cc | 89 +++-- gtests/ssl_gtest/tls_protect.h | 5 +- lib/ssl/dtls13con.c | 92 ++++- lib/ssl/dtls13con.h | 6 +- lib/ssl/dtlscon.c | 40 ++- lib/ssl/dtlscon.h | 1 + lib/ssl/ssl3con.c | 45 ++- lib/ssl/ssl3gthr.c | 25 +- lib/ssl/ssl3prot.h | 2 +- lib/ssl/sslexp.h | 50 +++ lib/ssl/sslimpl.h | 26 +- lib/ssl/sslprimitive.c | 205 +++++++++-- lib/ssl/sslsock.c | 3 + lib/ssl/sslspec.c | 3 + lib/ssl/sslspec.h | 3 + lib/ssl/tls13con.c | 40 +++ lib/ssl/tls13con.h | 6 +- 32 files changed, 1260 insertions(+), 268 deletions(-) rename gtests/ssl_gtest/{ssl_primitive_unittest.cc => ssl_aead_unittest.cc} (98%) create mode 100644 gtests/ssl_gtest/ssl_masking_unittest.cc diff --git a/cpputil/databuffer.h b/cpputil/databuffer.h index e981a7c223..4bedd075db 100644 --- a/cpputil/databuffer.h +++ b/cpputil/databuffer.h @@ -23,6 +23,7 @@ class DataBuffer { DataBuffer(const DataBuffer& other) : data_(nullptr), len_(0) { Assign(other); } + explicit DataBuffer(size_t l) : data_(nullptr), len_(0) { Allocate(l); } ~DataBuffer() { delete[] data_; } DataBuffer& operator=(const DataBuffer& other) { diff --git a/cpputil/scoped_ptrs_ssl.h b/cpputil/scoped_ptrs_ssl.h index 474187540e..682ebab829 100644 --- a/cpputil/scoped_ptrs_ssl.h +++ b/cpputil/scoped_ptrs_ssl.h @@ -12,6 +12,7 @@ struct ScopedDeleteSSL { void operator()(SSLAeadContext* ctx) { SSL_DestroyAead(ctx); } + void operator()(SSLMaskingContext* ctx) { SSL_DestroyMaskingContext(ctx); } void operator()(SSLAntiReplayContext* ctx) { SSL_ReleaseAntiReplayContext(ctx); } @@ -34,6 +35,7 @@ struct ScopedMaybeDeleteSSL { SCOPED(SSLAeadContext); SCOPED(SSLAntiReplayContext); +SCOPED(SSLMaskingContext); SCOPED(SSLResumptionTokenInfo); #undef SCOPED diff --git a/cpputil/tls_parser.h b/cpputil/tls_parser.h index 05dd99fc84..6636b3c6a7 100644 --- a/cpputil/tls_parser.h +++ b/cpputil/tls_parser.h @@ -74,6 +74,11 @@ const uint8_t kTlsFakeChangeCipherSpec[] = { 0x01 // Value }; +const uint8_t kCtDtlsCiphertext = 0x20; +const uint8_t kCtDtlsCiphertextMask = 0xE0; +const uint8_t kCtDtlsCiphertext16bSeqno = 0x08; +const uint8_t kCtDtlsCiphertextLengthPresent = 0x04; + static const uint8_t kTls13PskKe = 0; static const uint8_t kTls13PskDhKe = 1; static const uint8_t kTls13PskAuth = 0; diff --git a/gtests/ssl_gtest/manifest.mn b/gtests/ssl_gtest/manifest.mn index ed1128f7cb..d5e96a4901 100644 --- a/gtests/ssl_gtest/manifest.mn +++ b/gtests/ssl_gtest/manifest.mn @@ -14,6 +14,7 @@ CSRCS = \ CPPSRCS = \ bloomfilter_unittest.cc \ ssl_0rtt_unittest.cc \ + ssl_aead_unittest.cc \ ssl_agent_unittest.cc \ ssl_auth_unittest.cc \ ssl_cert_ext_unittest.cc \ @@ -35,8 +36,8 @@ CPPSRCS = \ ssl_hrr_unittest.cc \ ssl_keyupdate_unittest.cc \ ssl_loopback_unittest.cc \ + ssl_masking_unittest.cc \ ssl_misc_unittest.cc \ - ssl_primitive_unittest.cc \ ssl_record_unittest.cc \ ssl_recordsep_unittest.cc \ ssl_recordsize_unittest.cc \ diff --git a/gtests/ssl_gtest/ssl_primitive_unittest.cc b/gtests/ssl_gtest/ssl_aead_unittest.cc similarity index 98% rename from gtests/ssl_gtest/ssl_primitive_unittest.cc rename to gtests/ssl_gtest/ssl_aead_unittest.cc index 66ecdeb12f..d94683be30 100644 --- a/gtests/ssl_gtest/ssl_primitive_unittest.cc +++ b/gtests/ssl_gtest/ssl_aead_unittest.cc @@ -54,7 +54,7 @@ class AeadTest : public ::testing::Test { ASSERT_GE(kMaxSize, ciphertext_len); ASSERT_LT(0U, ciphertext_len); - uint8_t output[kMaxSize]; + uint8_t output[kMaxSize] = {0}; unsigned int output_len = 0; EXPECT_EQ(SECSuccess, SSL_AeadEncrypt(ctx.get(), 0, kAad, sizeof(kAad), kPlaintext, sizeof(kPlaintext), @@ -181,7 +181,7 @@ TEST_F(AeadTest, AeadNoPointer) { } TEST_F(AeadTest, AeadAes128Gcm) { - SSLAeadContext *ctxInit; + SSLAeadContext *ctxInit = nullptr; ASSERT_EQ(SECSuccess, SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256, secret_.get(), kLabel, strlen(kLabel), &ctxInit)); @@ -203,7 +203,7 @@ TEST_F(AeadTest, AeadAes256Gcm) { } TEST_F(AeadTest, AeadChaCha20Poly1305) { - SSLAeadContext *ctxInit; + SSLAeadContext *ctxInit = nullptr; ASSERT_EQ( SECSuccess, SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_CHACHA20_POLY1305_SHA256, diff --git a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc index 7739fe76f3..86cb02d73f 100644 --- a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc +++ b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc @@ -263,6 +263,7 @@ TEST_P(TlsCipherSuiteTest, ResumeCipherSuite) { TEST_P(TlsCipherSuiteTest, ReadLimit) { SetupCertificate(); EnableSingleCipher(); + TlsSendCipherSpecCapturer capturer(client_); ConnectAndCheckCipherSuite(); if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { uint64_t last = last_safe_write(); @@ -295,9 +296,31 @@ TEST_P(TlsCipherSuiteTest, ReadLimit) { } else { epoch = 0; } - TlsAgentTestBase::MakeRecord(variant_, ssl_ct_application_data, version_, - payload, sizeof(payload), &record, - (epoch << 48) | record_limit()); + + uint64_t seqno = (epoch << 48) | record_limit(); + + // DTLS 1.3 masks the sequence number + if (variant_ == ssl_variant_datagram && + version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + auto spec = capturer.spec(1); + ASSERT_NE(nullptr, spec.get()); + ASSERT_EQ(3, spec->epoch()); + + DataBuffer pt, ct; + uint8_t dtls13_ctype = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | + kCtDtlsCiphertextLengthPresent; + TlsRecordHeader hdr(variant_, version_, dtls13_ctype, seqno); + pt.Assign(payload, sizeof(payload)); + TlsRecordHeader out_hdr; + spec->Protect(hdr, pt, &ct, &out_hdr); + + auto rv = out_hdr.Write(&record, 0, ct); + EXPECT_EQ(out_hdr.header_length() + ct.len(), rv); + } else { + TlsAgentTestBase::MakeRecord(variant_, ssl_ct_application_data, version_, + payload, sizeof(payload), &record, seqno); + } + client_->SendDirect(record); server_->ExpectReadWriteError(); server_->ReadBytes(); diff --git a/gtests/ssl_gtest/ssl_drop_unittest.cc b/gtests/ssl_gtest/ssl_drop_unittest.cc index b441b5c10d..05b38e381b 100644 --- a/gtests/ssl_gtest/ssl_drop_unittest.cc +++ b/gtests/ssl_gtest/ssl_drop_unittest.cc @@ -619,55 +619,6 @@ TEST_P(TlsDropDatagram13, ReorderServerEE) { // The client sends an out of order non-handshake message // but with the handshake key. -class TlsSendCipherSpecCapturer { - public: - TlsSendCipherSpecCapturer(const std::shared_ptr& agent) - : agent_(agent), send_cipher_specs_() { - EXPECT_EQ(SECSuccess, - SSL_SecretCallback(agent_->ssl_fd(), SecretCallback, this)); - } - - std::shared_ptr spec(size_t i) { - if (i >= send_cipher_specs_.size()) { - return nullptr; - } - return send_cipher_specs_[i]; - } - - private: - static void SecretCallback(PRFileDesc* fd, PRUint16 epoch, - SSLSecretDirection dir, PK11SymKey* secret, - void* arg) { - auto self = static_cast(arg); - std::cerr << self->agent_->role_str() << ": capture " << dir - << " secret for epoch " << epoch << std::endl; - - if (dir == ssl_secret_read) { - return; - } - - SSLPreliminaryChannelInfo preinfo; - EXPECT_EQ(SECSuccess, - SSL_GetPreliminaryChannelInfo(self->agent_->ssl_fd(), &preinfo, - sizeof(preinfo))); - EXPECT_EQ(sizeof(preinfo), preinfo.length); - EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite); - - SSLCipherSuiteInfo cipherinfo; - EXPECT_EQ(SECSuccess, - SSL_GetCipherSuiteInfo(preinfo.cipherSuite, &cipherinfo, - sizeof(cipherinfo))); - EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length); - - auto spec = std::make_shared(true, epoch); - EXPECT_TRUE(spec->SetKeys(&cipherinfo, secret)); - self->send_cipher_specs_.push_back(spec); - } - - std::shared_ptr agent_; - std::vector> send_cipher_specs_; -}; - TEST_F(TlsConnectDatagram13, SendOutOfOrderAppWithHandshakeKey) { StartConnect(); // Capturing secrets means that we can't use decrypting filters on the client. @@ -684,8 +635,10 @@ TEST_F(TlsConnectDatagram13, SendOutOfOrderAppWithHandshakeKey) { auto spec = capturer.spec(0); ASSERT_NE(nullptr, spec.get()); ASSERT_EQ(2, spec->epoch()); - ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002, - ssl_ct_application_data, + + uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | + kCtDtlsCiphertextLengthPresent; + ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002, dtls13_ct, DataBuffer(buf, sizeof(buf)))); // Now have the server consume the bogus message. @@ -844,7 +797,7 @@ static void GetCipherAndLimit(uint16_t version, uint16_t* cipher, // a reasonable amount of time. *cipher = TLS_CHACHA20_POLY1305_SHA256; // Assume that we are starting with an expected sequence number of 0. - *limit = (1ULL << 29) - 1; + *limit = (1ULL << 15) - 1; } } @@ -866,14 +819,14 @@ TEST_P(TlsConnectDatagram, MissLotsOfPackets) { SendReceive(); } -// Send a sequence number of 0xfffffffd and it should be interpreted as that +// Send a sequence number of 0xfffd and it should be interpreted as that // (and not -3 or UINT64_MAX - 2). TEST_F(TlsConnectDatagram13, UnderflowSequenceNumber) { Connect(); // This is only valid if short headers are disabled. client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_FALSE); EXPECT_EQ(SECSuccess, - SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), (1ULL << 30) - 3)); + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), (1ULL << 16) - 3)); SendReceive(); } @@ -918,9 +871,13 @@ class TlsReplaceFirstRecordWithJunk : public TlsRecordFilter { return KEEP; } replaced_ = true; - TlsRecordHeader out_header(header.variant(), header.version(), - ssl_ct_application_data, - header.sequence_number()); + + uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | + kCtDtlsCiphertextLengthPresent; + TlsRecordHeader out_header( + header.variant(), header.version(), + is_dtls13() ? dtls13_ct : ssl_ct_application_data, + header.sequence_number()); static const uint8_t junk[] = {1, 2, 3, 4}; *offset = out_header.Write(output, *offset, DataBuffer(junk, sizeof(junk))); diff --git a/gtests/ssl_gtest/ssl_gtest.gyp b/gtests/ssl_gtest/ssl_gtest.gyp index 6cff0fc9d6..ae79c41fe5 100644 --- a/gtests/ssl_gtest/ssl_gtest.gyp +++ b/gtests/ssl_gtest/ssl_gtest.gyp @@ -15,6 +15,7 @@ 'libssl_internals.c', 'selfencrypt_unittest.cc', 'ssl_0rtt_unittest.cc', + 'ssl_aead_unittest.cc', 'ssl_agent_unittest.cc', 'ssl_auth_unittest.cc', 'ssl_cert_ext_unittest.cc', @@ -36,8 +37,8 @@ 'ssl_hrr_unittest.cc', 'ssl_keyupdate_unittest.cc', 'ssl_loopback_unittest.cc', + 'ssl_masking_unittest.cc', 'ssl_misc_unittest.cc', - 'ssl_primitive_unittest.cc', 'ssl_record_unittest.cc', 'ssl_recordsep_unittest.cc', 'ssl_recordsize_unittest.cc', diff --git a/gtests/ssl_gtest/ssl_masking_unittest.cc b/gtests/ssl_gtest/ssl_masking_unittest.cc new file mode 100644 index 0000000000..5b63b945b4 --- /dev/null +++ b/gtests/ssl_gtest/ssl_masking_unittest.cc @@ -0,0 +1,337 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include + +#include "keyhi.h" +#include "pk11pub.h" +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslexp.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "scoped_ptrs_ssl.h" +#include "tls_connect.h" + +namespace nss_test { + +// From tls_hkdf_unittest.cc: +extern size_t GetHashLength(SSLHashType ht); + +const std::string kLabel = "sn"; + +class MaskingTest : public ::testing::Test { + public: + MaskingTest() : slot_(PK11_GetInternalSlot()) {} + + void InitSecret(SSLHashType hash_type) { + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + PK11SymKey *s = PK11_KeyGen(slot_.get(), CKM_GENERIC_SECRET_KEY_GEN, + nullptr, AES_128_KEY_LENGTH, nullptr); + ASSERT_NE(nullptr, s); + secret_.reset(s); + } + + void SetUp() override { + InitSecret(ssl_hash_sha256); + PORT_SetError(0); + } + + protected: + ScopedPK11SymKey secret_; + ScopedPK11SlotInfo slot_; + void CreateMask(PRUint16 ciphersuite, std::string label, + const std::vector &sample, + std::vector *out_mask) { + ASSERT_NE(nullptr, out_mask); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite, + secret_.get(), label.c_str(), + label.size(), &ctx_init)); + EXPECT_EQ(0, PORT_GetError()); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + + EXPECT_EQ(SECSuccess, + SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + out_mask->data(), out_mask->size())); + EXPECT_EQ(0, PORT_GetError()); + bool all_zeros = std::all_of(out_mask->begin(), out_mask->end(), + [](uint8_t v) { return v == 0; }); + + // If out_mask is short, |all_zeros| will be (expectedly) true often enough + // to fail tests. + // In this case, just retry to make sure we're not outputting zeros + // continuously. + if (all_zeros && out_mask->size() < 3) { + unsigned int tries = 2; + std::vector tmp_sample = sample; + std::vector tmp_mask(out_mask->size()); + while (tries--) { + tmp_sample.data()[0]++; // Tweak something to get a new mask. + EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), tmp_sample.data(), + tmp_sample.size(), tmp_mask.data(), + tmp_mask.size())); + EXPECT_EQ(0, PORT_GetError()); + bool retry_zero = std::all_of(tmp_mask.begin(), tmp_mask.end(), + [](uint8_t v) { return v == 0; }); + if (!retry_zero) { + all_zeros = false; + break; + } + } + } + EXPECT_FALSE(all_zeros); + } +}; + +TEST_F(MaskingTest, MaskContextNoLabel) { + std::vector sample(AES_BLOCK_SIZE); + std::vector mask(AES_BLOCK_SIZE); + CreateMask(TLS_AES_128_GCM_SHA256, std::string(""), sample, &mask); +} + +TEST_F(MaskingTest, MaskContextUnsupportedMech) { + std::vector sample(AES_BLOCK_SIZE); + std::vector mask(AES_BLOCK_SIZE); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECFailure, + SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_RSA_WITH_AES_128_CBC_SHA256, + secret_.get(), nullptr, 0, &ctx_init)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(nullptr, ctx_init); +} + +TEST_F(MaskingTest, MaskNullSample) { + std::vector mask(AES_BLOCK_SIZE); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_AES_128_GCM_SHA256, secret_.get(), + kLabel.c_str(), kLabel.size(), &ctx_init)); + EXPECT_EQ(0, PORT_GetError()); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + + EXPECT_EQ(SECFailure, + SSL_CreateMask(ctx.get(), nullptr, 0, mask.data(), mask.size())); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), nullptr, mask.size(), + mask.data(), mask.size())); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(MaskingTest, MaskContextUnsupportedVersion) { + std::vector sample(AES_BLOCK_SIZE); + std::vector mask(AES_BLOCK_SIZE); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECFailure, SSL_CreateMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256, + secret_.get(), nullptr, 0, &ctx_init)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(nullptr, ctx_init); +} + +TEST_F(MaskingTest, MaskTooMuchOutput) { + // Max internally-supported length for AES + std::vector sample(AES_BLOCK_SIZE); + std::vector mask(AES_BLOCK_SIZE + 1); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_AES_128_GCM_SHA256, secret_.get(), + kLabel.c_str(), kLabel.size(), &ctx_init)); + EXPECT_EQ(0, PORT_GetError()); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + + EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + mask.data(), mask.size())); + EXPECT_EQ(SEC_ERROR_OUTPUT_LEN, PORT_GetError()); +} + +TEST_F(MaskingTest, MaskShortOutput) { + std::vector sample(16); + std::vector mask(16); // Don't pass a null + + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_AES_128_GCM_SHA256, secret_.get(), + kLabel.c_str(), kLabel.size(), &ctx_init)); + EXPECT_EQ(0, PORT_GetError()); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + mask.data(), 0)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(MaskingTest, MaskRotateLabel) { + std::vector sample(AES_BLOCK_SIZE); + std::vector mask1(AES_BLOCK_SIZE); + std::vector mask2(AES_BLOCK_SIZE); + EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), + sample.size())); + + CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask1); + CreateMask(TLS_AES_128_GCM_SHA256, std::string("sn1"), sample, &mask2); + EXPECT_FALSE(mask1 == mask2); +} + +TEST_F(MaskingTest, MaskRotateSample) { + std::vector sample(AES_BLOCK_SIZE); + std::vector mask1(AES_BLOCK_SIZE); + std::vector mask2(AES_BLOCK_SIZE); + + EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), + sample.size())); + CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask1); + + EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), + sample.size())); + CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask2); + EXPECT_FALSE(mask1 == mask2); +} + +TEST_F(MaskingTest, MaskAesRederive) { + std::vector sample(AES_BLOCK_SIZE); + std::vector mask1(AES_BLOCK_SIZE); + std::vector mask2(AES_BLOCK_SIZE); + + SECStatus rv = + PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), sample.size()); + EXPECT_EQ(SECSuccess, rv); + + // Check that re-using inputs with a new context produces the same mask. + CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask1); + CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask2); + EXPECT_TRUE(mask1 == mask2); +} + +TEST_F(MaskingTest, MaskAesTooLong) { + std::vector sample(AES_BLOCK_SIZE + 1); + std::vector mask(AES_BLOCK_SIZE + 1); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_AES_128_GCM_SHA256, secret_.get(), + kLabel.c_str(), kLabel.size(), &ctx_init)); + EXPECT_EQ(0, PORT_GetError()); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + mask.data(), mask.size())); + EXPECT_EQ(SEC_ERROR_OUTPUT_LEN, PORT_GetError()); +} + +TEST_F(MaskingTest, MaskAesShortSample) { + std::vector sample(AES_BLOCK_SIZE - 1); + std::vector mask(AES_BLOCK_SIZE - 1); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_AES_128_GCM_SHA256, secret_.get(), + kLabel.c_str(), kLabel.size(), &ctx_init)); + EXPECT_EQ(0, PORT_GetError()); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + + EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + mask.data(), mask.size())); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(MaskingTest, MaskAesShortValid) { + std::vector sample(AES_BLOCK_SIZE); + std::vector mask(1); + EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), + sample.size())); + CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask); +} + +TEST_F(MaskingTest, MaskChaChaRederive) { + // Block-aligned. + std::vector sample(32); + std::vector mask1(32); + std::vector mask2(32); + EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), + sample.size())); + CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask1); + CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask2); + EXPECT_TRUE(mask1 == mask2); +} + +TEST_F(MaskingTest, MaskChaChaRederiveOddSizes) { + // Non-block-aligned. + std::vector sample(27); + std::vector mask1(26); + std::vector mask2(25); + EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), + sample.size())); + CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask1); + CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask2); + mask1.pop_back(); + EXPECT_TRUE(mask1 == mask2); +} + +TEST_F(MaskingTest, MaskChaChaLongValid) { + // Max internally-supported length for ChaCha + std::vector sample(128); + std::vector mask(128); + EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), + sample.size())); + CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask); +} + +TEST_F(MaskingTest, MaskChaChaTooLong) { + // Max internally-supported length for ChaCha + std::vector sample(128 + 1); + std::vector mask(128 + 1); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_CHACHA20_POLY1305_SHA256, + secret_.get(), kLabel.c_str(), + kLabel.size(), &ctx_init)); + EXPECT_EQ(0, PORT_GetError()); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + mask.data(), mask.size())); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(MaskingTest, MaskChaChaShortSample) { + std::vector sample(15); // Should have 4B ctr, 12B nonce. + std::vector mask(15); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_CHACHA20_POLY1305_SHA256, + secret_.get(), kLabel.c_str(), + kLabel.size(), &ctx_init)); + EXPECT_EQ(0, PORT_GetError()); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + mask.data(), mask.size())); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(MaskingTest, MaskChaChaShortValid) { + std::vector sample(16); + std::vector mask(1); + EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), + sample.size())); + CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask); +} + +} // namespace nss_test diff --git a/gtests/ssl_gtest/ssl_record_unittest.cc b/gtests/ssl_gtest/ssl_record_unittest.cc index 86783b86e8..ca4fc96f86 100644 --- a/gtests/ssl_gtest/ssl_record_unittest.cc +++ b/gtests/ssl_gtest/ssl_record_unittest.cc @@ -185,8 +185,8 @@ TEST_F(TlsConnectStreamTls13, TooLargeRecord) { class ShortHeaderChecker : public PacketFilter { public: PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output) { - // The first octet should be 0b001xxxxx. - EXPECT_EQ(1, input.data()[0] >> 5); + // The first octet should be 0b001000xx. + EXPECT_EQ(kCtDtlsCiphertext, (input.data()[0] & ~0x3)); return KEEP; } }; @@ -205,6 +205,35 @@ TEST_F(TlsConnectDatagram13, ShortHeadersServer) { SendReceive(); } +// Send a DTLSCiphertext header with a 2B sequence number, and no length. +TEST_F(TlsConnectDatagram13, DtlsAlternateShortHeader) { + StartConnect(); + TlsSendCipherSpecCapturer capturer(client_); + Connect(); + SendReceive(50); + + uint8_t buf[] = {0x32, 0x33, 0x34}; + auto spec = capturer.spec(1); + ASSERT_NE(nullptr, spec.get()); + ASSERT_EQ(3, spec->epoch()); + + uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno; + TlsRecordHeader header(variant_, SSL_LIBRARY_VERSION_TLS_1_3, dtls13_ct, + 0x0003000000000001); + TlsRecordHeader out_header(header); + DataBuffer msg(buf, sizeof(buf)); + msg.Write(msg.len(), ssl_ct_application_data, 1); + DataBuffer ciphertext; + EXPECT_TRUE(spec->Protect(header, msg, &ciphertext, &out_header)); + + DataBuffer record; + auto rv = out_header.Write(&record, 0, ciphertext); + EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv); + client_->SendDirect(record); + + server_->ReadBytes(3); +} + TEST_F(TlsConnectStreamTls13, UnencryptedFinishedMessage) { StartConnect(); client_->Handshake(); // Send ClientHello diff --git a/gtests/ssl_gtest/ssl_recordsize_unittest.cc b/gtests/ssl_gtest/ssl_recordsize_unittest.cc index f2003a3589..8926b5551e 100644 --- a/gtests/ssl_gtest/ssl_recordsize_unittest.cc +++ b/gtests/ssl_gtest/ssl_recordsize_unittest.cc @@ -19,7 +19,8 @@ namespace nss_test { // This class tracks the maximum size of record that was sent, both cleartext // and plain. It only tracks records that have an outer type of -// application_data. In TLS 1.3, this includes handshake messages. +// application_data or DTLSCiphertext. In TLS 1.3, this includes handshake +// messages. class TlsRecordMaximum : public TlsRecordFilter { public: TlsRecordMaximum(const std::shared_ptr& a) @@ -34,7 +35,7 @@ class TlsRecordMaximum : public TlsRecordFilter { DataBuffer* output) override { std::cerr << "max: " << record << std::endl; // Ignore unprotected packets. - if (header.content_type() != ssl_ct_application_data) { + if (!header.is_protected()) { return KEEP; } @@ -195,9 +196,23 @@ class TlsRecordExpander : public TlsRecordFilter { virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& data, DataBuffer* changed) { - if (header.content_type() != ssl_ct_application_data) { - return KEEP; + if (!header.is_protected()) { + // We're targeting application_data records. If the record is + // |!is_protected()|, we have two possibilities: + if (!decrypting()) { + // 1) We're not decrypting, in which this case this is truly an + // unencrypted record (Keep). + return KEEP; + } + if (header.content_type() != ssl_ct_application_data) { + // 2) We are decrypting, so is_protected() read the internal + // content_type. If the internal ct IS NOT application_data, then + // it's not our target (Keep). + return KEEP; + } + // Otherwise, the the internal ct IS application_data (Change). } + changed->Allocate(data.len() + expansion_); changed->Write(0, data.data(), data.len()); return CHANGE; @@ -261,30 +276,31 @@ class TlsRecordPadder : public TlsRecordFilter { PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& record, size_t* offset, DataBuffer* output) override { - if (header.content_type() != ssl_ct_application_data) { + if (!header.is_protected()) { return KEEP; } uint16_t protection_epoch; uint8_t inner_content_type; DataBuffer plaintext; + TlsRecordHeader out_header; if (!Unprotect(header, record, &protection_epoch, &inner_content_type, - &plaintext)) { + &plaintext, &out_header)) { return KEEP; } - if (inner_content_type != ssl_ct_application_data) { + if (decrypting() && inner_content_type != ssl_ct_application_data) { return KEEP; } DataBuffer ciphertext; - bool ok = Protect(spec(protection_epoch), header, inner_content_type, - plaintext, &ciphertext, padding_); + bool ok = Protect(spec(protection_epoch), out_header, inner_content_type, + plaintext, &ciphertext, &out_header, padding_); EXPECT_TRUE(ok); if (!ok) { return KEEP; } - *offset = header.Write(output, *offset, ciphertext); + *offset = out_header.Write(output, *offset, ciphertext); return CHANGE; } diff --git a/gtests/ssl_gtest/ssl_tls13compat_unittest.cc b/gtests/ssl_gtest/ssl_tls13compat_unittest.cc index ecb63d4764..6905ed0c0f 100644 --- a/gtests/ssl_gtest/ssl_tls13compat_unittest.cc +++ b/gtests/ssl_gtest/ssl_tls13compat_unittest.cc @@ -384,14 +384,16 @@ TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) { ASSERT_EQ(2U, client_records->count()); // CH, Fin EXPECT_EQ(ssl_ct_handshake, client_records->record(0).header.content_type()); - EXPECT_EQ(ssl_ct_application_data, - client_records->record(1).header.content_type()); + EXPECT_EQ(kCtDtlsCiphertext, + (client_records->record(1).header.content_type() & + kCtDtlsCiphertextMask)); ASSERT_EQ(6U, server_records->count()); // SH, EE, CT, CV, Fin, Ack EXPECT_EQ(ssl_ct_handshake, server_records->record(0).header.content_type()); for (size_t i = 1; i < server_records->count(); ++i) { - EXPECT_EQ(ssl_ct_application_data, - server_records->record(i).header.content_type()); + EXPECT_EQ(kCtDtlsCiphertext, + (server_records->record(i).header.content_type() & + kCtDtlsCiphertextMask)); } } @@ -440,8 +442,9 @@ TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) { ASSERT_EQ(5U, server_records->count()); // SH, EE, CT, CV, Fin EXPECT_EQ(ssl_ct_handshake, server_records->record(0).header.content_type()); for (size_t i = 1; i < server_records->count(); ++i) { - EXPECT_EQ(ssl_ct_application_data, - server_records->record(i).header.content_type()); + EXPECT_EQ(kCtDtlsCiphertext, + (server_records->record(i).header.content_type() & + kCtDtlsCiphertextMask)); } uint32_t session_id_len = 0; diff --git a/gtests/ssl_gtest/tls_agent.cc b/gtests/ssl_gtest/tls_agent.cc index 88640481e5..b52306961c 100644 --- a/gtests/ssl_gtest/tls_agent.cc +++ b/gtests/ssl_gtest/tls_agent.cc @@ -1064,21 +1064,28 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) { bool TlsAgent::SendEncryptedRecord(const std::shared_ptr& spec, uint64_t seq, uint8_t ct, const DataBuffer& buf) { - LOGV("Encrypting " << buf.len() << " bytes"); // Ensure that we are doing TLS 1.3. EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3); - TlsRecordHeader header(variant_, expected_version_, ssl_ct_application_data, - seq); + if (variant_ != ssl_variant_datagram) { + ADD_FAILURE(); + return false; + } + + LOGV("Encrypting " << buf.len() << " bytes"); + uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | + kCtDtlsCiphertextLengthPresent; + TlsRecordHeader header(variant_, expected_version_, dtls13_ct, seq); + TlsRecordHeader out_header(header); DataBuffer padded = buf; padded.Write(padded.len(), ct, 1); DataBuffer ciphertext; - if (!spec->Protect(header, padded, &ciphertext)) { + if (!spec->Protect(header, padded, &ciphertext, &out_header)) { return false; } DataBuffer record; - auto rv = header.Write(&record, 0, ciphertext); - EXPECT_EQ(header.header_length() + ciphertext.len(), rv); + auto rv = out_header.Write(&record, 0, ciphertext); + EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv); SendDirect(record); return true; } @@ -1202,16 +1209,26 @@ void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type, uint16_t version, const uint8_t* buf, size_t len, DataBuffer* out, uint64_t sequence_number) { + // Fixup the content type for DTLSCiphertext + if (variant == ssl_variant_datagram && + version >= SSL_LIBRARY_VERSION_TLS_1_3 && + type == ssl_ct_application_data) { + type = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | + kCtDtlsCiphertextLengthPresent; + } + size_t index = 0; - index = out->Write(index, type, 1); if (variant == ssl_variant_stream) { + index = out->Write(index, type, 1); index = out->Write(index, version, 2); } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 && - type == ssl_ct_application_data) { + (type & kCtDtlsCiphertextMask) == kCtDtlsCiphertext) { uint32_t epoch = (sequence_number >> 48) & 0x3; - uint32_t seqno = sequence_number & ((1ULL << 30) - 1); - index = out->Write(index, (epoch << 30) | seqno, 4); + index = out->Write(index, type | epoch, 1); + uint32_t seqno = sequence_number & ((1ULL << 16) - 1); + index = out->Write(index, seqno, 2); } else { + index = out->Write(index, type, 1); index = out->Write(index, TlsVersionToDtlsVersion(version), 2); index = out->Write(index, sequence_number >> 32, 4); index = out->Write(index, sequence_number & PR_UINT32_MAX, 4); diff --git a/gtests/ssl_gtest/tls_filter.cc b/gtests/ssl_gtest/tls_filter.cc index b2917274b4..d47ee71ab3 100644 --- a/gtests/ssl_gtest/tls_filter.cc +++ b/gtests/ssl_gtest/tls_filter.cc @@ -120,6 +120,10 @@ bool TlsRecordFilter::is_dtls13() const { info.canSendEarlyData; } +bool TlsRecordFilter::is_dtls13_ciphertext(uint8_t ct) const { + return is_dtls13() && (ct & kCtDtlsCiphertextMask) == kCtDtlsCiphertext; +} + // Gets the cipher spec that matches the specified epoch. TlsCipherSpec& TlsRecordFilter::spec(uint16_t write_epoch) { for (auto& sp : cipher_specs_) { @@ -196,23 +200,24 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( uint8_t inner_content_type; DataBuffer plaintext; uint16_t protection_epoch = 0; + TlsRecordHeader out_header(header); if (!Unprotect(header, record, &protection_epoch, &inner_content_type, - &plaintext)) { + &plaintext, &out_header)) { std::cerr << agent()->role_str() << ": unprotect failed: " << header << ":" << record << std::endl; return KEEP; } auto& protection_spec = spec(protection_epoch); - TlsRecordHeader real_header(header.variant(), header.version(), - inner_content_type, header.sequence_number()); + TlsRecordHeader real_header(out_header.variant(), out_header.version(), + inner_content_type, out_header.sequence_number()); PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered); // In stream mode, even if something doesn't change we need to re-encrypt if // previous packets were dropped. if (action == KEEP) { - if (header.is_dtls() || !protection_spec.record_dropped()) { + if (out_header.is_dtls() || !protection_spec.record_dropped()) { // Count every outgoing packet. protection_spec.RecordProtected(); return KEEP; @@ -221,7 +226,7 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( } if (action == DROP) { - std::cerr << "record drop: " << header << ":" << record << std::endl; + std::cerr << "record drop: " << out_header << ":" << record << std::endl; protection_spec.RecordDropped(); return DROP; } @@ -233,17 +238,15 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( } uint64_t seq_num = protection_spec.next_out_seqno(); - if (!decrypting_ && header.is_dtls()) { + if (!decrypting_ && out_header.is_dtls()) { // Copy over the epoch, which isn't tracked when not decrypting. - seq_num |= header.sequence_number() & (0xffffULL << 48); + seq_num |= out_header.sequence_number() & (0xffffULL << 48); } - - TlsRecordHeader out_header(header.variant(), header.version(), - header.content_type(), seq_num); + out_header.sequence_number(seq_num); DataBuffer ciphertext; bool rv = Protect(protection_spec, out_header, inner_content_type, filtered, - &ciphertext); + &ciphertext, &out_header); if (!rv) { return KEEP; } @@ -262,19 +265,67 @@ size_t TlsRecordHeader::header_length() const { return WriteHeader(&buf, 0, 0); } -uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t expected, +bool TlsRecordHeader::MaskSequenceNumber() { + return MaskSequenceNumber(sn_mask()); +} + +bool TlsRecordHeader::MaskSequenceNumber(const DataBuffer& mask) { + if (mask.empty()) { + return false; + } + + if (is_dtls13_ciphertext()) { + uint64_t seqno = sequence_number(); + uint8_t len = content_type() & kCtDtlsCiphertext16bSeqno ? 2 : 1; + uint16_t seqno_bitmask = (1 << len * 8) - 1; + DataBuffer val; + if (val.Write(0, seqno & seqno_bitmask, len) != len) { + return false; + } + + val.data()[0] ^= mask.data()[0]; + if (len == 2 && mask.len() > 1) { + val.data()[1] ^= mask.data()[1]; + } + + uint32_t tmp; + if (!val.Read(0, len, &tmp)) { + return false; + } + + seqno = (seqno & ~seqno_bitmask) | tmp; + seqno_is_masked_ = !seqno_is_masked_; + if (!seqno_is_masked_) { + seqno = ParseSequenceNumber(guess_seqno_, seqno, len * 8, 2); + } + sequence_number_ = seqno; + + // Now update the header bytes + if (header_.len() > 1) { + header_.data()[1] ^= mask.data()[0]; + if ((content_type() & kCtDtlsCiphertext16bSeqno) && header().len() > 2) { + header_.data()[2] ^= mask.data()[1]; + } + } + } + + sn_mask_ = mask; + return true; +} + +uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t guess_seqno, uint32_t partial, size_t partial_bits) { EXPECT_GE(32U, partial_bits); uint64_t mask = (1ULL << partial_bits) - 1; // First we determine the highest possible value. This is half the - // expressible range above the expected value, less 1. + // expressible range above the expected value (|guess_seqno|), less 1. // // We subtract the extra 1 from the cap so that when given a choice between // the equidistant expected+N and expected-N we want to chose the lower. With // 0-RTT, we sometimes have to recover an epoch of 1 when we expect an epoch // of 3 and with 2 partial bits, the alternative result of 5 is wrong. - uint64_t cap = expected + (1ULL << (partial_bits - 1)) - 1; + uint64_t cap = guess_seqno + (1ULL << (partial_bits - 1)) - 1; // Add the partial piece in. e.g., xxxx789a and 1234 becomes xxxx1234. uint64_t seq_no = (cap & ~mask) | partial; // If the partial value is higher than the same partial piece from the cap, @@ -286,15 +337,18 @@ uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t expected, } // Determine the full epoch and sequence number from an expected and raw value. -// The expected and output values are packed as they are in DTLS 1.2 and -// earlier: with 16 bits of epoch and 48 bits of sequence number. -uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint32_t raw, +// The expected, raw, and output values are packed as they are in DTLS 1.2 and +// earlier: with 16 bits of epoch and 48 bits of sequence number. The raw value +// is packed this way (even before recovery) so that we don't need to track a +// moving value between two calls (one to recover the epoch, and one after +// unmasking to recover the sequence number). +uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint64_t raw, size_t seq_no_bits, size_t epoch_bits) { uint64_t epoch_mask = (1ULL << epoch_bits) - 1; - uint64_t epoch = RecoverSequenceNumber( - expected >> 48, (raw >> seq_no_bits) & epoch_mask, epoch_bits); - if (epoch > (expected >> 48)) { + uint64_t ep = RecoverSequenceNumber(expected >> 48, (raw >> 48) & epoch_mask, + epoch_bits); + if (ep > (expected >> 48)) { // If the epoch has changed, reset the expected sequence number. expected = 0; } else { @@ -302,9 +356,12 @@ uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint32_t raw, expected &= (1ULL << 48) - 1; } uint64_t seq_no_mask = (1ULL << seq_no_bits) - 1; - uint64_t seq_no = - RecoverSequenceNumber(expected, raw & seq_no_mask, seq_no_bits); - return (epoch << 48) | seq_no; + uint64_t seq_no = (raw & seq_no_mask); + if (!seqno_is_masked_) { + seq_no = RecoverSequenceNumber(expected, seq_no, seq_no_bits); + } + + return (ep << 48) | seq_no; } bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser, @@ -320,38 +377,47 @@ bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser, version_ = SSL_LIBRARY_VERSION_TLS_1_3; #ifndef UNSAFE_FUZZER_MODE - // Deal with the 7 octet header. - if (content_type_ == ssl_ct_application_data) { + // Deal with the DTLSCipherText header. + if (is_dtls13_ciphertext()) { + uint8_t seq_no_bytes = + (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1; uint32_t tmp; - if (!parser->Read(&tmp, 4)) { - return false; - } - sequence_number_ = ParseSequenceNumber(seqno, tmp, 30, 2); - if (!parser->ReadFromMark(&header_, parser->consumed() + 2 - mark, - mark)) { + + if (!parser->Read(&tmp, seq_no_bytes)) { return false; } - return parser->ReadVariable(body, 2); - } - // The short, 2 octet header. - if ((content_type_ & 0xe0) == 0x20) { - uint32_t tmp; - if (!parser->Read(&tmp, 1)) { - return false; + // Store the guess if masked. If and when seqno_bytesenceNumber is called, + // the value will be unmasked and recovered. This assumes we only call + // Parse() on headers containing masked values. + seqno_is_masked_ = true; + guess_seqno_ = seqno; + uint64_t ep = content_type_ & 0x03; + sequence_number_ = (ep << 48) | tmp; + + // Recover the full epoch. Note the sequence number portion holds the + // masked value until a call to Mask() reveals it (as indicated by + // |seqno_is_masked_|). + sequence_number_ = + ParseSequenceNumber(seqno, sequence_number_, seq_no_bytes * 8, 2); + + uint32_t len_bytes = + (content_type_ & kCtDtlsCiphertextLengthPresent) ? 2 : 0; + if (len_bytes) { + if (!parser->Read(&tmp, 2)) { + return false; + } } - // Need to use the low 5 bits of the first octet too. - tmp |= (content_type_ & 0x1f) << 8; - content_type_ = ssl_ct_application_data; - sequence_number_ = ParseSequenceNumber(seqno, tmp, 12, 1); if (!parser->ReadFromMark(&header_, parser->consumed() - mark, mark)) { return false; } - return parser->Read(body, parser->remaining()); + + return len_bytes ? parser->Read(body, tmp) + : parser->Read(body, parser->remaining()); } - // The full 13 octet header can only be used for a few types. + // The full DTLSPlainText header can only be used for a few types. EXPECT_TRUE(content_type_ == ssl_ct_alert || content_type_ == ssl_ct_handshake || content_type_ == ssl_ct_ack); @@ -389,15 +455,20 @@ bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser, size_t TlsRecordHeader::WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const { - offset = buffer->Write(offset, content_type_, 1); - if (is_dtls() && version_ >= SSL_LIBRARY_VERSION_TLS_1_3 && - content_type() == ssl_ct_application_data) { + if (is_dtls13_ciphertext()) { + uint8_t seq_no_bytes = (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1; // application_data records in TLS 1.3 have a different header format. - // Always use the long header here for simplicity. uint32_t e = (sequence_number_ >> 48) & 0x3; - uint32_t seqno = sequence_number_ & ((1ULL << 30) - 1); - offset = buffer->Write(offset, (e << 30) | seqno, 4); + uint32_t seqno = sequence_number_ & ((1ULL << seq_no_bytes * 8) - 1); + uint8_t new_content_type_ = content_type_ | e; + offset = buffer->Write(offset, new_content_type_, 1); + offset = buffer->Write(offset, seqno, seq_no_bytes); + + if (content_type_ & kCtDtlsCiphertextLengthPresent) { + offset = buffer->Write(offset, body_len, 2); + } } else { + offset = buffer->Write(offset, content_type_, 1); uint16_t v = is_dtls() ? TlsVersionToDtlsVersion(version_) : version_; offset = buffer->Write(offset, v, 2); if (is_dtls()) { @@ -405,8 +476,9 @@ size_t TlsRecordHeader::WriteHeader(DataBuffer* buffer, size_t offset, offset = buffer->Write(offset, sequence_number_ >> 32, 4); offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4); } + offset = buffer->Write(offset, body_len, 2); } - offset = buffer->Write(offset, body_len, 2); + return offset; } @@ -421,8 +493,9 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, const DataBuffer& ciphertext, uint16_t* protection_epoch, uint8_t* inner_content_type, - DataBuffer* plaintext) { - if (!decrypting_ || header.content_type() != ssl_ct_application_data) { + DataBuffer* plaintext, + TlsRecordHeader* out_header) { + if (!decrypting_ || !header.is_protected()) { // Maintain the epoch and sequence number for plaintext records. uint16_t ep = 0; if (agent()->variant() == ssl_variant_datagram) { @@ -438,7 +511,7 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, uint16_t ep = 0; if (agent()->variant() == ssl_variant_datagram) { ep = static_cast(header.sequence_number() >> 48); - if (!spec(ep).Unprotect(header, ciphertext, plaintext)) { + if (!spec(ep).Unprotect(header, ciphertext, plaintext, out_header)) { return false; } } else { @@ -446,7 +519,8 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, // can't just use the newest keys because the same flight of messages can // contain multiple epochs. So... trial decrypt! for (size_t i = cipher_specs_.size() - 1; i > 0; --i) { - if (cipher_specs_[i].Unprotect(header, ciphertext, plaintext)) { + if (cipher_specs_[i].Unprotect(header, ciphertext, plaintext, + out_header)) { ep = cipher_specs_[i].epoch(); break; } @@ -481,7 +555,8 @@ bool TlsRecordFilter::Protect(TlsCipherSpec& protection_spec, const TlsRecordHeader& header, uint8_t inner_content_type, const DataBuffer& plaintext, - DataBuffer* ciphertext, size_t padding) { + DataBuffer* ciphertext, + TlsRecordHeader* out_header, size_t padding) { if (!protection_spec.is_protected()) { // Not protected, just keep the sequence numbers updated. protection_spec.RecordProtected(); @@ -494,7 +569,7 @@ bool TlsRecordFilter::Protect(TlsCipherSpec& protection_spec, size_t offset = padded.Write(0, plaintext.data(), plaintext.len()); padded.Write(offset, inner_content_type, 1); - bool ok = protection_spec.Protect(header, padded, ciphertext); + bool ok = protection_spec.Protect(header, padded, ciphertext, out_header); if (!ok) { ADD_FAILURE() << "protect fail"; } else if (g_ssl_gtest_verbose) { diff --git a/gtests/ssl_gtest/tls_filter.h b/gtests/ssl_gtest/tls_filter.h index 64ee71c890..8cf558f9c5 100644 --- a/gtests/ssl_gtest/tls_filter.h +++ b/gtests/ssl_gtest/tls_filter.h @@ -12,6 +12,7 @@ #include #include #include "sslt.h" +#include "sslproto.h" #include "test_io.h" #include "tls_agent.h" #include "tls_parser.h" @@ -25,6 +26,59 @@ namespace nss_test { class TlsCipherSpec; +class TlsSendCipherSpecCapturer { + public: + TlsSendCipherSpecCapturer(const std::shared_ptr& agent) + : agent_(agent), send_cipher_specs_() { + EXPECT_EQ(SECSuccess, + SSL_SecretCallback(agent_->ssl_fd(), SecretCallback, this)); + } + + std::shared_ptr spec(size_t i) { + if (i >= send_cipher_specs_.size()) { + return nullptr; + } + return send_cipher_specs_[i]; + } + + private: + static void SecretCallback(PRFileDesc* fd, PRUint16 epoch, + SSLSecretDirection dir, PK11SymKey* secret, + void* arg) { + auto self = static_cast(arg); + std::cerr << self->agent_->role_str() << ": capture " << dir + << " secret for epoch " << epoch << std::endl; + + if (dir == ssl_secret_read) { + return; + } + + SSLPreliminaryChannelInfo preinfo; + EXPECT_EQ(SECSuccess, + SSL_GetPreliminaryChannelInfo(self->agent_->ssl_fd(), &preinfo, + sizeof(preinfo))); + EXPECT_EQ(sizeof(preinfo), preinfo.length); + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite); + + // Check the version: + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version); + ASSERT_GE(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion); + + SSLCipherSuiteInfo cipherinfo; + EXPECT_EQ(SECSuccess, + SSL_GetCipherSuiteInfo(preinfo.cipherSuite, &cipherinfo, + sizeof(cipherinfo))); + EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length); + + auto spec = std::make_shared(true, epoch); + EXPECT_TRUE(spec->SetKeys(&cipherinfo, secret)); + self->send_cipher_specs_.push_back(spec); + } + + std::shared_ptr agent_; + std::vector> send_cipher_specs_; +}; + class TlsVersioned { public: TlsVersioned() : variant_(ssl_variant_stream), version_(0) {} @@ -45,22 +99,57 @@ class TlsVersioned { class TlsRecordHeader : public TlsVersioned { public: TlsRecordHeader() - : TlsVersioned(), content_type_(0), sequence_number_(0), header_() {} + : TlsVersioned(), + content_type_(0), + guess_seqno_(0), + seqno_is_masked_(false), + sequence_number_(0), + header_() {} TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct, uint64_t seqno) : TlsVersioned(var, ver), content_type_(ct), + guess_seqno_(0), + seqno_is_masked_(false), sequence_number_(seqno), - header_() {} + header_(), + sn_mask_() {} + + bool is_protected() const { + // *TLS < 1.3 + if (version() < SSL_LIBRARY_VERSION_TLS_1_3 && + content_type() == ssl_ct_application_data) { + return true; + } + + // TLS 1.3 + if (!is_dtls() && version() >= SSL_LIBRARY_VERSION_TLS_1_3 && + content_type() == ssl_ct_application_data) { + return true; + } + + // DTLS 1.3 + return is_dtls13_ciphertext(); + } uint8_t content_type() const { return content_type_; } - uint64_t sequence_number() const { return sequence_number_; } uint16_t epoch() const { return static_cast(sequence_number_ >> 48); } + uint64_t sequence_number() const { return sequence_number_; } + void sequence_number(uint64_t seqno) { sequence_number_ = seqno; } + const DataBuffer& sn_mask() const { return sn_mask_; } + bool is_dtls13_ciphertext() const { + return is_dtls() && (version() >= SSL_LIBRARY_VERSION_TLS_1_3) && + (content_type() & kCtDtlsCiphertextMask) == kCtDtlsCiphertext; + } + size_t header_length() const; const DataBuffer& header() const { return header_; } + bool MaskSequenceNumber(); + bool MaskSequenceNumber(const DataBuffer& mask); + // Parse the header; return true if successful; body in an outparam if OK. bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser, DataBuffer* body); @@ -70,14 +159,17 @@ class TlsRecordHeader : public TlsVersioned { size_t WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const; private: - static uint64_t RecoverSequenceNumber(uint64_t expected, uint32_t partial, + static uint64_t RecoverSequenceNumber(uint64_t guess_seqno, uint32_t partial, size_t partial_bits); - static uint64_t ParseSequenceNumber(uint64_t expected, uint32_t raw, - size_t seq_no_bits, size_t epoch_bits); + uint64_t ParseSequenceNumber(uint64_t expected, uint64_t raw, + size_t seq_no_bits, size_t epoch_bits); uint8_t content_type_; + uint64_t guess_seqno_; + bool seqno_is_masked_; uint64_t sequence_number_; DataBuffer header_; + DataBuffer sn_mask_; }; struct TlsRecord { @@ -111,12 +203,14 @@ class TlsRecordFilter : public PacketFilter { // Enabling it for lower version tests will cause undefined // behavior. void EnableDecryption(); + bool decrypting() const { return decrypting_; }; bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText, uint16_t* protection_epoch, uint8_t* inner_content_type, - DataBuffer* plaintext); + DataBuffer* plaintext, TlsRecordHeader* out_header); bool Protect(TlsCipherSpec& protection_spec, const TlsRecordHeader& header, uint8_t inner_content_type, const DataBuffer& plaintext, - DataBuffer* ciphertext, size_t padding = 0); + DataBuffer* ciphertext, TlsRecordHeader* out_header, + size_t padding = 0); protected: // There are two filter functions which can be overriden. Both are @@ -141,6 +235,7 @@ class TlsRecordFilter : public PacketFilter { } bool is_dtls13() const; + bool is_dtls13_ciphertext(uint8_t ct) const; TlsCipherSpec& spec(uint16_t epoch); private: @@ -471,8 +566,9 @@ class TlsEncryptedHandshakeMessageReplacer : public TlsRecordFilter { uint16_t protection_epoch = 0; uint8_t inner_content_type; DataBuffer plaintext; + TlsRecordHeader out_header; if (!Unprotect(header, record, &protection_epoch, &inner_content_type, - &plaintext) || + &plaintext, &out_header) || !plaintext.len()) { return KEEP; } @@ -501,12 +597,12 @@ class TlsEncryptedHandshakeMessageReplacer : public TlsRecordFilter { } DataBuffer ciphertext; - bool ok = Protect(spec(protection_epoch), header, inner_content_type, - plaintext, &ciphertext, 0); + bool ok = Protect(spec(protection_epoch), out_header, inner_content_type, + plaintext, &ciphertext, &out_header); if (!ok) { return KEEP; } - *offset = header.Write(output, *offset, ciphertext); + *offset = out_header.Write(output, *offset, ciphertext); return CHANGE; } diff --git a/gtests/ssl_gtest/tls_protect.cc b/gtests/ssl_gtest/tls_protect.cc index de91982f74..7737fe5eaa 100644 --- a/gtests/ssl_gtest/tls_protect.cc +++ b/gtests/ssl_gtest/tls_protect.cc @@ -25,39 +25,66 @@ TlsCipherSpec::TlsCipherSpec(bool dtls, uint16_t epoc) bool TlsCipherSpec::SetKeys(SSLCipherSuiteInfo* cipherinfo, PK11SymKey* secret) { - SSLAeadContext* ctx; + SSLAeadContext* aead_ctx; SECStatus rv = SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, cipherinfo->cipherSuite, secret, "", 0, // Use the default labels. - &ctx); + &aead_ctx); if (rv != SECSuccess) { return false; } - aead_.reset(ctx); + aead_.reset(aead_ctx); + + SSLMaskingContext* mask_ctx; + const char kHkdfPurposeSn[] = "sn"; + rv = SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, + cipherinfo->cipherSuite, secret, kHkdfPurposeSn, + strlen(kHkdfPurposeSn), &mask_ctx); + if (rv != SECSuccess) { + return false; + } + mask_.reset(mask_ctx); return true; } bool TlsCipherSpec::Unprotect(const TlsRecordHeader& header, const DataBuffer& ciphertext, - DataBuffer* plaintext) { - if (aead_ == nullptr) { + DataBuffer* plaintext, + TlsRecordHeader* out_header) { + if (!aead_ || !out_header) { return false; } + *out_header = header; + // Make space. plaintext->Allocate(ciphertext.len()); - auto header_bytes = header.header(); unsigned int len; - uint64_t seqno; - if (dtls_) { - seqno = header.sequence_number(); - } else { - seqno = in_seqno_; + uint64_t seqno = dtls_ ? header.sequence_number() : in_seqno_; + SECStatus rv; + + if (header.is_dtls13_ciphertext()) { + if (!mask_ || !out_header) { + return false; + } + PORT_Assert(ciphertext.len() >= 16); + DataBuffer mask(2); + rv = SSL_CreateMask(mask_.get(), ciphertext.data(), ciphertext.len(), + mask.data(), mask.len()); + if (rv != SECSuccess) { + return false; + } + + if (!out_header->MaskSequenceNumber(mask)) { + return false; + } + seqno = out_header->sequence_number(); } - SECStatus rv = - SSL_AeadDecrypt(aead_.get(), seqno, header_bytes.data(), - header_bytes.len(), ciphertext.data(), ciphertext.len(), - plaintext->data(), &len, plaintext->len()); + + auto header_bytes = out_header->header(); + rv = SSL_AeadDecrypt(aead_.get(), seqno, header_bytes.data(), + header_bytes.len(), ciphertext.data(), ciphertext.len(), + plaintext->data(), &len, plaintext->len()); if (rv != SECSuccess) { return false; } @@ -69,11 +96,14 @@ bool TlsCipherSpec::Unprotect(const TlsRecordHeader& header, } bool TlsCipherSpec::Protect(const TlsRecordHeader& header, - const DataBuffer& plaintext, - DataBuffer* ciphertext) { - if (aead_ == nullptr) { + const DataBuffer& plaintext, DataBuffer* ciphertext, + TlsRecordHeader* out_header) { + if (!aead_ || !out_header) { return false; } + + *out_header = header; + // Make a padded buffer. ciphertext->Allocate(plaintext.len() + 32); // Room for any plausible auth tag @@ -81,12 +111,7 @@ bool TlsCipherSpec::Protect(const TlsRecordHeader& header, DataBuffer header_bytes; (void)header.WriteHeader(&header_bytes, 0, plaintext.len() + 16); - uint64_t seqno; - if (dtls_) { - seqno = header.sequence_number(); - } else { - seqno = out_seqno_; - } + uint64_t seqno = dtls_ ? header.sequence_number() : out_seqno_; SECStatus rv = SSL_AeadEncrypt(aead_.get(), seqno, header_bytes.data(), @@ -96,6 +121,22 @@ bool TlsCipherSpec::Protect(const TlsRecordHeader& header, return false; } + if (header.is_dtls13_ciphertext()) { + if (!mask_ || !out_header) { + return false; + } + PORT_Assert(ciphertext->len() >= 16); + DataBuffer mask(2); + rv = SSL_CreateMask(mask_.get(), ciphertext->data(), ciphertext->len(), + mask.data(), mask.len()); + if (rv != SECSuccess) { + return false; + } + if (!out_header->MaskSequenceNumber(mask)) { + return false; + } + } + RecordProtected(); ciphertext->Truncate(len); diff --git a/gtests/ssl_gtest/tls_protect.h b/gtests/ssl_gtest/tls_protect.h index b1febf8870..d7ea2aa128 100644 --- a/gtests/ssl_gtest/tls_protect.h +++ b/gtests/ssl_gtest/tls_protect.h @@ -27,9 +27,9 @@ class TlsCipherSpec { bool SetKeys(SSLCipherSuiteInfo* cipherinfo, PK11SymKey* secret); bool Protect(const TlsRecordHeader& header, const DataBuffer& plaintext, - DataBuffer* ciphertext); + DataBuffer* ciphertext, TlsRecordHeader* out_header); bool Unprotect(const TlsRecordHeader& header, const DataBuffer& ciphertext, - DataBuffer* plaintext); + DataBuffer* plaintext, TlsRecordHeader* out_header); uint16_t epoch() const { return epoch_; } uint64_t next_in_seqno() const { return in_seqno_; } @@ -52,6 +52,7 @@ class TlsCipherSpec { uint64_t out_seqno_; bool record_dropped_ = false; ScopedSSLAeadContext aead_; + ScopedSSLMaskingContext mask_; }; } // namespace nss_test diff --git a/lib/ssl/dtls13con.c b/lib/ssl/dtls13con.c index 0c4fc7fcd1..c87e0907a5 100644 --- a/lib/ssl/dtls13con.c +++ b/lib/ssl/dtls13con.c @@ -10,38 +10,52 @@ #include "ssl.h" #include "sslimpl.h" #include "sslproto.h" +#include "keyhi.h" +#include "pk11func.h" +/* + * 0 1 2 3 4 5 6 7 + * +-+-+-+-+-+-+-+-+ + * |0|0|1|C|S|L|E E| + * +-+-+-+-+-+-+-+-+ + * | Connection ID | Legend: + * | (if any, | + * / length as / C - CID present + * | negotiated) | S - Sequence number length + * +-+-+-+-+-+-+-+-+ L - Length present + * | 8 or 16 bit | E - Epoch + * |Sequence Number| + * +-+-+-+-+-+-+-+-+ + * | 16 bit Length | + * | (if present) | + * +-+-+-+-+-+-+-+-+ + */ SECStatus -dtls13_InsertCipherTextHeader(const sslSocket *ss, ssl3CipherSpec *cwSpec, +dtls13_InsertCipherTextHeader(const sslSocket *ss, const ssl3CipherSpec *cwSpec, sslBuffer *wrBuf, PRBool *needsLength) { - PRUint32 seq; - SECStatus rv; - /* Avoid using short records for the handshake. We pack multiple records * into the one datagram for the handshake. */ if (ss->opt.enableDtlsShortHeader && - cwSpec->epoch != TrafficKeyHandshake) { + cwSpec->epoch > TrafficKeyHandshake) { *needsLength = PR_FALSE; /* The short header is comprised of two octets in the form - * 0b001essssssssssss where 'e' is the low bit of the epoch and 's' is - * the low 12 bits of the sequence number. */ - seq = 0x2000 | - (((uint64_t)cwSpec->epoch & 1) << 12) | - (cwSpec->nextSeqNum & 0xfff); - return sslBuffer_AppendNumber(wrBuf, seq, 2); + * 0b001000eessssssss where 'e' is the low two bits of the + * epoch and 's' is the low 8 bits of the sequence number. */ + PRUint8 ct = 0x20 | ((uint64_t)cwSpec->epoch & 0x3); + if (sslBuffer_AppendNumber(wrBuf, ct, 1) != SECSuccess) { + return SECFailure; + } + PRUint8 seq = cwSpec->nextSeqNum & 0xff; + return sslBuffer_AppendNumber(wrBuf, seq, 1); } - rv = sslBuffer_AppendNumber(wrBuf, ssl_ct_application_data, 1); - if (rv != SECSuccess) { + PRUint8 ct = 0x2c | ((PRUint8)cwSpec->epoch & 0x3); + if (sslBuffer_AppendNumber(wrBuf, ct, 1) != SECSuccess) { return SECFailure; } - - /* The epoch and sequence number are encoded on 4 octets, with the epoch - * consuming the first two bits. */ - seq = (((uint64_t)cwSpec->epoch & 3) << 30) | (cwSpec->nextSeqNum & 0x3fffffff); - rv = sslBuffer_AppendNumber(wrBuf, seq, 4); - if (rv != SECSuccess) { + if (sslBuffer_AppendNumber(wrBuf, + (cwSpec->nextSeqNum & 0xffff), 2) != SECSuccess) { return SECFailure; } *needsLength = PR_TRUE; @@ -512,3 +526,43 @@ dtls13_HolddownTimerCb(sslSocket *ss) ssl_CipherSpecReleaseByEpoch(ss, ssl_secret_read, TrafficKeyHandshake); ssl_ClearPRCList(&ss->ssl3.hs.dtlsRcvdHandshake, NULL); } + +SECStatus +dtls13_MaskSequenceNumber(sslSocket *ss, ssl3CipherSpec *spec, + PRUint8 *hdr, PRUint8 *cipherText, PRUint32 cipherTextLen) +{ + PORT_Assert(IS_DTLS(ss)); + if (spec->version < SSL_LIBRARY_VERSION_TLS_1_3) { + return SECSuccess; + } + + if (spec->maskContext) { + PRUint8 mask[2]; + SECStatus rv = ssl_CreateMaskInner(spec->maskContext, cipherText, cipherTextLen, mask, sizeof(mask)); + + if (rv != SECSuccess) { + return SECFailure; + } + + hdr[1] ^= mask[0]; + if (hdr[0] & 0x08) { + hdr[2] ^= mask[1]; + } + } + + return SECSuccess; +} + +CK_MECHANISM_TYPE +tls13_SequenceNumberEncryptionMechanism(SSLCipherAlgorithm bulkAlgorithm) +{ + switch (bulkAlgorithm) { + case ssl_calg_aes_gcm: + return CKM_AES_ECB; + case ssl_calg_chacha20: + return CKM_NSS_CHACHA20_CTR; + default: + PORT_Assert(PR_FALSE); + } + return CKM_INVALID_MECHANISM; +} diff --git a/lib/ssl/dtls13con.h b/lib/ssl/dtls13con.h index ce92a8a55b..057d63efb4 100644 --- a/lib/ssl/dtls13con.h +++ b/lib/ssl/dtls13con.h @@ -10,7 +10,7 @@ #define __dtls13con_h_ SECStatus dtls13_InsertCipherTextHeader(const sslSocket *ss, - ssl3CipherSpec *cwSpec, + const ssl3CipherSpec *cwSpec, sslBuffer *wrBuf, PRBool *needsLength); SECStatus dtls13_RememberFragment(sslSocket *ss, PRCList *list, @@ -29,5 +29,9 @@ SECStatus dtls13_SendAck(sslSocket *ss); void dtls13_SendAckCb(sslSocket *ss); void dtls13_HolddownTimerCb(sslSocket *ss); void dtls_ReceivedFirstMessageInFlight(sslSocket *ss); +SECStatus dtls13_MaskSequenceNumber(sslSocket *ss, ssl3CipherSpec *spec, + PRUint8 *hdr, PRUint8 *cipherText, PRUint32 cipherTextLen); + +CK_MECHANISM_TYPE tls13_SequenceNumberEncryptionMechanism(SSLCipherAlgorithm bulkAlgorithm); #endif diff --git a/lib/ssl/dtlscon.c b/lib/ssl/dtlscon.c index 9417063f12..ae84b81d9e 100644 --- a/lib/ssl/dtlscon.c +++ b/lib/ssl/dtlscon.c @@ -1335,6 +1335,14 @@ dtls_IsLongHeader(SSL3ProtocolVersion version, PRUint8 firstOctet) #endif } +PRBool +dtls_IsDtls13Ciphertext(SSL3ProtocolVersion version, PRUint8 firstOctet) +{ + // Allow no version in case we haven't negotiated one yet. + return (version == 0 || version >= SSL_LIBRARY_VERSION_TLS_1_3) && + (firstOctet & 0xe0) == 0x20; +} + DTLSEpoch dtls_ReadEpoch(const ssl3CipherSpec *crSpec, const PRUint8 *hdr) { @@ -1349,13 +1357,12 @@ dtls_ReadEpoch(const ssl3CipherSpec *crSpec, const PRUint8 *hdr) /* A lot of how we recover the epoch here will depend on how we plan to * manage KeyUpdate. In the case that we decide to install a new read spec * as a KeyUpdate is handled, crSpec will always be the highest epoch we can - * possibly receive. That makes this easier to manage. */ - if ((hdr[0] & 0xe0) == 0x20) { + * possibly receive. That makes this easier to manage. + */ + if (dtls_IsDtls13Ciphertext(crSpec->version, hdr[0])) { + /* TODO(ekr@rtfm.com: do something with the two-bit epoch. */ /* Use crSpec->epoch, or crSpec->epoch - 1 if the last bit differs. */ - if (((hdr[0] >> 4) & 1) == (crSpec->epoch & 1)) { - return crSpec->epoch; - } - return crSpec->epoch - 1; + return crSpec->epoch - ((hdr[0] ^ crSpec->epoch) & 0x3); } /* dtls_GatherData should ensure that this works. */ @@ -1398,20 +1405,15 @@ dtls_ReadSequenceNumber(const ssl3CipherSpec *spec, const PRUint8 *hdr) * sequence number is replaced. If that causes the value to exceed the * maximum, subtract an entire range. */ - if ((hdr[0] & 0xe0) == 0x20) { - /* A 12-bit sequence number. */ - cap = spec->nextSeqNum + (1ULL << 11); - partial = (((sslSequenceNumber)hdr[0] & 0xf) << 8) | - (sslSequenceNumber)hdr[1]; - mask = (1ULL << 12) - 1; + if (hdr[0] & 0x08) { + cap = spec->nextSeqNum + (1ULL << 15); + partial = (((sslSequenceNumber)hdr[1]) << 8) | + (sslSequenceNumber)hdr[2]; + mask = (1ULL << 16) - 1; } else { - /* A 30-bit sequence number. */ - cap = spec->nextSeqNum + (1ULL << 29); - partial = (((sslSequenceNumber)hdr[1] & 0x3f) << 24) | - ((sslSequenceNumber)hdr[2] << 16) | - ((sslSequenceNumber)hdr[3] << 8) | - (sslSequenceNumber)hdr[4]; - mask = (1ULL << 30) - 1; + cap = spec->nextSeqNum + (1ULL << 7); + partial = (sslSequenceNumber)hdr[1]; + mask = (1ULL << 8) - 1; } seqNum = (cap & ~mask) | partial; /* The second check prevents the value from underflowing if we get a large diff --git a/lib/ssl/dtlscon.h b/lib/ssl/dtlscon.h index 4ede3c2ca9..9d10aa248d 100644 --- a/lib/ssl/dtlscon.h +++ b/lib/ssl/dtlscon.h @@ -47,4 +47,5 @@ extern PRBool dtls_IsRelevant(sslSocket *ss, const ssl3CipherSpec *spec, sslSequenceNumber *seqNum); void dtls_ReceivedFirstMessageInFlight(sslSocket *ss); PRBool dtls_IsLongHeader(SSL3ProtocolVersion version, PRUint8 firstOctet); +PRBool dtls_IsDtls13Ciphertext(SSL3ProtocolVersion version, PRUint8 firstOctet); #endif diff --git a/lib/ssl/ssl3con.c b/lib/ssl/ssl3con.c index 60b247fd7c..e8ea99d827 100644 --- a/lib/ssl/ssl3con.c +++ b/lib/ssl/ssl3con.c @@ -2406,7 +2406,6 @@ ssl_ProtectRecord(sslSocket *ss, ssl3CipherSpec *cwSpec, SSLContentType ct, PORT_Assert(cwSpec->cipherDef->max_records <= RECORD_SEQ_MAX); if (cwSpec->nextSeqNum >= cwSpec->cipherDef->max_records) { - /* We should have automatically updated before here in TLS 1.3. */ PORT_Assert(cwSpec->version < SSL_LIBRARY_VERSION_TLS_1_3); SSL_TRC(3, ("%d: SSL[-]: write sequence number at limit 0x%0llx", SSL_GETPID(), cwSpec->nextSeqNum)); @@ -2438,7 +2437,28 @@ ssl_ProtectRecord(sslSocket *ss, ssl3CipherSpec *cwSpec, SSLContentType ct, } #else if (cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_3) { + PRUint8 *cipherText = SSL_BUFFER_NEXT(wrBuf); + unsigned int bufLen = SSL_BUFFER_LEN(wrBuf); rv = tls13_ProtectRecord(ss, cwSpec, ct, pIn, contentLen, wrBuf); + if (rv != SECSuccess) { + return SECFailure; + } + if (IS_DTLS(ss)) { + bufLen = SSL_BUFFER_LEN(wrBuf) - bufLen; +#ifdef UNSAFE_FUZZER_MODE + /* The null cipher doesn't add a tag. Make sure the "ciphertext" + * is long enough for mask creation. */ + unsigned char tmpCt[AES_BLOCK_SIZE] = { 0 }; + if (bufLen < 16) { + memcpy(tmpCt, cipherText, bufLen); + bufLen = sizeof(tmpCt); + cipherText = tmpCt; + } +#endif + rv = dtls13_MaskSequenceNumber(ss, cwSpec, + SSL_BUFFER_BASE(wrBuf), + cipherText, bufLen); + } } else { rv = ssl3_MACEncryptRecord(cwSpec, ss->sec.isServer, IS_DTLS(ss), ct, pIn, contentLen, wrBuf); @@ -12899,6 +12919,24 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText) } isTLS = (PRBool)(spec->version > SSL_LIBRARY_VERSION_3_0); if (IS_DTLS(ss)) { + unsigned int bufLen = SSL_BUFFER_LEN(cText->buf); + unsigned char *cipherText = SSL_BUFFER_BASE(cText->buf); +#ifdef UNSAFE_FUZZER_MODE + /* The null cipher doesn't add a tag. Make sure the "ciphertext" + * is long enough for mask creation. */ + unsigned char tmpCt[AES_BLOCK_SIZE] = { 0 }; + if (bufLen < 16) { + memcpy(tmpCt, cipherText, bufLen); + bufLen = sizeof(tmpCt); + cipherText = tmpCt; + } +#endif + if (dtls13_MaskSequenceNumber(ss, spec, cText->hdr, + cipherText, bufLen) != SECSuccess) { + ssl_ReleaseSpecReadLock(ss); /*****************************/ + PORT_SetError(SSL_ERROR_DECRYPTION_FAILURE); + return SECFailure; + } if (!dtls_IsRelevant(ss, spec, cText, &cText->seqNum)) { ssl_ReleaseSpecReadLock(ss); /*****************************/ return SECSuccess; @@ -12940,7 +12978,10 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText) /* Encrypted application data records could arrive before the handshake * completes in DTLS 1.3. These can look like valid TLS 1.2 application_data * records in epoch 0, which is never valid. Pretend they didn't decrypt. */ - if (spec->epoch == 0 && rType == ssl_ct_application_data) { + + if (spec->epoch == 0 && ((IS_DTLS(ss) && + dtls_IsDtls13Ciphertext(0, rType)) || + rType == ssl_ct_application_data)) { PORT_SetError(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA); alert = unexpected_message; rv = SECFailure; diff --git a/lib/ssl/ssl3gthr.c b/lib/ssl/ssl3gthr.c index f9c741746f..3bc6e8edcb 100644 --- a/lib/ssl/ssl3gthr.c +++ b/lib/ssl/ssl3gthr.c @@ -268,6 +268,7 @@ dtls_GatherData(sslSocket *ss, sslGather *gs, int flags) PRUint8 contentType; unsigned int headerLen; SECStatus rv; + PRBool dtlsLengthPresent = PR_TRUE; SSL_TRC(30, ("dtls_GatherData")); @@ -316,8 +317,20 @@ dtls_GatherData(sslSocket *ss, sslGather *gs, int flags) headerLen = 13; } else if (contentType == ssl_ct_application_data) { headerLen = 7; - } else if ((contentType & 0xe0) == 0x20) { - headerLen = 2; + } else if (dtls_IsDtls13Ciphertext(ss->version, contentType)) { + /* We don't support CIDs. */ + if (contentType & 0x10) { + PORT_Assert(PR_FALSE); + PORT_SetError(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE); + gs->dtlsPacketOffset = 0; + gs->dtlsPacket.len = 0; + return -1; + } + + dtlsLengthPresent = (contentType & 0x04) == 0x04; + PRUint8 dtlsSeqNoSize = (contentType & 0x08) ? 2 : 1; + PRUint8 dtlsLengthBytes = dtlsLengthPresent ? 2 : 0; + headerLen = 1 + dtlsSeqNoSize + dtlsLengthBytes; } else { SSL_DBG(("%d: SSL3[%d]: invalid first octet (%d) for DTLS", SSL_GETPID(), ss->fd, contentType)); @@ -345,12 +358,10 @@ dtls_GatherData(sslSocket *ss, sslGather *gs, int flags) gs->dtlsPacketOffset += headerLen; /* Have received SSL3 record header in gs->hdr. */ - if (headerLen == 13) { - gs->remainder = (gs->hdr[11] << 8) | gs->hdr[12]; - } else if (headerLen == 7) { - gs->remainder = (gs->hdr[5] << 8) | gs->hdr[6]; + if (dtlsLengthPresent) { + gs->remainder = (gs->hdr[headerLen - 2] << 8) | + gs->hdr[headerLen - 1]; } else { - PORT_Assert(headerLen == 2); gs->remainder = gs->dtlsPacket.len - gs->dtlsPacketOffset; } diff --git a/lib/ssl/ssl3prot.h b/lib/ssl/ssl3prot.h index ffe8373011..b180931e93 100644 --- a/lib/ssl/ssl3prot.h +++ b/lib/ssl/ssl3prot.h @@ -14,7 +14,7 @@ typedef PRUint16 SSL3ProtocolVersion; /* version numbers are defined in sslproto.h */ /* DTLS 1.3 is still a draft. */ -#define DTLS_1_3_DRAFT_VERSION 28 +#define DTLS_1_3_DRAFT_VERSION 30 typedef PRUint16 ssl3CipherSuite; /* The cipher suites are defined in sslproto.h */ diff --git a/lib/ssl/sslexp.h b/lib/ssl/sslexp.h index b734d86ca3..61b1fc088a 100644 --- a/lib/ssl/sslexp.h +++ b/lib/ssl/sslexp.h @@ -826,6 +826,56 @@ typedef PRTime(PR_CALLBACK *SSLTimeFunc)(void *arg); PRUint16 _numCiphers), \ (fd, cipherOrder, numCiphers)) +/* + * The following functions expose a masking primitive that uses ciphersuite and + * version information to set paramaters for the masking key and mask generation + * logic. This is only supported for TLS 1.3. + * + * The key and IV are generated using the TLS KDF with a custom label. That is + * HKDF-Expand-Label(secret, label, "", L), where |label| is an input to + * SSL_CreateMaskingContext. + * + * The mask generation logic in SSL_CreateMask is determined by the underlying + * symmetric cipher: + * - For AES-ECB, mask = AES-ECB(mask_key, sample). |len| must be <= 16 as + * the output is limited to a single block. + * - For CHACHA20, mask = ChaCha20(mask_key, sample[0..3], sample[4..15], {0}.len) + * That is, the low 4 bytes of |sample| used as the counter, the remaining 12 bytes + * the nonce. We encrypt |len| bytes of zeros, returning the raw key stream. + * + * The caller must pre-allocate at least |len| bytes for output. If the underlying + * cipher cannot produce the requested amount of data, SECFailure is returned. + */ + +typedef struct SSLMaskingContextStr { + CK_MECHANISM_TYPE mech; + PRUint16 version; + PRUint16 cipherSuite; + PK11SymKey *secret; +} SSLMaskingContext; + +#define SSL_CreateMaskingContext(version, cipherSuite, secret, \ + label, labelLen, ctx) \ + SSL_EXPERIMENTAL_API("SSL_CreateMaskingContext", \ + (PRUint16 _version, PRUint16 _cipherSuite, \ + PK11SymKey * _secret, \ + const char *_label, \ + unsigned int _labelLen, \ + SSLMaskingContext **_ctx), \ + (version, cipherSuite, secret, label, labelLen, ctx)) + +#define SSL_DestroyMaskingContext(ctx) \ + SSL_EXPERIMENTAL_API("SSL_DestroyMaskingContext", \ + (SSLMaskingContext * _ctx), \ + (ctx)) + +#define SSL_CreateMask(ctx, sample, sampleLen, mask, maskLen) \ + SSL_EXPERIMENTAL_API("SSL_CreateMask", \ + (SSLMaskingContext * _ctx, const PRUint8 *_sample, \ + unsigned int _sampleLen, PRUint8 *_mask, \ + unsigned int _maskLen), \ + (ctx, sample, sampleLen, mask, maskLen)) + /* Deprecated experimental APIs */ #define SSL_UseAltServerHelloType(fd, enable) SSL_DEPRECATED_EXPERIMENTAL_API #define SSL_SetupAntiReplay(a, b, c) SSL_DEPRECATED_EXPERIMENTAL_API diff --git a/lib/ssl/sslimpl.h b/lib/ssl/sslimpl.h index 4a393b281c..af789c73e1 100644 --- a/lib/ssl/sslimpl.h +++ b/lib/ssl/sslimpl.h @@ -810,7 +810,7 @@ typedef struct { /* |seqNum| eventually contains the reconstructed sequence number. */ sslSequenceNumber seqNum; /* The header of the cipherText. */ - const PRUint8 *hdr; + PRUint8 *hdr; unsigned int hdrLen; /* |buf| is the payload of the ciphertext. */ @@ -1849,6 +1849,30 @@ SSLExp_HkdfExpandLabelWithMech(PRUint16 version, PRUint16 cipherSuite, PK11SymKe SECStatus SSLExp_SetTimeFunc(PRFileDesc *fd, SSLTimeFunc f, void *arg); +extern SECStatus ssl_CreateMaskingContextInner(PRUint16 version, PRUint16 cipherSuite, + PK11SymKey *secret, + const char *label, + unsigned int labelLen, + SSLMaskingContext **ctx); + +extern SECStatus ssl_CreateMaskInner(SSLMaskingContext *ctx, const PRUint8 *sample, + unsigned int sampleLen, PRUint8 *outMask, + unsigned int maskLen); + +extern SECStatus ssl_DestroyMaskingContextInner(SSLMaskingContext *ctx); + +SECStatus SSLExp_CreateMaskingContext(PRUint16 version, PRUint16 cipherSuite, + PK11SymKey *secret, + const char *label, + unsigned int labelLen, + SSLMaskingContext **ctx); + +SECStatus SSLExp_CreateMask(SSLMaskingContext *ctx, const PRUint8 *sample, + unsigned int sampleLen, PRUint8 *mask, + unsigned int len); + +SECStatus SSLExp_DestroyMaskingContext(SSLMaskingContext *ctx); + SEC_END_PROTOS #if defined(XP_UNIX) || defined(XP_OS2) || defined(XP_BEOS) diff --git a/lib/ssl/sslprimitive.c b/lib/ssl/sslprimitive.c index 540c178400..5522f96fd3 100644 --- a/lib/ssl/sslprimitive.c +++ b/lib/ssl/sslprimitive.c @@ -6,6 +6,7 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ +#include "blapit.h" #include "keyhi.h" #include "pk11pub.h" #include "sechash.h" @@ -23,34 +24,6 @@ struct SSLAeadContextStr { ssl3KeyMaterial keys; }; -static SECStatus -tls13_GetHashAndCipher(PRUint16 version, PRUint16 cipherSuite, - SSLHashType *hash, const ssl3BulkCipherDef **cipher) -{ - if (version < SSL_LIBRARY_VERSION_TLS_1_3) { - PORT_SetError(SEC_ERROR_INVALID_ARGS); - return SECFailure; - } - - // Lookup and check the suite. - SSLVersionRange vrange = { version, version }; - if (!ssl3_CipherSuiteAllowedForVersionRange(cipherSuite, &vrange)) { - PORT_SetError(SEC_ERROR_INVALID_ARGS); - return SECFailure; - } - const ssl3CipherSuiteDef *suiteDef = ssl_LookupCipherSuiteDef(cipherSuite); - const ssl3BulkCipherDef *cipherDef = ssl_GetBulkCipherDef(suiteDef); - if (cipherDef->type != type_aead) { - PORT_SetError(SEC_ERROR_INVALID_ARGS); - return SECFailure; - } - *hash = suiteDef->prf_hash; - if (cipher != NULL) { - *cipher = cipherDef; - } - return SECSuccess; -} - SECStatus SSLExp_MakeAead(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *secret, const char *labelPrefix, unsigned int labelPrefixLen, @@ -272,3 +245,179 @@ SSLExp_HkdfExpandLabelWithMech(PRUint16 version, PRUint16 cipherSuite, PK11SymKe return tls13_HkdfExpandLabel(prk, hash, hsHash, hsHashLen, label, labelLen, mech, keySize, keyp); } + +SECStatus +ssl_CreateMaskingContextInner(PRUint16 version, PRUint16 cipherSuite, + PK11SymKey *secret, + const char *label, + unsigned int labelLen, + SSLMaskingContext **ctx) +{ + if (!secret || !ctx || (!label && labelLen)) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + + SSLMaskingContext *out = PORT_ZNew(SSLMaskingContext); + if (out == NULL) { + goto loser; + } + + SSLHashType hash; + const ssl3BulkCipherDef *cipher; + SECStatus rv = tls13_GetHashAndCipher(version, cipherSuite, + &hash, &cipher); + if (rv != SECSuccess) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + goto loser; /* Code already set. */ + } + + out->mech = tls13_SequenceNumberEncryptionMechanism(cipher->calg); + if (out->mech == CKM_INVALID_MECHANISM) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + goto loser; + } + + // Derive the masking key + rv = tls13_HkdfExpandLabel(secret, hash, + NULL, 0, // Handshake hash. + label, labelLen, + out->mech, + cipher->key_size, &out->secret); + if (rv != SECSuccess) { + goto loser; + } + + out->version = version; + out->cipherSuite = cipherSuite; + + *ctx = out; + return SECSuccess; +loser: + SSLExp_DestroyMaskingContext(out); + return SECFailure; +} + +SECStatus +ssl_CreateMaskInner(SSLMaskingContext *ctx, const PRUint8 *sample, + unsigned int sampleLen, PRUint8 *outMask, + unsigned int maskLen) +{ + if (!ctx || !sample || !sampleLen || !outMask || !maskLen) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + + if (ctx->secret == NULL) { + PORT_SetError(SEC_ERROR_NO_KEY); + return SECFailure; + } + + SECStatus rv = SECFailure; + unsigned int outMaskLen = 0; + + /* Internal output len/buf, for use if the caller allocated and requested + * less than one block of output. |oneBlock| should have size equal to the + * largest block size supported below. */ + PRUint8 oneBlock[AES_BLOCK_SIZE]; + PRUint8 *outMask_ = outMask; + unsigned int maskLen_ = maskLen; + + switch (ctx->mech) { + case CKM_AES_ECB: + if (sampleLen < AES_BLOCK_SIZE) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + if (maskLen_ < AES_BLOCK_SIZE) { + outMask_ = oneBlock; + maskLen_ = sizeof(oneBlock); + } + rv = PK11_Encrypt(ctx->secret, + ctx->mech, + NULL, + outMask_, &outMaskLen, maskLen_, + sample, AES_BLOCK_SIZE); + if (rv == SECSuccess && + maskLen < AES_BLOCK_SIZE) { + memcpy(outMask, outMask_, maskLen); + } + break; + case CKM_NSS_CHACHA20_CTR: + if (sampleLen < 16) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + + SECItem param; + param.type = siBuffer; + param.len = 16; + param.data = (PRUint8 *)sample; // const-cast :( + unsigned char zeros[128] = { 0 }; + + if (maskLen > sizeof(zeros)) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + + rv = PK11_Encrypt(ctx->secret, + ctx->mech, + ¶m, + outMask, &outMaskLen, + maskLen, + zeros, maskLen); + break; + default: + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + + if (rv != SECSuccess) { + PORT_SetError(SEC_ERROR_PKCS11_FUNCTION_FAILED); + return SECFailure; + } + + // Ensure we produced at least as much material as requested. + if (outMaskLen < maskLen) { + PORT_SetError(SEC_ERROR_OUTPUT_LEN); + return SECFailure; + } + + return SECSuccess; +} + +SECStatus +ssl_DestroyMaskingContextInner(SSLMaskingContext *ctx) +{ + if (!ctx) { + return SECSuccess; + } + + PK11_FreeSymKey(ctx->secret); + PORT_ZFree(ctx, sizeof(*ctx)); + return SECSuccess; +} + +SECStatus +SSLExp_CreateMask(SSLMaskingContext *ctx, const PRUint8 *sample, + unsigned int sampleLen, PRUint8 *outMask, + unsigned int maskLen) +{ + return ssl_CreateMaskInner(ctx, sample, sampleLen, outMask, maskLen); +} + +SECStatus +SSLExp_CreateMaskingContext(PRUint16 version, PRUint16 cipherSuite, + PK11SymKey *secret, + const char *label, + unsigned int labelLen, + SSLMaskingContext **ctx) +{ + return ssl_CreateMaskingContextInner(version, cipherSuite, secret, label, labelLen, ctx); +} + +SECStatus +SSLExp_DestroyMaskingContext(SSLMaskingContext *ctx) +{ + return ssl_DestroyMaskingContextInner(ctx); +} diff --git a/lib/ssl/sslsock.c b/lib/ssl/sslsock.c index aa0e76e3ce..581f0c467d 100644 --- a/lib/ssl/sslsock.c +++ b/lib/ssl/sslsock.c @@ -4220,8 +4220,11 @@ struct { EXP(CipherSuiteOrderGet), EXP(CipherSuiteOrderSet), EXP(CreateAntiReplayContext), + EXP(CreateMask), + EXP(CreateMaskingContext), EXP(DelegateCredential), EXP(DestroyAead), + EXP(DestroyMaskingContext), EXP(DestroyResumptionTokenInfo), EXP(EnableESNI), EXP(EncodeESNIKeys), diff --git a/lib/ssl/sslspec.c b/lib/ssl/sslspec.c index def3c67505..c5bedad7a6 100644 --- a/lib/ssl/sslspec.c +++ b/lib/ssl/sslspec.c @@ -7,6 +7,8 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "ssl.h" +#include "sslexp.h" +#include "sslimpl.h" #include "sslproto.h" #include "pk11func.h" #include "secitem.h" @@ -227,6 +229,7 @@ ssl_FreeCipherSpec(ssl3CipherSpec *spec) } PK11_FreeSymKey(spec->masterSecret); ssl_DestroyKeyMaterial(&spec->keyMaterial); + ssl_DestroyMaskingContextInner(spec->maskContext); PORT_ZFree(spec, sizeof(*spec)); } diff --git a/lib/ssl/sslspec.h b/lib/ssl/sslspec.h index ca9ef540fb..d00b20d760 100644 --- a/lib/ssl/sslspec.h +++ b/lib/ssl/sslspec.h @@ -169,6 +169,9 @@ struct ssl3CipherSpecStr { * negotiated value for TLS 1.3; it is reduced by one to account for the * content type octet. */ PRUint16 recordSizeLimit; + + /* Masking context used for DTLS 1.3 */ + SSLMaskingContext *maskContext; }; typedef void (*sslCipherSpecChangedFunc)(void *arg, diff --git a/lib/ssl/tls13con.c b/lib/ssl/tls13con.c index c3528a52f8..97c1918725 100644 --- a/lib/ssl/tls13con.c +++ b/lib/ssl/tls13con.c @@ -131,6 +131,7 @@ const char kHkdfLabelExporterMasterSecret[] = "exp master"; const char kHkdfLabelResumption[] = "resumption"; const char kHkdfLabelTrafficUpdate[] = "traffic upd"; const char kHkdfPurposeKey[] = "key"; +const char kHkdfPurposeSn[] = "sn"; const char kHkdfPurposeIv[] = "iv"; const char keylogLabelClientEarlyTrafficSecret[] = "CLIENT_EARLY_TRAFFIC_SECRET"; @@ -286,6 +287,34 @@ tls13_GetHash(const sslSocket *ss) return ss->ssl3.hs.suite_def->prf_hash; } +SECStatus +tls13_GetHashAndCipher(PRUint16 version, PRUint16 cipherSuite, + SSLHashType *hash, const ssl3BulkCipherDef **cipher) +{ + if (version < SSL_LIBRARY_VERSION_TLS_1_3) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + + // Lookup and check the suite. + SSLVersionRange vrange = { version, version }; + if (!ssl3_CipherSuiteAllowedForVersionRange(cipherSuite, &vrange)) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + const ssl3CipherSuiteDef *suiteDef = ssl_LookupCipherSuiteDef(cipherSuite); + const ssl3BulkCipherDef *cipherDef = ssl_GetBulkCipherDef(suiteDef); + if (cipherDef->type != type_aead) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + *hash = suiteDef->prf_hash; + if (cipher != NULL) { + *cipher = cipherDef; + } + return SECSuccess; +} + unsigned int tls13_GetHashSizeForHash(SSLHashType hash) { @@ -3474,6 +3503,17 @@ tls13_DeriveTrafficKeys(sslSocket *ss, ssl3CipherSpec *spec, goto loser; } + if (IS_DTLS(ss) && spec->epoch > 0) { + rv = ssl_CreateMaskingContextInner(spec->version, + ss->ssl3.hs.cipher_suite, prk, kHkdfPurposeSn, + strlen(kHkdfPurposeSn), &spec->maskContext); + if (rv != SECSuccess) { + LOG_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE); + PORT_Assert(0); + goto loser; + } + } + rv = tls13_HkdfExpandLabelRaw(prk, tls13_GetHash(ss), NULL, 0, kHkdfPurposeIv, strlen(kHkdfPurposeIv), diff --git a/lib/ssl/tls13con.h b/lib/ssl/tls13con.h index bd309419fe..0160740a49 100644 --- a/lib/ssl/tls13con.h +++ b/lib/ssl/tls13con.h @@ -44,10 +44,12 @@ PRBool tls13_InHsState(sslSocket *ss, ...); PRBool tls13_IsPostHandshake(const sslSocket *ss); -SSLHashType tls13_GetHashForCipherSuite(ssl3CipherSuite suite); SSLHashType tls13_GetHash(const sslSocket *ss); -unsigned int tls13_GetHashSizeForHash(SSLHashType hash); +SECStatus tls13_GetHashAndCipher(PRUint16 version, PRUint16 cipherSuite, + SSLHashType *hash, const ssl3BulkCipherDef **cipher); +SSLHashType tls13_GetHashForCipherSuite(ssl3CipherSuite suite); unsigned int tls13_GetHashSize(const sslSocket *ss); +unsigned int tls13_GetHashSizeForHash(SSLHashType hash); CK_MECHANISM_TYPE tls13_GetHkdfMechanism(sslSocket *ss); CK_MECHANISM_TYPE tls13_GetHkdfMechanismForHash(SSLHashType hash); SECStatus tls13_ComputeHash(sslSocket *ss, SSL3Hashes *hashes,