diff --git a/gtests/freebl_gtest/mpi_unittest.cc b/gtests/freebl_gtest/mpi_unittest.cc index 4fed1a40e0..2ccb8c351a 100644 --- a/gtests/freebl_gtest/mpi_unittest.cc +++ b/gtests/freebl_gtest/mpi_unittest.cc @@ -15,7 +15,7 @@ #include "mpi.h" namespace nss_test { -void gettime(struct timespec *tp) { +void gettime(struct timespec* tp) { #ifdef __MACH__ clock_serv_t cclock; mach_timespec_t mts; @@ -69,6 +69,39 @@ class MPITest : public ::testing::Test { mp_clear(&b); mp_clear(&c); } + + void dump(const std::string& prefix, const uint8_t* buf, size_t len) { + auto flags = std::cerr.flags(); + std::cerr << prefix << ": [" << std::dec << len << "] "; + for (size_t i = 0; i < len; ++i) { + std::cerr << std::hex << std::setw(2) << std::setfill('0') + << static_cast(buf[i]); + } + std::cerr << std::endl << std::resetiosflags(flags); + } + + void TestToFixedOctets(const std::vector& ref, size_t len) { + mp_int a; + ASSERT_EQ(MP_OKAY, mp_init(&a)); + ASSERT_EQ(MP_OKAY, mp_read_unsigned_octets(&a, ref.data(), ref.size())); + uint8_t buf[len]; + ASSERT_EQ(MP_OKAY, mp_to_fixlen_octets(&a, buf, len)); + size_t compare; + if (len > ref.size()) { + for (size_t i = 0; i < len - ref.size(); ++i) { + ASSERT_EQ(0U, buf[i]) << "index " << i << " should be zero"; + } + compare = ref.size(); + } else { + compare = len; + } + dump("value", ref.data(), ref.size()); + dump("output", buf, len); + ASSERT_EQ(0, memcmp(buf + len - compare, ref.data() + ref.size() - compare, + compare)) + << "comparing " << compare << " octets"; + mp_clear(&a); + } }; TEST_F(MPITest, MpiCmp01Test) { TestCmp("0", "1", -1); } @@ -113,6 +146,47 @@ TEST_F(MPITest, MpiCmpUnalignedTest) { } #endif +TEST_F(MPITest, MpiFixlenOctetsZero) { + std::vector zero = {0}; + TestToFixedOctets(zero, 1); + TestToFixedOctets(zero, 2); + TestToFixedOctets(zero, sizeof(mp_digit)); + TestToFixedOctets(zero, sizeof(mp_digit) + 1); +} + +TEST_F(MPITest, MpiFixlenOctetsVarlen) { + std::vector packed; + for (size_t i = 0; i < sizeof(mp_digit) * 2; ++i) { + packed.push_back(0xa4); // Any non-zero value will do. + TestToFixedOctets(packed, packed.size()); + TestToFixedOctets(packed, packed.size() + 1); + TestToFixedOctets(packed, packed.size() + sizeof(mp_digit)); + } +} + +TEST_F(MPITest, MpiFixlenOctetsTooSmall) { + uint8_t buf[sizeof(mp_digit) * 3]; + std::vector ref; + for (size_t i = 0; i < sizeof(mp_digit) * 2; i++) { + ref.push_back(3); // Any non-zero value will do. + dump("ref", ref.data(), ref.size()); + + mp_int a; + ASSERT_EQ(MP_OKAY, mp_init(&a)); + ASSERT_EQ(MP_OKAY, mp_read_unsigned_octets(&a, ref.data(), ref.size())); +#ifdef DEBUG + // ARGCHK maps to assert() in a debug build. + EXPECT_DEATH(mp_to_fixlen_octets(&a, buf, ref.size() - 1), ""); +#else + EXPECT_EQ(MP_BADARG, mp_to_fixlen_octets(&a, buf, ref.size() - 1)); +#endif + ASSERT_EQ(MP_OKAY, mp_to_fixlen_octets(&a, buf, ref.size())); + ASSERT_EQ(0, memcmp(buf, ref.data(), ref.size())); + + mp_clear(&a); + } +} + // This test is slow. Disable it by default so we can run these tests on CI. class DISABLED_MPITest : public ::testing::Test {}; @@ -127,17 +201,17 @@ TEST_F(DISABLED_MPITest, MpiCmpConstTest) { mp_read_radix( &a, - const_cast( + const_cast( "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551"), 16); mp_read_radix( &b, - const_cast( + const_cast( "FF0FFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551"), 16); mp_read_radix( &c, - const_cast( + const_cast( "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632550"), 16); diff --git a/lib/freebl/mpi/mpi.c b/lib/freebl/mpi/mpi.c index 8c893fb5fa..401eac51db 100644 --- a/lib/freebl/mpi/mpi.c +++ b/lib/freebl/mpi/mpi.c @@ -4775,38 +4775,61 @@ mp_to_signed_octets(const mp_int *mp, unsigned char *str, mp_size maxlen) /* }}} */ /* {{{ mp_to_fixlen_octets(mp, str) */ -/* output a buffer of big endian octets exactly as long as requested. */ +/* output a buffer of big endian octets exactly as long as requested. + constant time on the value of mp. */ mp_err mp_to_fixlen_octets(const mp_int *mp, unsigned char *str, mp_size length) { - int ix, pos = 0; + int ix, jx; unsigned int bytes; - ARGCHK(mp != NULL && str != NULL && !SIGN(mp), MP_BADARG); - - bytes = mp_unsigned_octet_size(mp); - ARGCHK(bytes <= length, MP_BADARG); + ARGCHK(mp != NULL, MP_BADARG); + ARGCHK(str != NULL, MP_BADARG); + ARGCHK(!SIGN(mp), MP_BADARG); + ARGCHK(length > 0, MP_BADARG); + + /* Constant time on the value of mp. Don't use mp_unsigned_octet_size. */ + bytes = USED(mp) * MP_DIGIT_SIZE; + + /* If the output is shorter than the native size of mp, then check that any + * bytes not written have zero values. This check isn't constant time on + * the assumption that timing-sensitive callers can guarantee that mp fits + * in the allocated space. */ + ix = USED(mp) - 1; + if (bytes > length) { + unsigned int zeros = bytes - length; + + while (zeros >= MP_DIGIT_SIZE) { + ARGCHK(DIGIT(mp, ix) == 0, MP_BADARG); + zeros -= MP_DIGIT_SIZE; + ix--; + } - /* place any needed leading zeros */ - for (; length > bytes; --length) { - *str++ = 0; + if (zeros > 0) { + mp_digit d = DIGIT(mp, ix); + mp_digit m = ~0ULL << ((MP_DIGIT_SIZE - zeros) * CHAR_BIT); + ARGCHK((d & m) == 0, MP_BADARG); + for (jx = MP_DIGIT_SIZE - zeros - 1; jx >= 0; jx--) { + *str++ = d >> (jx * CHAR_BIT); + } + ix--; + } + } else if (bytes < length) { + /* Place any needed leading zeros. */ + unsigned int zeros = length - bytes; + memset(str, 0, zeros); + str += zeros; } - /* Iterate over each digit... */ - for (ix = USED(mp) - 1; ix >= 0; ix--) { + /* Iterate over each whole digit... */ + for (; ix >= 0; ix--) { mp_digit d = DIGIT(mp, ix); - int jx; /* Unpack digit bytes, high order first */ - for (jx = sizeof(mp_digit) - 1; jx >= 0; jx--) { - unsigned char x = (unsigned char)(d >> (jx * CHAR_BIT)); - if (!pos && !x) /* suppress leading zeros */ - continue; - str[pos++] = x; + for (jx = MP_DIGIT_SIZE - 1; jx >= 0; jx--) { + *str++ = d >> (jx * CHAR_BIT); } } - if (!pos) - str[pos++] = 0; return MP_OKAY; } /* end mp_to_fixlen_octets() */ /* }}} */ diff --git a/lib/freebl/mpi/mpi.h b/lib/freebl/mpi/mpi.h index 97af0f069b..d5aef46d7c 100644 --- a/lib/freebl/mpi/mpi.h +++ b/lib/freebl/mpi/mpi.h @@ -128,7 +128,8 @@ typedef int mp_sword; #define MP_WORD_MAX UINT_MAX #endif -#define MP_DIGIT_BIT (CHAR_BIT * sizeof(mp_digit)) +#define MP_DIGIT_SIZE sizeof(mp_digit) +#define MP_DIGIT_BIT (CHAR_BIT * MP_DIGIT_SIZE) #define MP_WORD_BIT (CHAR_BIT * sizeof(mp_word)) #define MP_RADIX (1 + (mp_word)MP_DIGIT_MAX)