From a2c891d1432db52371aa633ea0e1bd3dbea7bb97 Mon Sep 17 00:00:00 2001 From: brianheineman Date: Fri, 28 Jun 2024 22:12:17 -0600 Subject: [PATCH] feat!: add configurable hashers --- postgresql_archive/src/hasher/mod.rs | 2 + postgresql_archive/src/hasher/registry.rs | 134 ++++++++++++++++++ postgresql_archive/src/hasher/sha2_256.rs | 35 +++++ postgresql_archive/src/lib.rs | 1 + postgresql_archive/src/matcher/mod.rs | 6 +- postgresql_archive/src/matcher/registry.rs | 16 ++- .../src/repository/github/repository.rs | 37 +++-- 7 files changed, 212 insertions(+), 19 deletions(-) create mode 100644 postgresql_archive/src/hasher/mod.rs create mode 100644 postgresql_archive/src/hasher/registry.rs create mode 100644 postgresql_archive/src/hasher/sha2_256.rs diff --git a/postgresql_archive/src/hasher/mod.rs b/postgresql_archive/src/hasher/mod.rs new file mode 100644 index 0000000..7c6a995 --- /dev/null +++ b/postgresql_archive/src/hasher/mod.rs @@ -0,0 +1,2 @@ +pub mod registry; +pub mod sha2_256; diff --git a/postgresql_archive/src/hasher/registry.rs b/postgresql_archive/src/hasher/registry.rs new file mode 100644 index 0000000..eefe4d9 --- /dev/null +++ b/postgresql_archive/src/hasher/registry.rs @@ -0,0 +1,134 @@ +use crate::hasher::sha2_256; +use crate::Result; +use lazy_static::lazy_static; +use std::collections::HashMap; +use std::sync::{Arc, Mutex, RwLock}; + +lazy_static! { + static ref REGISTRY: Arc> = + Arc::new(Mutex::new(HasherRegistry::default())); +} + +pub type HasherFn = fn(&Vec) -> Result; + +/// Singleton struct to store hashers +struct HasherRegistry { + hashers: HashMap>>, +} + +impl HasherRegistry { + /// Creates a new hasher registry. + /// + /// # Returns + /// * The hasher registry. + fn new() -> Self { + Self { + hashers: HashMap::new(), + } + } + + /// Registers a hasher for an extension. Newly registered hashers with the same extension will + /// override existing ones. + /// + /// # Arguments + /// * `extension` - The extension to register the hasher for. + /// * `hasher_fn` - The hasher function to register. + fn register>(&mut self, extension: S, hasher_fn: HasherFn) { + let extension = extension.as_ref().to_string(); + self.hashers + .insert(extension, Arc::new(RwLock::new(hasher_fn))); + } + + /// Get a hasher for the specified extension. + /// + /// # Arguments + /// * `extension` - The extension to locate a hasher for. + /// + /// # Returns + /// * The hasher for the extension or [None] if not found. + fn get>(&self, extension: S) -> Option { + let extension = extension.as_ref().to_string(); + if let Some(hasher) = self.hashers.get(&extension) { + return Some(*hasher.read().unwrap()); + } + + None + } +} + +impl Default for HasherRegistry { + fn default() -> Self { + let mut registry = Self::new(); + registry.register("sha256", sha2_256::hash); + registry + } +} + +/// Registers a hasher for an extension. Newly registered hashers with the same extension will +/// override existing ones. +/// +/// # Arguments +/// * `extension` - The extension to register the hasher for. +/// * `hasher_fn` - The hasher function to register. +/// +/// # Panics +/// * If the registry is poisoned. +#[allow(dead_code)] +pub fn register>(extension: S, hasher_fn: HasherFn) { + let mut registry = REGISTRY.lock().unwrap(); + registry.register(extension, hasher_fn); +} + +/// Get a hasher for the specified extension. +/// +/// # Arguments +/// * `extension` - The extension to locate a hasher for. +/// +/// # Returns +/// * The hasher for the extension or [None] if not found. +/// +/// # Panics +/// * If the registry is poisoned. +pub fn get>(extension: S) -> Option { + let registry = REGISTRY.lock().unwrap(); + registry.get(extension) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_register() -> Result<()> { + let extension = "sha256"; + let hashers = REGISTRY.lock().unwrap().hashers.len(); + assert!(!REGISTRY.lock().unwrap().hashers.is_empty()); + REGISTRY.lock().unwrap().hashers.remove(extension); + assert_ne!(hashers, REGISTRY.lock().unwrap().hashers.len()); + register(extension, sha2_256::hash); + assert_eq!(hashers, REGISTRY.lock().unwrap().hashers.len()); + + let hasher = get(extension).unwrap(); + let data = vec![1, 2, 3]; + let hash = hasher(&data)?; + + assert_eq!( + "039058c6f2c0cb492c533b0a4d14ef77cc0f78abccced5287d84a1a2011cfb81", + hash + ); + Ok(()) + } + + #[test] + fn test_sha2_256() -> Result<()> { + let hasher = get("sha256").unwrap(); + let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]; + let hash = hasher(&data)?; + + assert_eq!( + "9a89c68c4c5e28b8c4a5567673d462fff515db46116f9900624d09c474f593fb", + hash + ); + Ok(()) + } +} diff --git a/postgresql_archive/src/hasher/sha2_256.rs b/postgresql_archive/src/hasher/sha2_256.rs new file mode 100644 index 0000000..f44a08c --- /dev/null +++ b/postgresql_archive/src/hasher/sha2_256.rs @@ -0,0 +1,35 @@ +use crate::Result; +use sha2::{Digest, Sha256}; + +/// Hashes the data using SHA2-256. +/// +/// # Arguments +/// * `data` - The data to hash. +/// +/// # Returns +/// * The hash of the data. +/// +/// # Errors +/// * If the data cannot be hashed. +pub fn hash(data: &Vec) -> Result { + let mut hasher = Sha256::new(); + hasher.update(data); + let hash = hex::encode(hasher.finalize()); + Ok(hash) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hash() -> Result<()> { + let data = vec![4, 2]; + let hash = hash(&data)?; + assert_eq!( + "b7586d310e5efb1b7d10a917ba5af403adbf54f4f77fe7fdcb4880a95dac7e7e", + hash + ); + Ok(()) + } +} diff --git a/postgresql_archive/src/lib.rs b/postgresql_archive/src/lib.rs index 839b221..207c8e9 100644 --- a/postgresql_archive/src/lib.rs +++ b/postgresql_archive/src/lib.rs @@ -113,6 +113,7 @@ mod archive; #[cfg(feature = "blocking")] pub mod blocking; mod error; +pub mod hasher; pub mod matcher; pub mod repository; mod version; diff --git a/postgresql_archive/src/matcher/mod.rs b/postgresql_archive/src/matcher/mod.rs index 1688bd3..8fb17cd 100644 --- a/postgresql_archive/src/matcher/mod.rs +++ b/postgresql_archive/src/matcher/mod.rs @@ -1,3 +1,3 @@ -mod default; -mod postgresql_binaries; -pub(crate) mod registry; +pub mod default; +pub mod postgresql_binaries; +pub mod registry; diff --git a/postgresql_archive/src/matcher/registry.rs b/postgresql_archive/src/matcher/registry.rs index 27319c0..bac99b2 100644 --- a/postgresql_archive/src/matcher/registry.rs +++ b/postgresql_archive/src/matcher/registry.rs @@ -10,7 +10,7 @@ lazy_static! { Arc::new(Mutex::new(MatchersRegistry::default())); } -type MatcherFn = fn(&str, &Version) -> Result; +pub type MatcherFn = fn(&str, &Version) -> Result; /// Singleton struct to store matchers struct MatchersRegistry { @@ -75,6 +75,9 @@ impl Default for MatchersRegistry { /// # Arguments /// * `url` - The URL to register the matcher for; [None] to register the default. /// * `matcher_fn` - The matcher function to register. +/// +/// # Panics +/// * If the registry is poisoned. #[allow(dead_code)] pub fn register>(url: Option, matcher_fn: MatcherFn) { let mut registry = REGISTRY.lock().unwrap(); @@ -89,6 +92,9 @@ pub fn register>(url: Option, matcher_fn: MatcherFn) { /// /// # Returns /// * The matcher for the URL, or the default matcher. +/// +/// # Panics +/// * If the registry is poisoned. pub fn get>(url: S) -> MatcherFn { let registry = REGISTRY.lock().unwrap(); registry.get(url) @@ -99,8 +105,8 @@ mod tests { use super::*; use std::env; - #[tokio::test] - async fn test_register() -> Result<()> { + #[test] + fn test_register() -> Result<()> { let matchers = REGISTRY.lock().unwrap().matchers.len(); assert!(!REGISTRY.lock().unwrap().matchers.is_empty()); REGISTRY.lock().unwrap().matchers.remove(&None::); @@ -117,8 +123,8 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_default_matcher() -> Result<()> { + #[test] + fn test_default_matcher() -> Result<()> { let matcher = get("https://foo.com"); let version = Version::new(16, 3, 0); let os = env::consts::OS; diff --git a/postgresql_archive/src/repository/github/repository.rs b/postgresql_archive/src/repository/github/repository.rs index b22ca75..973f0c1 100644 --- a/postgresql_archive/src/repository/github/repository.rs +++ b/postgresql_archive/src/repository/github/repository.rs @@ -1,10 +1,11 @@ +use crate::hasher::registry::HasherFn; use crate::repository::github::models::{Asset, Release}; use crate::repository::model::Repository; use crate::repository::Archive; use crate::Error::{ ArchiveHashMismatch, AssetHashNotFound, AssetNotFound, RepositoryFailure, VersionNotFound, }; -use crate::{matcher, Result}; +use crate::{hasher, matcher, Result}; use async_trait::async_trait; use bytes::Bytes; use http::{header, Extensions}; @@ -16,7 +17,6 @@ use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::RetryTransientMiddleware; use reqwest_tracing::TracingMiddleware; use semver::{Version, VersionReq}; -use sha2::{Digest, Sha256}; use std::env; use std::str::FromStr; use tracing::{debug, instrument, warn}; @@ -26,7 +26,7 @@ const GITHUB_API_VERSION_HEADER: &str = "X-GitHub-Api-Version"; const GITHUB_API_VERSION: &str = "2022-11-28"; lazy_static! { - static ref GITHUB_TOKEN: Option = match std::env::var("GITHUB_TOKEN") { + static ref GITHUB_TOKEN: Option = match env::var("GITHUB_TOKEN") { Ok(token) => { debug!("GITHUB_TOKEN environment variable found"); Some(token) @@ -200,7 +200,11 @@ impl GitHub { /// # Errors /// * If the asset is not found. #[instrument(level = "debug", skip(version, release))] - fn get_asset(&self, version: &Version, release: &Release) -> Result<(Asset, Option)> { + fn get_asset( + &self, + version: &Version, + release: &Release, + ) -> Result<(Asset, Option, Option)> { let matcher = matcher::registry::get(&self.url); let mut release_asset: Option = None; for asset in &release.assets { @@ -214,16 +218,26 @@ impl GitHub { return Err(AssetNotFound); }; + // Attempt to find the asset hash for the asset. let mut asset_hash: Option = None; - let hash_name = format!("{}.sha256", asset.name); + let mut asset_hasher_fn: Option = None; for release_asset in &release.assets { - if release_asset.name == hash_name { + let release_asset_name = release_asset.name.as_str(); + if !release_asset_name.starts_with(&asset.name) { + continue; + } + let extension = release_asset_name + .strip_prefix(format!("{}.", asset.name.as_str()).as_str()) + .unwrap_or_default(); + + if let Some(hasher_fn) = hasher::registry::get(extension) { asset_hash = Some(release_asset.clone()); + asset_hasher_fn = Some(hasher_fn); break; } } - Ok((asset, asset_hash)) + Ok((asset, asset_hash, asset_hasher_fn)) } } @@ -246,7 +260,7 @@ impl Repository for GitHub { async fn get_archive(&self, version_req: &VersionReq) -> Result { let release = self.get_release(version_req).await?; let version = Self::get_version_from_tag_name(release.tag_name.as_str())?; - let (asset, asset_hash) = self.get_asset(&version, &release)?; + let (asset, asset_hash, asset_hasher_fn) = self.get_asset(&version, &release)?; let name = asset.name.clone(); let client = reqwest_client(); @@ -280,9 +294,10 @@ impl Repository for GitHub { human_bytes(text.len() as f64) ); - let mut hasher = Sha256::new(); - hasher.update(&archive); - let archive_hash = hex::encode(hasher.finalize()); + let archive_hash = match asset_hasher_fn { + Some(hasher_fn) => hasher_fn(&bytes)?, + None => String::new(), + }; if archive_hash != hash { return Err(ArchiveHashMismatch { archive_hash, hash });