/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "config.h" #include #include #include #include #include "nspr.h" #include "nss.h" #include "prio.h" #include "prnetdb.h" #include "secerr.h" #include "ssl.h" #include "ssl3prot.h" #include "sslerr.h" #include "sslproto.h" #include "nss_scoped_ptrs.h" #include "sslimpl.h" #include "tls13ech.h" #include "base64.h" #include "nsskeys.h" static const char* kVersionDisableFlags[] = {"no-ssl3", "no-tls1", "no-tls11", "no-tls12", "no-tls13"}; /* Default EarlyData dummy data determined by Bogo implementation. */ const unsigned char kBogoDummyData[] = {'h', 'e', 'l', 'l', 'o'}; bool exitCodeUnimplemented = false; std::string FormatError(PRErrorCode code) { return std::string(":") + PORT_ErrorToName(code) + ":" + ":" + PORT_ErrorToString(code); } static void StringRemoveNewlines(std::string& str) { str.erase(std::remove(str.begin(), str.end(), '\n'), str.cend()); str.erase(std::remove(str.begin(), str.end(), '\r'), str.cend()); } class TestAgent { public: TestAgent(const Config& cfg) : cfg_(cfg) {} ~TestAgent() {} static std::unique_ptr Create(const Config& cfg) { std::unique_ptr agent(new TestAgent(cfg)); if (!agent->Init()) return nullptr; return agent; } bool Init() { if (!ConnectTcp()) { return false; } if (!SetupKeys()) { std::cerr << "Couldn't set up keys/certs\n"; return false; } if (!SetupOptions()) { std::cerr << "Couldn't configure socket\n"; return false; } SECStatus rv = SSL_ResetHandshake(ssl_fd_.get(), cfg_.get("server")); if (rv != SECSuccess) return false; return true; } bool ConnectTcp() { if (!(cfg_.get("ipv6") && OpenConnection("::1")) && !OpenConnection("127.0.0.1")) { return false; } ssl_fd_ = ScopedPRFileDesc(SSL_ImportFD(NULL, pr_fd_.get())); if (!ssl_fd_) { return false; } pr_fd_.release(); return true; } bool OpenConnection(const char* ip) { PRStatus prv; PRNetAddr addr; prv = PR_StringToNetAddr(ip, &addr); if (prv != PR_SUCCESS) { return false; } addr.inet.port = PR_htons(cfg_.get("port")); pr_fd_ = ScopedPRFileDesc(PR_OpenTCPSocket(addr.raw.family)); if (!pr_fd_) return false; prv = PR_Connect(pr_fd_.get(), &addr, PR_INTERVAL_NO_TIMEOUT); if (prv != PR_SUCCESS) { return false; } uint64_t shim_id = cfg_.get("shim-id"); uint8_t buf[8] = {0}; for (size_t i = 0; i < 8; i++) { buf[i] = shim_id & 0xff; shim_id >>= 8; } int sent = PR_Write(pr_fd_.get(), buf, sizeof(buf)); if (sent != sizeof(buf)) { return false; } return true; } bool SetupKeys() { SECStatus rv; if (cfg_.get("key-file") != "") { key_ = ScopedSECKEYPrivateKey( ReadPrivateKey(cfg_.get("key-file"))); if (!key_) return false; } if (cfg_.get("cert-file") != "") { cert_ = ScopedCERTCertificate( ReadCertificate(cfg_.get("cert-file"))); if (!cert_) return false; } // Needed because certs are not entirely valid. rv = SSL_AuthCertificateHook(ssl_fd_.get(), AuthCertificateHook, this); if (rv != SECSuccess) return false; if (cfg_.get("server")) { // Server rv = SSL_ConfigServerCert(ssl_fd_.get(), cert_.get(), key_.get(), nullptr, 0); if (rv != SECSuccess) { std::cerr << "Couldn't configure server cert\n"; return false; } } else if (key_ && cert_) { // Client. rv = SSL_GetClientAuthDataHook(ssl_fd_.get(), GetClientAuthDataHook, this); if (rv != SECSuccess) return false; } return true; } static bool ConvertFromWireVersion(SSLProtocolVariant variant, int wire_version, uint16_t* lib_version) { // These default values are used when {min,max}-version isn't given. if (wire_version == 0 || wire_version == 0xffff) { *lib_version = static_cast(wire_version); return true; } #ifdef TLS_1_3_DRAFT_VERSION if (wire_version == (0x7f00 | TLS_1_3_DRAFT_VERSION)) { // N.B. SSL_LIBRARY_VERSION_DTLS_1_3_WIRE == SSL_LIBRARY_VERSION_TLS_1_3 wire_version = SSL_LIBRARY_VERSION_TLS_1_3; } #endif if (variant == ssl_variant_datagram) { switch (wire_version) { case SSL_LIBRARY_VERSION_DTLS_1_0_WIRE: *lib_version = SSL_LIBRARY_VERSION_DTLS_1_0; break; case SSL_LIBRARY_VERSION_DTLS_1_2_WIRE: *lib_version = SSL_LIBRARY_VERSION_DTLS_1_2; break; case SSL_LIBRARY_VERSION_DTLS_1_3_WIRE: *lib_version = SSL_LIBRARY_VERSION_DTLS_1_3; break; default: std::cerr << "Unrecognized DTLS version " << wire_version << ".\n"; return false; } } else { if (wire_version < SSL_LIBRARY_VERSION_3_0 || wire_version > SSL_LIBRARY_VERSION_TLS_1_3) { std::cerr << "Unrecognized TLS version " << wire_version << ".\n"; return false; } *lib_version = static_cast(wire_version); } return true; } bool GetVersionRange(SSLVersionRange* range_out, SSLProtocolVariant variant) { SSLVersionRange supported; if (SSL_VersionRangeGetSupported(variant, &supported) != SECSuccess) { return false; } uint16_t min_allowed; uint16_t max_allowed; if (!ConvertFromWireVersion(variant, cfg_.get("min-version"), &min_allowed)) { return false; } if (!ConvertFromWireVersion(variant, cfg_.get("max-version"), &max_allowed)) { return false; } min_allowed = std::max(min_allowed, supported.min); max_allowed = std::min(max_allowed, supported.max); bool found_min = false; bool found_max = false; // Ignore -no-ssl3, because SSLv3 is never supported. for (size_t i = 1; i < PR_ARRAY_SIZE(kVersionDisableFlags); ++i) { auto version = static_cast(SSL_LIBRARY_VERSION_TLS_1_0 + (i - 1)); if (variant == ssl_variant_datagram) { // In DTLS mode, the -no-tlsN flags refer to DTLS versions, // but NSS wants the corresponding TLS versions. if (version == SSL_LIBRARY_VERSION_TLS_1_1) { // DTLS 1.1 doesn't exist. continue; } if (version == SSL_LIBRARY_VERSION_TLS_1_0) { version = SSL_LIBRARY_VERSION_DTLS_1_0; } } if (version < min_allowed) { continue; } if (version > max_allowed) { break; } const bool allowed = !cfg_.get(kVersionDisableFlags[i]); if (!found_min && allowed) { found_min = true; range_out->min = version; } if (found_min && !found_max) { if (allowed) { range_out->max = version; } else { found_max = true; } } if (found_max && allowed) { std::cerr << "Discontiguous version range.\n"; return false; } } if (!found_min) { std::cerr << "All versions disabled.\n"; } return found_min; } bool SetupOptions() { SECStatus rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE); if (rv != SECSuccess) return false; rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_SESSION_TICKETS, PR_TRUE); if (rv != SECSuccess) return false; SSLVersionRange vrange; if (!GetVersionRange(&vrange, ssl_variant_stream)) return false; rv = SSL_VersionRangeSet(ssl_fd_.get(), &vrange); if (rv != SECSuccess) return false; SSLVersionRange verify_vrange; rv = SSL_VersionRangeGet(ssl_fd_.get(), &verify_vrange); if (rv != SECSuccess) return false; if (vrange.min != verify_vrange.min || vrange.max != verify_vrange.max) return false; rv = SSL_OptionSet(ssl_fd_.get(), SSL_NO_CACHE, false); if (rv != SECSuccess) return false; auto alpn = cfg_.get("advertise-alpn"); if (!alpn.empty()) { assert(!cfg_.get("server")); rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_ALPN, PR_TRUE); if (rv != SECSuccess) return false; rv = SSL_SetNextProtoNego( ssl_fd_.get(), reinterpret_cast(alpn.c_str()), alpn.size()); if (rv != SECSuccess) return false; } // Set supported signature schemes. auto sign_prefs = cfg_.get>("signing-prefs"); auto verify_prefs = cfg_.get>("verify-prefs"); if (sign_prefs.empty()) { sign_prefs = verify_prefs; } else if (!verify_prefs.empty()) { return false; // Both shouldn't be set. } if (!sign_prefs.empty()) { std::vector sig_schemes; std::transform( sign_prefs.begin(), sign_prefs.end(), std::back_inserter(sig_schemes), [](int scheme) { return static_cast(scheme); }); rv = SSL_SignatureSchemePrefSet( ssl_fd_.get(), sig_schemes.data(), static_cast(sig_schemes.size())); if (rv != SECSuccess) return false; } if (cfg_.get("fallback-scsv")) { rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_FALLBACK_SCSV, PR_TRUE); if (rv != SECSuccess) return false; } if (cfg_.get("false-start")) { rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_FALSE_START, PR_TRUE); if (rv != SECSuccess) return false; } if (cfg_.get("enable-ocsp-stapling")) { rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_OCSP_STAPLING, PR_TRUE); if (rv != SECSuccess) return false; } bool requireClientCert = cfg_.get("require-any-client-certificate"); if (requireClientCert || cfg_.get("verify-peer")) { assert(cfg_.get("server")); rv = SSL_OptionSet(ssl_fd_.get(), SSL_REQUEST_CERTIFICATE, PR_TRUE); if (rv != SECSuccess) return false; rv = SSL_OptionSet( ssl_fd_.get(), SSL_REQUIRE_CERTIFICATE, requireClientCert ? SSL_REQUIRE_ALWAYS : SSL_REQUIRE_NO_ERROR); if (rv != SECSuccess) return false; } if (!cfg_.get("server")) { auto hostname = cfg_.get("host-name"); if (!hostname.empty()) { rv = SSL_SetURL(ssl_fd_.get(), hostname.c_str()); } else { // Needed to make resumption work. rv = SSL_SetURL(ssl_fd_.get(), "server"); } if (rv != SECSuccess) return false; // Setup ECH configs on client if provided auto echConfigList = cfg_.get("ech-config-list"); if (!echConfigList.empty()) { unsigned int binLen; auto bin = ATOB_AsciiToData(echConfigList.c_str(), &binLen); rv = SSLExp_SetClientEchConfigs(ssl_fd_.get(), bin, binLen); if (rv != SECSuccess) return false; free(bin); } if (cfg_.get("enable-grease")) { rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_GREASE, PR_TRUE); if (rv != SECSuccess) return false; } if (cfg_.get("permute-extensions")) { rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_CH_EXTENSION_PERMUTATION, PR_TRUE); if (rv != SECSuccess) return false; } } else { // GREASE - BoGo expects servers to enable GREASE by default rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_GREASE, PR_TRUE); if (rv != SECSuccess) return false; } rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); if (rv != SECSuccess) return false; if (cfg_.get("server")) { // BoGo expects servers to enable ECH (backend) by default rv = SSLExp_EnableTls13BackendEch(ssl_fd_.get(), true); if (rv != SECSuccess) return false; } if (cfg_.get("enable-ech-grease")) { rv = SSLExp_EnableTls13GreaseEch(ssl_fd_.get(), true); if (rv != SECSuccess) return false; } if (cfg_.get("enable-early-data")) { rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_0RTT_DATA, PR_TRUE); if (rv != SECSuccess) return false; } if (!ConfigureGroups()) return false; if (!ConfigureCiphers()) return false; return true; } bool ConfigureGroups() { auto curves = cfg_.get>("curves"); if (curves.size() > 0) { std::vector groups; std::transform( curves.begin(), curves.end(), std::back_inserter(groups), [](int curve) { return static_cast(curve); }); SECStatus rv = SSL_NamedGroupConfig(ssl_fd_.get(), &groups[0], groups.size()); if (rv != SECSuccess) { return false; } // Xyber768 is disabled by policy by default, so if it's requested // we need to update the policy flags as well. for (auto group : groups) { if (group == ssl_grp_kem_xyber768d00) { NSS_SetAlgorithmPolicy(SEC_OID_XYBER768D00, NSS_USE_ALG_IN_SSL_KX, 0); } } } return true; } bool ConfigureCiphers() { auto cipherList = cfg_.get("nss-cipher"); if (cipherList.empty()) { return EnableNonExportCiphers(); } for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { SSLCipherSuiteInfo csinfo; std::string::size_type n; SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, sizeof(csinfo)); if (rv != SECSuccess) { return false; } // Check if cipherList contains the name of the Cipher Suite and // enable/disable accordingly. n = cipherList.find(csinfo.cipherSuiteName, 0); if (std::string::npos == n) { rv = SSL_CipherPrefSet(ssl_fd_.get(), SSL_ImplementedCiphers[i], PR_FALSE); } else { rv = SSL_CipherPrefSet(ssl_fd_.get(), SSL_ImplementedCiphers[i], PR_TRUE); } if (rv != SECSuccess) { return false; } } return true; } bool EnableNonExportCiphers() { for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { SSLCipherSuiteInfo csinfo; SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, sizeof(csinfo)); if (rv != SECSuccess) { return false; } rv = SSL_CipherPrefSet(ssl_fd_.get(), SSL_ImplementedCiphers[i], PR_TRUE); if (rv != SECSuccess) { return false; } } return true; } // Dummy auth certificate hook. static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd, PRBool checksig, PRBool isServer) { return SECSuccess; } static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd, CERTDistNames* caNames, CERTCertificate** cert, SECKEYPrivateKey** privKey) { TestAgent* a = static_cast(self); *cert = CERT_DupCertificate(a->cert_.get()); *privKey = SECKEY_CopyPrivateKey(a->key_.get()); return SECSuccess; } SECStatus Handshake() { return SSL_ForceHandshake(ssl_fd_.get()); } // Implement a trivial echo client/server. Read bytes from the other side, // flip all the bits, and send them back. SECStatus ReadWrite() { for (;;) { uint8_t block[512]; int32_t rv = PR_Read(ssl_fd_.get(), block, sizeof(block)); if (rv < 0) { std::cerr << "Failure reading\n"; return SECFailure; } if (rv == 0) return SECSuccess; int32_t len = rv; for (int32_t i = 0; i < len; ++i) { block[i] ^= 0xff; } rv = PR_Write(ssl_fd_.get(), block, len); if (rv != len) { std::cerr << "Write failure\n"; PORT_SetError(SEC_ERROR_OUTPUT_LEN); return SECFailure; } } } // Write bytes to the other side then read them back and check // that they were correctly XORed as in ReadWrite. SECStatus WriteRead() { static const uint8_t ch = 'E'; // We do 600-byte blocks to provide mis-alignment of the // reader and writer. uint8_t block[600]; memset(block, ch, sizeof(block)); int32_t rv = PR_Write(ssl_fd_.get(), block, sizeof(block)); if (rv != sizeof(block)) { std::cerr << "Write failure\n"; PORT_SetError(SEC_ERROR_OUTPUT_LEN); return SECFailure; } size_t left = sizeof(block); while (left) { rv = PR_Read(ssl_fd_.get(), block, left); if (rv < 0) { std::cerr << "Failure reading\n"; return SECFailure; } if (rv == 0) { PORT_SetError(SEC_ERROR_INPUT_LEN); return SECFailure; } int32_t len = rv; for (int32_t i = 0; i < len; ++i) { if (block[i] != (ch ^ 0xff)) { PORT_SetError(SEC_ERROR_BAD_DATA); return SECFailure; } } left -= len; } return SECSuccess; } SECStatus CheckALPN(std::string expectedALPN) { SECStatus rv; SSLNextProtoState state; char chosen[256]; unsigned int chosen_len; rv = SSL_GetNextProto(ssl_fd_.get(), &state, reinterpret_cast(chosen), &chosen_len, sizeof(chosen)); if (rv != SECSuccess) { PRErrorCode err = PR_GetError(); std::cerr << "SSL_GetNextProto failed with error=" << FormatError(err) << std::endl; return SECFailure; } assert(chosen_len <= sizeof(chosen)); if (std::string(chosen, chosen_len) != expectedALPN) { std::cerr << "Expexted ALPN (" << expectedALPN << ") != Choosen ALPN (" << std::string(chosen, chosen_len) << ")" << std::endl; return SECFailure; } return SECSuccess; } SECStatus AdvertiseALPN(std::string alpn) { return SSL_SetNextProtoNego( ssl_fd_.get(), reinterpret_cast(alpn.c_str()), alpn.size()); } /* Certificate Encoding/Decoding Shrinking functions * See * https://boringssl.googlesource.com/boringssl/+/master/ssl/test/runner/runner.go#16168 */ static SECStatus certCompressionShrinkEncode(const SECItem* input, SECItem* output) { if (input == NULL || input->data == NULL) { PR_SetError(SEC_ERROR_INVALID_ARGS, 0); return SECFailure; } if (input->len < 2) { std::cerr << "Certificate is too short. " << std::endl; PR_SetError(SEC_ERROR_INVALID_ARGS, 0); return SECFailure; } SECITEM_AllocItem(NULL, output, input->len - 2); if (output == NULL || output->data == NULL) { return SECFailure; } /* The shrinking encoding primitive expects the first two bytes of a * certificate to be equal to 0. */ if (input->data[0] != 0 || input->data[1] != 0) { std::cerr << "Cannot compress certificate message." << std::endl; return SECFailure; } for (size_t i = 0; i < output->len; i++) { output->data[i] = input->data[i + 2]; } return SECSuccess; } static SECStatus certCompressionShrinkDecode( const SECItem* input, SECItem* output, size_t expectedLenDecodedCertificate) { if (input == NULL || input->data == NULL) { PR_SetError(SEC_ERROR_INVALID_ARGS, 0); return SECFailure; } if (output == NULL || output->data == NULL || output->len != input->len + 2) { return SECFailure; } if (expectedLenDecodedCertificate != output->len) { std::cerr << "Cannot decompress certificate message." << std::endl; return SECFailure; } output->data[0] = 0; output->data[1] = 0; for (size_t i = 0; i < input->len; i++) { output->data[i + 2] = input->data[i]; } return SECSuccess; } /* Certificate Encoding/Decoding Expanding functions * See * https://boringssl.googlesource.com/boringssl/+/master/ssl/test/runner/runner.go#16186 */ static SECStatus certCompressionExpandEncode(const SECItem* input, SECItem* output) { if (input == NULL || input->data == NULL) { PR_SetError(SEC_ERROR_INVALID_ARGS, 0); return SECFailure; } SECITEM_AllocItem(NULL, output, input->len + 4); if (output == NULL || output->data == NULL) { return SECFailure; } output->data[0] = 1; output->data[1] = 2; output->data[2] = 3; output->data[3] = 4; for (size_t i = 0; i < input->len; i++) { output->data[i + 4] = input->data[i]; } return SECSuccess; } static SECStatus certCompressionExpandDecode( const SECItem* input, SECItem* output, size_t expectedLenDecodedCertificate) { if (input == NULL || input->data == NULL) { PR_SetError(SEC_ERROR_INVALID_ARGS, 0); return SECFailure; } if (input->len < 4) { PR_SetError(SEC_ERROR_INVALID_ARGS, 0); std::cerr << "Certificate is too short. " << std::endl; return SECFailure; } if (output == NULL || output->data == NULL || output->len != input->len - 4) { return SECFailure; } /* See the corresponding compression function. */ if (input->data[0] != 1 || input->data[1] != 2 || input->data[2] != 3 || input->data[3] != 4) { std::cerr << "Cannot decompress certificate message." << std::endl; return SECFailure; } if (expectedLenDecodedCertificate != output->len) { std::cerr << "Cannot decompress certificate message." << std::endl; return SECFailure; } for (size_t i = 0; i < output->len; i++) { output->data[i] = input->data[i + 4]; } return SECSuccess; } /* Certificate Encoding/Decoding Random functions * See * https://boringssl.googlesource.com/boringssl/+/master/ssl/test/runner/runner.go#16201 */ static SECStatus certCompressionRandomEncode(const SECItem* input, SECItem* output) { if (input == NULL || input->data == NULL) { PR_SetError(SEC_ERROR_INVALID_ARGS, 0); return SECFailure; } SECITEM_AllocItem(NULL, output, input->len + 1); if (output == NULL || output->data == NULL) { return SECFailure; } SECStatus rv = PK11_GenerateRandom(output->data, 1); if (rv != SECSuccess) { std::cerr << "Failed to generate randomness. " << std::endl; return SECFailure; } for (size_t i = 0; i < input->len; i++) { output->data[i + 1] = input->data[i]; } return SECSuccess; } static SECStatus certCompressionRandomDecode( const SECItem* input, SECItem* output, size_t expectedLenDecodedCertificate) { if (input == NULL || input->data == NULL) { PR_SetError(SEC_ERROR_INVALID_ARGS, 0); return SECFailure; } if (input->len < 1) { PR_SetError(SEC_ERROR_INVALID_ARGS, 0); std::cerr << "Certificate is too short. " << std::endl; return SECFailure; } if (output == NULL || output->data == NULL || output->len != input->len - 1) { return SECFailure; } if (expectedLenDecodedCertificate != output->len) { std::cerr << "Cannot decompress certificate message." << std::endl; return SECFailure; } for (size_t i = 0; i < output->len; i++) { output->data[i] = input->data[i + 1]; } return SECSuccess; } SECStatus DoExchange(bool resuming) { SECStatus rv; int earlyDataSent = 0; std::string str; sslSocket* ss = ssl_FindSocket(ssl_fd_.get()); if (!ss) { return SECFailure; } if (cfg_.get("install-cert-compression-algs")) { SSLCertificateCompressionAlgorithm t = { (SSLCertificateCompressionAlgorithmID)0xff01, "shrinkingCompressionAlg", certCompressionShrinkEncode, certCompressionShrinkDecode}; SSLCertificateCompressionAlgorithm t1 = { (SSLCertificateCompressionAlgorithmID)0xff02, "expandingCompressionAlg", certCompressionExpandEncode, certCompressionExpandDecode}; SSLCertificateCompressionAlgorithm t2 = { (SSLCertificateCompressionAlgorithmID)0xff03, "randomCompressionAlg", certCompressionRandomEncode, certCompressionRandomDecode}; SSLExp_SetCertificateCompressionAlgorithm(ssl_fd_.get(), t); SSLExp_SetCertificateCompressionAlgorithm(ssl_fd_.get(), t1); SSLExp_SetCertificateCompressionAlgorithm(ssl_fd_.get(), t2); } /* Apply resumption SSL options (if any). */ if (resuming) { /* Client options */ if (!cfg_.get("server")) { auto resumeEchConfigList = cfg_.get("on-resume-ech-config-list"); if (!resumeEchConfigList.empty()) { unsigned int binLen; auto bin = ATOB_AsciiToData(resumeEchConfigList.c_str(), &binLen); rv = SSLExp_SetClientEchConfigs(ssl_fd_.get(), bin, binLen); if (rv != SECSuccess) { PRErrorCode err = PR_GetError(); std::cerr << "Setting up resumption ECH configs failed with error=" << err << FormatError(err) << std::endl; } free(bin); } str = cfg_.get("on-resume-advertise-alpn"); if (!str.empty()) { if (AdvertiseALPN(str) != SECSuccess) { PRErrorCode err = PR_GetError(); std::cerr << "Setting up resumption ALPN failed with error=" << err << FormatError(err) << std::endl; } } } } else { /* Explicitly not on resume (on initial) */ /* Client options */ if (!cfg_.get("server")) { str = cfg_.get("on-initial-advertise-alpn"); if (!str.empty()) { if (AdvertiseALPN(str) != SECSuccess) { PRErrorCode err = PR_GetError(); std::cerr << "Setting up initial ALPN failed with error=" << err << FormatError(err) << std::endl; } } } } /* If client send ClientHello. */ if (!cfg_.get("server")) { ssl_Get1stHandshakeLock(ss); rv = ssl_BeginClientHandshake(ss); ssl_Release1stHandshakeLock(ss); if (rv != SECSuccess) { PRErrorCode err = PR_GetError(); std::cerr << "Handshake failed with error=" << err << FormatError(err) << std::endl; return SECFailure; } /* If the client is resuming. */ if (ss->statelessResume) { SSLPreliminaryChannelInfo pinfo; rv = SSL_GetPreliminaryChannelInfo(ssl_fd_.get(), &pinfo, sizeof(SSLPreliminaryChannelInfo)); if (rv != SECSuccess) { PRErrorCode err = PR_GetError(); std::cerr << "SSL_GetPreliminaryChannelInfo failed with " << err << std::endl; return SECFailure; } /* Check that the used ticket supports early data. */ if (cfg_.get("expect-ticket-supports-early-data")) { if (!pinfo.ticketSupportsEarlyData) { std::cerr << "Expected ticket to support EarlyData" << std::endl; return SECFailure; } } /* If the client should send EarlyData. */ if (cfg_.get("on-resume-shim-writes-first")) { earlyDataSent = ssl_SecureWrite(ss, kBogoDummyData, sizeof(kBogoDummyData)); if (earlyDataSent < 0) { std::cerr << "Sending of EarlyData failed" << std::endl; return SECFailure; } } if (cfg_.get("expect-no-offer-early-data")) { if (earlyDataSent) { std::cerr << "Unexpectedly offered EarlyData" << std::endl; return SECFailure; } } } } /* As server start, as client continue handshake. */ rv = Handshake(); /* Retry config evaluation must be done before error handling since * handshake failure is intended on ech_required tests. */ if (cfg_.get("expect-no-ech-retry-configs")) { if (ss->xtnData.ech && ss->xtnData.ech->retryConfigsValid) { std::cerr << "Unexpectedly received ECH retry configs" << std::endl; return SECFailure; } } /* If given, verify received retry configs before error handling. */ std::string expectedRCs64 = cfg_.get("expect-ech-retry-configs"); if (!expectedRCs64.empty()) { SECItem receivedRCs; /* Get received RetryConfigs. */ if (SSLExp_GetEchRetryConfigs(ssl_fd_.get(), &receivedRCs) != SECSuccess) { std::cerr << "Failed to get ECH retry configs." << std::endl; return SECFailure; } /* (Re-)Encode received configs to compare with expected ASCII string. */ std::string receivedRCs64( BTOA_DataToAscii(receivedRCs.data, receivedRCs.len)); /* Remove newlines (for unknown reasons) added during b64 encoding. */ StringRemoveNewlines(receivedRCs64); if (receivedRCs64 != expectedRCs64) { std::cerr << "Received ECH retry configs did not match expected retry " "configs." << std::endl; return SECFailure; } } /* Check if handshake succeeded. */ if (rv != SECSuccess) { PRErrorCode err = PR_GetError(); std::cerr << "Handshake failed with error=" << err << FormatError(err) << std::endl; return SECFailure; } /* If parts of data was sent as EarlyData make sure to send possibly * unsent rest. This is required to pass bogo resumption tests. */ if (earlyDataSent && earlyDataSent < int(sizeof(kBogoDummyData))) { int toSend = sizeof(kBogoDummyData) - earlyDataSent; earlyDataSent = ssl_SecureWrite(ss, &kBogoDummyData[earlyDataSent], toSend); if (earlyDataSent != toSend) { std::cerr << "Could not send rest of EarlyData after handshake completion" << std::endl; return SECFailure; } } if (cfg_.get("write-then-read")) { rv = WriteRead(); if (rv != SECSuccess) { PRErrorCode err = PR_GetError(); std::cerr << "WriteRead failed with error=" << FormatError(err) << std::endl; return SECFailure; } } else { rv = ReadWrite(); if (rv != SECSuccess) { PRErrorCode err = PR_GetError(); std::cerr << "ReadWrite failed with error=" << FormatError(err) << std::endl; return SECFailure; } } SSLChannelInfo info; rv = SSL_GetChannelInfo(ssl_fd_.get(), &info, sizeof(info)); if (rv != SECSuccess) { PRErrorCode err = PR_GetError(); std::cerr << "SSL_GetChannelInfo failed with error=" << FormatError(err) << std::endl; return SECFailure; } auto sig_alg = cfg_.get("expect-peer-signature-algorithm"); if (sig_alg) { auto expected = static_cast(sig_alg); if (info.signatureScheme != expected) { std::cerr << "Unexpected signature scheme" << std::endl; return SECFailure; } } auto curve_id = cfg_.get("expect-curve-id"); if (curve_id) { auto expected = static_cast(curve_id); if (info.keaGroup != expected && !(info.keaGroup == ssl_grp_none && info.originalKeaGroup == expected)) { std::cerr << "Unexpected named group" << std::endl; return SECFailure; } } if (cfg_.get("expect-ech-accept")) { if (!info.echAccepted) { std::cerr << "Expected ECH" << std::endl; return SECFailure; } } if (cfg_.get("expect-hrr")) { if (!ss->ssl3.hs.helloRetry) { std::cerr << "Expected HRR" << std::endl; return SECFailure; } } str = cfg_.get("expect-alpn"); if (!str.empty()) { if (CheckALPN(str) != SECSuccess) { std::cerr << "Unexpected ALPN" << std::endl; return SECFailure; } } /* if resumed */ if (info.resumed) { if (cfg_.get("expect-session-miss")) { std::cerr << "Expected reject Resume" << std::endl; return SECFailure; } if (cfg_.get("on-resume-expect-ech-accept")) { if (!info.echAccepted) { std::cerr << "Expected ECH on Resume" << std::endl; return SECFailure; } } if (cfg_.get("on-resume-expect-reject-early-data")) { if (info.earlyDataAccepted) { std::cerr << "Expected reject EarlyData" << std::endl; return SECFailure; } } if (cfg_.get("on-resume-expect-accept-early-data")) { if (!info.earlyDataAccepted) { std::cerr << "Expected accept EarlyData" << std::endl; return SECFailure; } } /* On successfully resumed connection. */ if (info.earlyDataAccepted) { str = cfg_.get("on-resume-expect-alpn"); if (!str.empty()) { if (CheckALPN(str) != SECSuccess) { std::cerr << "Unexpected ALPN on Resume" << std::endl; return SECFailure; } } else { /* No real resume but new handshake on EarlyData rejection. */ /* On Retry... */ str = cfg_.get("on-retry-expect-alpn"); if (!str.empty()) { if (CheckALPN(str) != SECSuccess) { std::cerr << "Unexpected ALPN on HRR" << std::endl; return SECFailure; } } } } } else { /* Explicitly not on resume */ if (cfg_.get("on-initial-expect-ech-accept")) { if (!info.echAccepted) { std::cerr << "Expected ECH accept on initial connection" << std::endl; return SECFailure; } } str = cfg_.get("on-initial-expect-alpn"); if (!str.empty()) { if (CheckALPN(str) != SECSuccess) { std::cerr << "Unexpected ALPN on Initial" << std::endl; return SECFailure; } } } return SECSuccess; } private: const Config& cfg_; ScopedPRFileDesc pr_fd_; ScopedPRFileDesc ssl_fd_; ScopedCERTCertificate cert_; ScopedSECKEYPrivateKey key_; }; std::unique_ptr ReadConfig(int argc, char** argv) { std::unique_ptr cfg(new Config()); cfg->AddEntry("port", 0); cfg->AddEntry("ipv6", false); cfg->AddEntry("shim-id", 0); cfg->AddEntry("server", false); cfg->AddEntry("resume-count", 0); cfg->AddEntry("key-file", ""); cfg->AddEntry("cert-file", ""); cfg->AddEntry("min-version", 0); cfg->AddEntry("max-version", 0xffff); for (auto flag : kVersionDisableFlags) { cfg->AddEntry(flag, false); } cfg->AddEntry("fallback-scsv", false); cfg->AddEntry("false-start", false); cfg->AddEntry("enable-ocsp-stapling", false); cfg->AddEntry("write-then-read", false); cfg->AddEntry("require-any-client-certificate", false); cfg->AddEntry("verify-peer", false); cfg->AddEntry("is-handshaker-supported", false); cfg->AddEntry("handshaker-path", ""); // Ignore this cfg->AddEntry("advertise-alpn", ""); cfg->AddEntry("on-initial-advertise-alpn", ""); cfg->AddEntry("on-resume-advertise-alpn", ""); cfg->AddEntry("expect-alpn", ""); cfg->AddEntry("on-initial-expect-alpn", ""); cfg->AddEntry("on-resume-expect-alpn", ""); cfg->AddEntry("on-retry-expect-alpn", ""); cfg->AddEntry>("signing-prefs", std::vector()); cfg->AddEntry>("verify-prefs", std::vector()); cfg->AddEntry("expect-peer-signature-algorithm", 0); cfg->AddEntry("nss-cipher", ""); cfg->AddEntry("host-name", ""); cfg->AddEntry("ech-config-list", ""); cfg->AddEntry("on-resume-ech-config-list", ""); cfg->AddEntry("expect-ech-accept", false); cfg->AddEntry("expect-hrr", false); cfg->AddEntry("enable-ech-grease", false); cfg->AddEntry("enable-early-data", false); cfg->AddEntry("enable-grease", false); cfg->AddEntry("permute-extensions", false); cfg->AddEntry("on-resume-expect-reject-early-data", false); cfg->AddEntry("on-resume-expect-accept-early-data", false); cfg->AddEntry("expect-ticket-supports-early-data", false); cfg->AddEntry("on-resume-shim-writes-first", false); // Always means 0Rtt write cfg->AddEntry("shim-writes-first", false); // Unimplemented since not required so far cfg->AddEntry("expect-session-miss", false); cfg->AddEntry("expect-ech-retry-configs", ""); cfg->AddEntry("expect-no-ech-retry-configs", false); cfg->AddEntry("on-initial-expect-ech-accept", false); cfg->AddEntry("on-resume-expect-ech-accept", false); cfg->AddEntry("expect-no-offer-early-data", false); /* NSS does not support earlydata rejection reason logging => Ignore. */ cfg->AddEntry("on-resume-expect-early-data-reason", "none"); cfg->AddEntry("on-retry-expect-early-data-reason", "none"); cfg->AddEntry>("curves", std::vector()); cfg->AddEntry("expect-curve-id", 0); cfg->AddEntry("install-cert-compression-algs", false); auto rv = cfg->ParseArgs(argc, argv); switch (rv) { case Config::kOK: break; case Config::kUnknownFlag: exitCodeUnimplemented = true; default: return nullptr; } // Needed to change to std::unique_ptr return std::move(cfg); } bool RunCycle(std::unique_ptr& cfg, bool resuming = false) { std::unique_ptr agent(TestAgent::Create(*cfg)); return agent && agent->DoExchange(resuming) == SECSuccess; } int GetExitCode(bool success) { if (exitCodeUnimplemented) { return 89; } if (success) { return 0; } return 1; } int main(int argc, char** argv) { std::unique_ptr cfg = ReadConfig(argc, argv); if (!cfg) { return GetExitCode(false); } if (cfg->get("is-handshaker-supported")) { std::cout << "No\n"; return 0; } if (cfg->get("server")) { if (SSL_ConfigServerSessionIDCache(1024, 0, 0, ".") != SECSuccess) { std::cerr << "Couldn't configure session cache\n"; return 1; } } if (NSS_NoDB_Init(nullptr) != SECSuccess) { return 1; } // Run a single test cycle. bool success = RunCycle(cfg); int resume_count = cfg->get("resume-count"); while (success && resume_count-- > 0) { std::cout << "Resuming" << std::endl; success = RunCycle(cfg, true); } SSL_ClearSessionCache(); if (cfg->get("server")) { SSL_ShutdownServerSessionIDCache(); } if (NSS_Shutdown() != SECSuccess) { success = false; } return GetExitCode(success); }