commit 88f6da4f4b67c71ca2600c95e5f614443d1d8199
parent 88d397aa025f97d28fc900533dfbda721d486ad4
Author: cfillion <cfillion@users.noreply.github.com>
Date: Thu, 6 Dec 2018 08:47:52 -0800
cache CNG algorithm providers
...as recommended by Microsoft
https://docs.microsoft.com/en-us/windows/desktop/api/bcrypt/nf-bcrypt-bcryptopenalgorithmprovider#remarks
Diffstat:
2 files changed, 75 insertions(+), 23 deletions(-)
diff --git a/src/hash.cpp b/src/hash.cpp
@@ -4,53 +4,92 @@
#include <vector>
#ifdef _WIN32
+# include <map>
# include <windows.h>
-class Hash::CNGContext : public Hash::Context {
+class CNGAlgorithmProvider;
+std::map<Hash::Algorithm, std::weak_ptr<CNGAlgorithmProvider>> s_algoCache;
+
+class CNGAlgorithmProvider {
public:
- CNGContext(const Algorithm algo) : m_algo(), m_hash(), m_hashLength() {
- const wchar_t *algoName;
+ static std::shared_ptr<CNGAlgorithmProvider> get(const Hash::Algorithm algo)
+ {
+ auto it = s_algoCache.find(algo);
+
+ if(it != s_algoCache.end() && !it->second.expired())
+ return it->second.lock();
+
+ wchar_t *algoName;
switch(algo) {
- case SHA256:
+ case Hash::SHA256:
algoName = BCRYPT_SHA256_ALGORITHM;
break;
default:
- return;
+ return nullptr;
}
- BCryptOpenAlgorithmProvider(&m_algo, algoName,
- MS_PRIMITIVE_PROVIDER, 0);
+ auto provider = std::make_shared<CNGAlgorithmProvider>(algoName);
+ s_algoCache[algo] = provider;
+ return provider;
+ }
+
+ CNGAlgorithmProvider(const wchar_t *algoName)
+ {
+ BCryptOpenAlgorithmProvider(&m_algo, algoName, MS_PRIMITIVE_PROVIDER, 0);
+ }
+
+ ~CNGAlgorithmProvider()
+ {
+ BCryptCloseAlgorithmProvider(m_algo, 0);
+ }
+
+ operator BCRYPT_ALG_HANDLE() const
+ {
+ return m_algo;
+ }
+
+private:
+ BCRYPT_ALG_HANDLE m_algo;
+};
+class Hash::CNGContext : public Hash::Context {
+public:
+ CNGContext(const std::shared_ptr<CNGAlgorithmProvider> &algo)
+ : m_algo(algo), m_hash(), m_hashLength()
+ {
unsigned long bytesWritten;
- BCryptGetProperty(m_algo, BCRYPT_HASH_LENGTH,
+ BCryptGetProperty(*m_algo, BCRYPT_HASH_LENGTH,
reinterpret_cast<PUCHAR>(&m_hashLength), sizeof(m_hashLength),
&bytesWritten, 0);
- BCryptCreateHash(m_algo, &m_hash, nullptr, 0, nullptr, 0, 0);
+ BCryptCreateHash(*m_algo, &m_hash, nullptr, 0, nullptr, 0, 0);
}
- ~CNGContext() override {
- if(m_algo)
- BCryptCloseAlgorithmProvider(m_algo, 0);
- if(m_hash)
- BCryptDestroyHash(m_hash);
+ ~CNGContext() override
+ {
+ BCryptDestroyHash(m_hash);
}
- size_t hashSize() const override { return m_hashLength; }
+ size_t hashSize() const override
+ {
+ return m_hashLength;
+ }
- void addData(const char *data, const size_t len) override {
+ void addData(const char *data, const size_t len) override
+ {
BCryptHashData(m_hash,
reinterpret_cast<unsigned char *>(const_cast<char *>(data)),
static_cast<unsigned long>(len), 0);
}
- void getHash(unsigned char *out) {
+ void getHash(unsigned char *out)
+ {
BCryptFinishHash(m_hash, out, m_hashLength, 0);
}
private:
- BCRYPT_ALG_HANDLE m_algo;
+ std::shared_ptr<CNGAlgorithmProvider> m_algo;
BCRYPT_HASH_HANDLE m_hash;
unsigned long m_hashLength;
};
@@ -66,17 +105,23 @@ private:
class Hash::SHA256Context : public Hash::Context {
public:
- SHA256Context() { SHA256_Init(&m_context); }
+ SHA256Context()
+ {
+ SHA256_Init(&m_context);
+ }
- size_t hashSize() const override {
+ size_t hashSize() const override
+ {
return SHA256_DIGEST_LENGTH;
}
- void addData(const char *data, const size_t len) override {
+ void addData(const char *data, const size_t len) override
+ {
SHA256_Update(&m_context, data, len);
}
- void getHash(unsigned char *out) override {
+ void getHash(unsigned char *out) override
+ {
SHA256_Final(out, &m_context);
}
@@ -90,7 +135,8 @@ Hash::Hash(const Algorithm algo)
: m_algo(algo)
{
#ifdef _WIN32
- m_context = std::make_unique<CNGContext>(algo);
+ if(const auto &algoProvider = CNGAlgorithmProvider::get(algo))
+ m_context = std::make_unique<CNGContext>(algoProvider);
#else
switch(algo) {
case SHA256:
diff --git a/test/hash.cpp b/test/hash.cpp
@@ -27,3 +27,9 @@ TEST_CASE("sha256 hashes", M) {
"1220dbd318c1c462aee872f41109a4dfd3048871a03dedd0fe0e757ced57dad6f2d7");
}
}
+
+TEST_CASE("invalid algorithm", M) {
+ Hash hash(static_cast<Hash::Algorithm>(0));
+ hash.write("foo bar", 7);
+ REQUIRE(hash.digest() == "");
+}