diff --git a/gtests/ssl_gtest/ssl_record_unittest.cc b/gtests/ssl_gtest/ssl_record_unittest.cc index ef81b222c8..54acac9d3f 100644 --- a/gtests/ssl_gtest/ssl_record_unittest.cc +++ b/gtests/ssl_gtest/ssl_record_unittest.cc @@ -10,6 +10,8 @@ #include "databuffer.h" #include "gtest_utils.h" +#include "tls_connect.h" +#include "tls_filter.h" namespace nss_test { @@ -51,8 +53,8 @@ class TlsPaddingTest << " total length=" << plaintext_.len() << std::endl; std::cerr << "Plaintext: " << plaintext_ << std::endl; sslBuffer s; - s.buf = const_cast( - static_cast(plaintext_.data())); + s.buf = const_cast( + static_cast(plaintext_.data())); s.len = plaintext_.len(); SECStatus rv = ssl_RemoveTLSCBCPadding(&s, kMacSize); if (expect_success) { @@ -99,6 +101,75 @@ TEST_P(TlsPaddingTest, LastByteOfPadWrong) { } } +class RecordReplacer : public TlsRecordFilter { + public: + RecordReplacer(size_t size) + : TlsRecordFilter(), enabled_(false), size_(size) {} + + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (!enabled_) { + return KEEP; + } + + EXPECT_EQ(kTlsApplicationDataType, header.content_type()); + changed->Allocate(size_); + + for (size_t i = 0; i < size_; ++i) { + changed->data()[i] = i & 0xff; + } + + enabled_ = false; + return CHANGE; + } + + void Enable() { enabled_ = true; } + + private: + bool enabled_; + size_t size_; +}; + +TEST_F(TlsConnectStreamTls13, LargeRecord) { + EnsureTlsSetup(); + + const size_t record_limit = 16384; + auto replacer = std::make_shared(record_limit); + client_->SetTlsRecordFilter(replacer); + replacer->EnableDecryption(); + Connect(); + + replacer->Enable(); + client_->SendData(10); + WAIT_(server_->received_bytes() == record_limit, 2000); + ASSERT_EQ(record_limit, server_->received_bytes()); +} + +TEST_F(TlsConnectStreamTls13, TooLargeRecord) { + EnsureTlsSetup(); + + const size_t record_limit = 16384; + auto replacer = std::make_shared(record_limit + 1); + client_->SetTlsRecordFilter(replacer); + replacer->EnableDecryption(); + Connect(); + + replacer->Enable(); + ExpectAlert(server_, kTlsAlertRecordOverflow); + client_->SendData(10); // This is expanded. + + uint8_t buf[record_limit + 2]; + PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_GT(0, rv); + EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, PORT_GetError()); + + // Read the server alert. + rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf)); + EXPECT_GT(0, rv); + EXPECT_EQ(SSL_ERROR_RECORD_OVERFLOW_ALERT, PORT_GetError()); +} + const static size_t kContentSizesArr[] = { 1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288}; diff --git a/lib/ssl/ssl3con.c b/lib/ssl/ssl3con.c index 4fe8c518ca..05752716b6 100644 --- a/lib/ssl/ssl3con.c +++ b/lib/ssl/ssl3con.c @@ -12804,7 +12804,7 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf) /* ** Having completed the decompression, check the length again. */ - if (isTLS && databuf->len > (MAX_FRAGMENT_LENGTH + 1024)) { + if (isTLS && databuf->len > MAX_FRAGMENT_LENGTH) { SSL3_SendAlert(ss, alert_fatal, record_overflow); PORT_SetError(SSL_ERROR_RX_RECORD_TOO_LONG); return SECFailure;