reapack

Package manager for REAPER
Log | Files | Refs | Submodules | README | LICENSE

commit 88f6da4f4b67c71ca2600c95e5f614443d1d8199
parent 88d397aa025f97d28fc900533dfbda721d486ad4
Author: cfillion <cfillion@users.noreply.github.com>
Date:   Thu,  6 Dec 2018 08:47:52 -0800

cache CNG algorithm providers

...as recommended by Microsoft

https://docs.microsoft.com/en-us/windows/desktop/api/bcrypt/nf-bcrypt-bcryptopenalgorithmprovider#remarks

Diffstat:
Msrc/hash.cpp | 92+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------
Mtest/hash.cpp | 6++++++
2 files changed, 75 insertions(+), 23 deletions(-)

diff --git a/src/hash.cpp b/src/hash.cpp @@ -4,53 +4,92 @@ #include <vector> #ifdef _WIN32 +# include <map> # include <windows.h> -class Hash::CNGContext : public Hash::Context { +class CNGAlgorithmProvider; +std::map<Hash::Algorithm, std::weak_ptr<CNGAlgorithmProvider>> s_algoCache; + +class CNGAlgorithmProvider { public: - CNGContext(const Algorithm algo) : m_algo(), m_hash(), m_hashLength() { - const wchar_t *algoName; + static std::shared_ptr<CNGAlgorithmProvider> get(const Hash::Algorithm algo) + { + auto it = s_algoCache.find(algo); + + if(it != s_algoCache.end() && !it->second.expired()) + return it->second.lock(); + + wchar_t *algoName; switch(algo) { - case SHA256: + case Hash::SHA256: algoName = BCRYPT_SHA256_ALGORITHM; break; default: - return; + return nullptr; } - BCryptOpenAlgorithmProvider(&m_algo, algoName, - MS_PRIMITIVE_PROVIDER, 0); + auto provider = std::make_shared<CNGAlgorithmProvider>(algoName); + s_algoCache[algo] = provider; + return provider; + } + + CNGAlgorithmProvider(const wchar_t *algoName) + { + BCryptOpenAlgorithmProvider(&m_algo, algoName, MS_PRIMITIVE_PROVIDER, 0); + } + + ~CNGAlgorithmProvider() + { + BCryptCloseAlgorithmProvider(m_algo, 0); + } + + operator BCRYPT_ALG_HANDLE() const + { + return m_algo; + } + +private: + BCRYPT_ALG_HANDLE m_algo; +}; +class Hash::CNGContext : public Hash::Context { +public: + CNGContext(const std::shared_ptr<CNGAlgorithmProvider> &algo) + : m_algo(algo), m_hash(), m_hashLength() + { unsigned long bytesWritten; - BCryptGetProperty(m_algo, BCRYPT_HASH_LENGTH, + BCryptGetProperty(*m_algo, BCRYPT_HASH_LENGTH, reinterpret_cast<PUCHAR>(&m_hashLength), sizeof(m_hashLength), &bytesWritten, 0); - BCryptCreateHash(m_algo, &m_hash, nullptr, 0, nullptr, 0, 0); + BCryptCreateHash(*m_algo, &m_hash, nullptr, 0, nullptr, 0, 0); } - ~CNGContext() override { - if(m_algo) - BCryptCloseAlgorithmProvider(m_algo, 0); - if(m_hash) - BCryptDestroyHash(m_hash); + ~CNGContext() override + { + BCryptDestroyHash(m_hash); } - size_t hashSize() const override { return m_hashLength; } + size_t hashSize() const override + { + return m_hashLength; + } - void addData(const char *data, const size_t len) override { + void addData(const char *data, const size_t len) override + { BCryptHashData(m_hash, reinterpret_cast<unsigned char *>(const_cast<char *>(data)), static_cast<unsigned long>(len), 0); } - void getHash(unsigned char *out) { + void getHash(unsigned char *out) + { BCryptFinishHash(m_hash, out, m_hashLength, 0); } private: - BCRYPT_ALG_HANDLE m_algo; + std::shared_ptr<CNGAlgorithmProvider> m_algo; BCRYPT_HASH_HANDLE m_hash; unsigned long m_hashLength; }; @@ -66,17 +105,23 @@ private: class Hash::SHA256Context : public Hash::Context { public: - SHA256Context() { SHA256_Init(&m_context); } + SHA256Context() + { + SHA256_Init(&m_context); + } - size_t hashSize() const override { + size_t hashSize() const override + { return SHA256_DIGEST_LENGTH; } - void addData(const char *data, const size_t len) override { + void addData(const char *data, const size_t len) override + { SHA256_Update(&m_context, data, len); } - void getHash(unsigned char *out) override { + void getHash(unsigned char *out) override + { SHA256_Final(out, &m_context); } @@ -90,7 +135,8 @@ Hash::Hash(const Algorithm algo) : m_algo(algo) { #ifdef _WIN32 - m_context = std::make_unique<CNGContext>(algo); + if(const auto &algoProvider = CNGAlgorithmProvider::get(algo)) + m_context = std::make_unique<CNGContext>(algoProvider); #else switch(algo) { case SHA256: diff --git a/test/hash.cpp b/test/hash.cpp @@ -27,3 +27,9 @@ TEST_CASE("sha256 hashes", M) { "1220dbd318c1c462aee872f41109a4dfd3048871a03dedd0fe0e757ced57dad6f2d7"); } } + +TEST_CASE("invalid algorithm", M) { + Hash hash(static_cast<Hash::Algorithm>(0)); + hash.write("foo bar", 7); + REQUIRE(hash.digest() == ""); +}