commit b0985c227a185df26996c3a0a7d091da33b56f3e
parent b98367fd0bb2891ce1fe68d44cef144cc97f375f
Author: cfillion <cfillion@users.noreply.github.com>
Date: Sat, 8 Dec 2018 11:06:50 -0500
Merge branch 'checksum'
Diffstat:
12 files changed, 380 insertions(+), 11 deletions(-)
diff --git a/linux.tup b/linux.tup
@@ -20,7 +20,7 @@ SWELL := $(WDL)/swell
WDLSOURCE += $(SWELL)/swell-modstub-generic.cpp
export CURLSO
-LDFLAGS := -lstdc++ -lpthread -ldl -l${CURLSO:-curl} -lsqlite3 -lz
+LDFLAGS := -lstdc++ -lpthread -ldl -lcrypto -l${CURLSO:-curl} -lsqlite3 -lz
LDFLAGS += -Wl,--gc-sections
SOFLAGS := -shared
diff --git a/src/download.cpp b/src/download.cpp
@@ -18,6 +18,7 @@
#include "download.hpp"
#include "filesystem.hpp"
+#include "hash.hpp"
#include "reapack.hpp"
#include <cassert>
@@ -92,7 +93,7 @@ size_t Download::WriteData(char *data, size_t rawsize, size_t nmemb, void *ptr)
{
const size_t size = rawsize * nmemb;
- static_cast<ostream *>(ptr)->write(data, size);
+ static_cast<WriteContext *>(ptr)->write(data, size);
return size;
}
@@ -115,8 +116,21 @@ void Download::setName(const string &name)
bool Download::run()
{
- ostream *stream = openStream();
- if(!stream)
+ WriteContext write;
+
+ Hash::Algorithm algo;
+ if(!m_expectedChecksum.empty()) {
+ if(Hash::getAlgorithm(m_expectedChecksum, &algo))
+ write.hash = make_unique<Hash>(algo);
+ else {
+ const string &error = String::format(
+ "Unsupported checksum: %s", m_expectedChecksum.c_str());
+ setError({error, m_url});
+ return false;
+ }
+ }
+
+ if(!(write.stream = openStream()))
return false;
thread_local DownloadContext ctx;
@@ -132,7 +146,7 @@ bool Download::run()
curl_easy_setopt(ctx, CURLOPT_PROGRESSDATA, this);
curl_easy_setopt(ctx, CURLOPT_WRITEFUNCTION, WriteData);
- curl_easy_setopt(ctx, CURLOPT_WRITEDATA, stream);
+ curl_easy_setopt(ctx, CURLOPT_WRITEDATA, &write);
curl_slist *headers = nullptr;
if(has(Download::NoCacheFlag))
@@ -152,10 +166,26 @@ bool Download::run()
setError({err, m_url});
return false;
}
+ else if(write.hash && write.hash->digest() != m_expectedChecksum) {
+ const string &err = String::format(
+ "Checksum mismatch.\nExpected: %s\nActual: %s",
+ m_expectedChecksum.c_str(), write.hash->digest().c_str()
+ );
+ setError({err, m_url});
+ return false;
+ }
return true;
}
+void Download::WriteContext::write(const char *data, const size_t len)
+{
+ stream->write(data, len);
+
+ if(hash)
+ hash->addData(data, len);
+}
+
MemoryDownload::MemoryDownload(const string &url, const NetworkOpts &opts, int flags)
: Download(url, opts, flags)
{
diff --git a/src/download.hpp b/src/download.hpp
@@ -24,8 +24,11 @@
#include <curl/curl.h>
#include <fstream>
+#include <memory>
#include <sstream>
+class Hash;
+
class DownloadContext {
public:
static void GlobalInit();
@@ -49,6 +52,9 @@ public:
Download(const std::string &url, const NetworkOpts &, int flags = 0);
void setName(const std::string &);
+ void setExpectedChecksum(const std::string &checksum) {
+ m_expectedChecksum = checksum;
+ }
const std::string &url() const { return m_url; }
bool concurrent() const override { return true; }
@@ -59,11 +65,20 @@ protected:
virtual void closeStream() {}
private:
+ struct WriteContext {
+ std::ostream *stream;
+ std::unique_ptr<Hash> hash;
+
+ void write(const char *data, size_t len);
+ bool checkChecksum(const std::string &expected) const;
+ };
+
bool has(Flag f) const { return (m_flags & f) != 0; }
static size_t WriteData(char *, size_t, size_t, void *);
static int UpdateProgress(void *, double, double, double, double);
std::string m_url;
+ std::string m_expectedChecksum;
NetworkOpts m_opts;
int m_flags;
};
diff --git a/src/hash.cpp b/src/hash.cpp
@@ -0,0 +1,193 @@
+#include "hash.hpp"
+
+#include <cstdio>
+#include <vector>
+
+#ifdef _WIN32
+# include <map>
+# include <windows.h>
+
+class CNGAlgorithmProvider;
+std::map<Hash::Algorithm, std::weak_ptr<CNGAlgorithmProvider>> s_algoCache;
+
+class CNGAlgorithmProvider {
+public:
+ static std::shared_ptr<CNGAlgorithmProvider> get(const Hash::Algorithm algo)
+ {
+ auto it = s_algoCache.find(algo);
+
+ if(it != s_algoCache.end() && !it->second.expired())
+ return it->second.lock();
+
+ wchar_t *algoName;
+
+ switch(algo) {
+ case Hash::SHA256:
+ algoName = BCRYPT_SHA256_ALGORITHM;
+ break;
+ default:
+ return nullptr;
+ }
+
+ auto provider = std::make_shared<CNGAlgorithmProvider>(algoName);
+ s_algoCache[algo] = provider;
+ return provider;
+ }
+
+ CNGAlgorithmProvider(const wchar_t *algoName)
+ {
+ BCryptOpenAlgorithmProvider(&m_algo, algoName, MS_PRIMITIVE_PROVIDER, 0);
+ }
+
+ ~CNGAlgorithmProvider()
+ {
+ BCryptCloseAlgorithmProvider(m_algo, 0);
+ }
+
+ operator BCRYPT_ALG_HANDLE() const
+ {
+ return m_algo;
+ }
+
+private:
+ BCRYPT_ALG_HANDLE m_algo;
+};
+
+class Hash::CNGContext : public Hash::Context {
+public:
+ CNGContext(const std::shared_ptr<CNGAlgorithmProvider> &algo)
+ : m_algo(algo), m_hash(), m_hashLength()
+ {
+ unsigned long bytesWritten;
+ BCryptGetProperty(*m_algo, BCRYPT_HASH_LENGTH,
+ reinterpret_cast<PUCHAR>(&m_hashLength), sizeof(m_hashLength),
+ &bytesWritten, 0);
+
+ BCryptCreateHash(*m_algo, &m_hash, nullptr, 0, nullptr, 0, 0);
+ }
+
+ ~CNGContext() override
+ {
+ BCryptDestroyHash(m_hash);
+ }
+
+ size_t hashSize() const override
+ {
+ return m_hashLength;
+ }
+
+ void addData(const char *data, const size_t len) override
+ {
+ BCryptHashData(m_hash,
+ reinterpret_cast<unsigned char *>(const_cast<char *>(data)),
+ static_cast<unsigned long>(len), 0);
+ }
+
+ void getHash(unsigned char *out)
+ {
+ BCryptFinishHash(m_hash, out, m_hashLength, 0);
+ }
+
+private:
+ std::shared_ptr<CNGAlgorithmProvider> m_algo;
+ BCRYPT_HASH_HANDLE m_hash;
+ unsigned long m_hashLength;
+};
+
+#else // Unix systems
+
+# ifdef __APPLE__
+# define COMMON_DIGEST_FOR_OPENSSL
+# include <CommonCrypto/CommonDigest.h>
+# else
+# include <openssl/sha.h>
+# endif
+
+class Hash::SHA256Context : public Hash::Context {
+public:
+ SHA256Context()
+ {
+ SHA256_Init(&m_context);
+ }
+
+ size_t hashSize() const override
+ {
+ return SHA256_DIGEST_LENGTH;
+ }
+
+ void addData(const char *data, const size_t len) override
+ {
+ SHA256_Update(&m_context, data, len);
+ }
+
+ void getHash(unsigned char *out) override
+ {
+ SHA256_Final(out, &m_context);
+ }
+
+private:
+ SHA256_CTX m_context;
+};
+
+#endif
+
+Hash::Hash(const Algorithm algo)
+ : m_algo(algo)
+{
+#ifdef _WIN32
+ if(const auto &algoProvider = CNGAlgorithmProvider::get(algo))
+ m_context = std::make_unique<CNGContext>(algoProvider);
+#else
+ switch(algo) {
+ case SHA256:
+ m_context = std::make_unique<SHA256Context>();
+ break;
+ }
+#endif
+}
+
+void Hash::addData(const char *data, const size_t len)
+{
+ if(m_context)
+ m_context->addData(data, len);
+}
+
+const std::string &Hash::digest()
+{
+ if(!m_context || !m_value.empty())
+ return m_value;
+
+ // Assuming m_algo and hashSize can fit in one byte. We'll need to implement
+ // multihash's varint if we need larger values in the future.
+ const size_t hashSize = m_context->hashSize();
+
+ std::vector<unsigned char> multihash(2 + hashSize);
+ multihash[0] = m_algo;
+ multihash[1] = static_cast<unsigned char>(hashSize);
+ m_context->getHash(&multihash[2]);
+
+ m_value.resize(multihash.size() * 2);
+
+ for(size_t i = 0; i < multihash.size(); ++i)
+ sprintf(&m_value[i * 2], "%02x", multihash[i]);
+
+ return m_value;
+}
+
+bool Hash::getAlgorithm(const std::string &hash, Algorithm *out)
+{
+ unsigned int algo, size;
+ if(sscanf(hash.c_str(), "%2x%2x", &algo, &size) != 2)
+ return false;
+
+ if(hash.size() != (size * 2) + 4)
+ return false;
+
+ switch(algo) {
+ case SHA256:
+ *out = static_cast<Algorithm>(algo);
+ return true;
+ default:
+ return false;
+ };
+}
diff --git a/src/hash.hpp b/src/hash.hpp
@@ -0,0 +1,38 @@
+#ifndef REAPACK_HASH_HPP
+#define REAPACK_HASH_HPP
+
+#include <memory>
+#include <string>
+
+class Hash {
+public:
+ enum Algorithm {
+ SHA256 = 0x12,
+ };
+
+ static bool getAlgorithm(const std::string &hash, Algorithm *out);
+
+ Hash(Algorithm);
+ Hash(const Hash &) = delete;
+
+ void addData(const char *data, size_t len);
+ const std::string &digest();
+
+private:
+ class Context {
+ public:
+ virtual ~Context() = default;
+ virtual size_t hashSize() const = 0;
+ virtual void addData(const char *data, size_t len) = 0;
+ virtual void getHash(unsigned char *out) = 0;
+ };
+
+ class CNGContext;
+ class SHA256Context;
+
+ Algorithm m_algo;
+ std::string m_value;
+ std::unique_ptr<Context> m_context;
+};
+
+#endif
diff --git a/src/index_v1.cpp b/src/index_v1.cpp
@@ -176,6 +176,9 @@ void LoadSourceV1(TiXmlElement *node, Version *ver)
const char *file = node->Attribute("file");
if(!file) file = "";
+ const char *checksum = node->Attribute("checksum");
+ if(!checksum) checksum = "";
+
const char *main = node->Attribute("main");
if(!main) main = "";
@@ -185,6 +188,7 @@ void LoadSourceV1(TiXmlElement *node, Version *ver)
Source *src = new Source(file, url, ver);
unique_ptr<Source> ptr(src);
+ src->setChecksum(checksum);
src->setPlatform(platform);
src->setTypeOverride(Package::getType(type));
diff --git a/src/install.cpp b/src/install.cpp
@@ -74,6 +74,7 @@ bool InstallTask::start()
else {
const NetworkOpts &opts = g_reapack->config()->network;
FileDownload *dl = new FileDownload(targetPath, src->url(), opts);
+ dl->setExpectedChecksum(src->checksum());
push(dl, dl->path());
}
}
diff --git a/src/source.hpp b/src/source.hpp
@@ -42,25 +42,31 @@ public:
static Section detectSection(const Path &category);
Source(const std::string &file, const std::string &url, const Version *);
+
const Version *version() const { return m_version; }
+ Package::Type type() const;
+ const std::string &file() const;
+ const std::string &url() const { return m_url; }
+ Path targetPath() const;
+
+ void setChecksum(const std::string &checksum) { m_checksum = checksum; }
+ const std::string &checksum() const { return m_checksum; }
void setPlatform(Platform p) { m_platform = p; }
Platform platform() const { return m_platform; }
+
void setTypeOverride(Package::Type t) { m_type = t; }
Package::Type typeOverride() const { return m_type; }
- Package::Type type() const;
- const std::string &file() const;
- const std::string &url() const { return m_url; }
+
void setSections(int);
int sections() const { return m_sections; }
- Path targetPath() const;
-
private:
Platform m_platform;
Package::Type m_type;
std::string m_file;
std::string m_url;
+ std::string m_checksum;
int m_sections;
Path m_targetPath;
const Version *m_version;
diff --git a/test/hash.cpp b/test/hash.cpp
@@ -0,0 +1,53 @@
+#include "helper.hpp"
+
+#include <hash.hpp>
+
+static const char *M = "[hash]";
+
+TEST_CASE("sha256 hashes", M) {
+ Hash hash(Hash::SHA256);
+
+ SECTION("empty") {
+ REQUIRE(hash.digest() ==
+ "1220e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855");
+ }
+
+ SECTION("single chunk") {
+ hash.addData("hello world", 11);
+
+ REQUIRE(hash.digest() ==
+ "1220b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9");
+ }
+
+ SECTION("split chunks") {
+ hash.addData("foo bar", 7);
+ hash.addData(" bazqux", 4);
+
+ REQUIRE(hash.digest() ==
+ "1220dbd318c1c462aee872f41109a4dfd3048871a03dedd0fe0e757ced57dad6f2d7");
+ }
+}
+
+TEST_CASE("invalid algorithm", M) {
+ Hash hash(static_cast<Hash::Algorithm>(0));
+ hash.addData("foo bar", 7);
+ REQUIRE(hash.digest() == "");
+}
+
+TEST_CASE("get hash algorithm", M) {
+ Hash::Algorithm algo;
+
+ SECTION("empty string")
+ REQUIRE_FALSE(Hash::getAlgorithm("", &algo));
+
+ SECTION("only sha-256 ID")
+ REQUIRE_FALSE(Hash::getAlgorithm("12", &algo));
+
+ SECTION("unexpected size")
+ REQUIRE_FALSE(Hash::getAlgorithm("1202ab", &algo));
+
+ SECTION("seemingly good (but not actually) sha-256") {
+ REQUIRE(Hash::getAlgorithm("1202abcd", &algo));
+ REQUIRE(algo == Hash::SHA256);
+ }
+}
diff --git a/test/index_v1.cpp b/test/index_v1.cpp
@@ -302,3 +302,22 @@ TEST_CASE("read multiple sections", M) {
REQUIRE(ri->category(0)->package(0)->version(0)->source(0)->sections()
== (Source::MainSection | Source::MIDIEditorSection));
}
+
+TEST_CASE("read sha256 checksum", M) {
+ IndexPtr ri = Index::load({}, R"(
+<index version="1">
+ <category name="catname">
+ <reapack name="packname" type="script">
+ <version name="1.0" author="John Doe">
+ <source checksum="12206037d8b51b33934348a2b26e04f0eb7227315b87bb5688ceb6dccb0468b14cce">https://google.com/</source>
+ </version>
+ </reapack>
+ </category>
+</index>
+ )");
+
+ REQUIRE(ri->packages().size() == 1);
+
+ REQUIRE(ri->category(0)->package(0)->version(0)->source(0)->checksum()
+ == "12206037d8b51b33934348a2b26e04f0eb7227315b87bb5688ceb6dccb0468b14cce");
+}
diff --git a/test/source.cpp b/test/source.cpp
@@ -204,3 +204,13 @@ TEST_CASE("directory traversal in category name", M) {
REQUIRE(src.targetPath() == expected);
}
+
+TEST_CASE("source checksum", M) {
+ MAKE_VERSION;
+
+ Source src({}, "url", &ver);
+ REQUIRE(src.checksum().empty());
+
+ src.setChecksum("hello world");
+ REQUIRE(src.checksum() == "hello world");
+}
diff --git a/win32.tup b/win32.tup
@@ -15,7 +15,7 @@ CXXFLAGS += /DWDL_NO_DEFINE_MINMAX /DCURL_STATICLIB /DUNICODE /DNDEBUG
CXXFLAGS += /DREAPACK_FILE#\"$(REAPACK_FILE).dll\"
LD := $(WRAP) link
-LDFLAGS := /nologo User32.lib Shell32.lib Gdi32.lib Comdlg32.lib Comctl32.lib
+LDFLAGS := /nologo User32.lib Shell32.lib Gdi32.lib Comdlg32.lib Comctl32.lib Bcrypt.lib
LDFLAGS += $(VCPKG)/lib/libcurl.lib Ws2_32.lib Crypt32.lib Advapi32.lib
LDFLAGS += $(VCPKG)/lib/sqlite3.lib $(VCPKG)/lib/zlib.lib
LDFLAGS += $(TUP_VARIANTDIR)/src/resource.res