Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 54c4118

Browse filesBrowse files
Merge pull request theseus-rs#86 from theseus-rs/add-configurable-hashers
feat!: add configurable hashers
2 parents 9630a68 + a2c891d commit 54c4118
Copy full SHA for 54c4118

File tree

Expand file treeCollapse file tree

7 files changed

+212
-19
lines changed
Filter options
Expand file treeCollapse file tree

7 files changed

+212
-19
lines changed

‎postgresql_archive/src/hasher/mod.rs

Copy file name to clipboard
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub mod registry;
2+
pub mod sha2_256;
+134Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
use crate::hasher::sha2_256;
2+
use crate::Result;
3+
use lazy_static::lazy_static;
4+
use std::collections::HashMap;
5+
use std::sync::{Arc, Mutex, RwLock};
6+
7+
lazy_static! {
8+
static ref REGISTRY: Arc<Mutex<HasherRegistry>> =
9+
Arc::new(Mutex::new(HasherRegistry::default()));
10+
}
11+
12+
pub type HasherFn = fn(&Vec<u8>) -> Result<String>;
13+
14+
/// Singleton struct to store hashers
15+
struct HasherRegistry {
16+
hashers: HashMap<String, Arc<RwLock<HasherFn>>>,
17+
}
18+
19+
impl HasherRegistry {
20+
/// Creates a new hasher registry.
21+
///
22+
/// # Returns
23+
/// * The hasher registry.
24+
fn new() -> Self {
25+
Self {
26+
hashers: HashMap::new(),
27+
}
28+
}
29+
30+
/// Registers a hasher for an extension. Newly registered hashers with the same extension will
31+
/// override existing ones.
32+
///
33+
/// # Arguments
34+
/// * `extension` - The extension to register the hasher for.
35+
/// * `hasher_fn` - The hasher function to register.
36+
fn register<S: AsRef<str>>(&mut self, extension: S, hasher_fn: HasherFn) {
37+
let extension = extension.as_ref().to_string();
38+
self.hashers
39+
.insert(extension, Arc::new(RwLock::new(hasher_fn)));
40+
}
41+
42+
/// Get a hasher for the specified extension.
43+
///
44+
/// # Arguments
45+
/// * `extension` - The extension to locate a hasher for.
46+
///
47+
/// # Returns
48+
/// * The hasher for the extension or [None] if not found.
49+
fn get<S: AsRef<str>>(&self, extension: S) -> Option<HasherFn> {
50+
let extension = extension.as_ref().to_string();
51+
if let Some(hasher) = self.hashers.get(&extension) {
52+
return Some(*hasher.read().unwrap());
53+
}
54+
55+
None
56+
}
57+
}
58+
59+
impl Default for HasherRegistry {
60+
fn default() -> Self {
61+
let mut registry = Self::new();
62+
registry.register("sha256", sha2_256::hash);
63+
registry
64+
}
65+
}
66+
67+
/// Registers a hasher for an extension. Newly registered hashers with the same extension will
68+
/// override existing ones.
69+
///
70+
/// # Arguments
71+
/// * `extension` - The extension to register the hasher for.
72+
/// * `hasher_fn` - The hasher function to register.
73+
///
74+
/// # Panics
75+
/// * If the registry is poisoned.
76+
#[allow(dead_code)]
77+
pub fn register<S: AsRef<str>>(extension: S, hasher_fn: HasherFn) {
78+
let mut registry = REGISTRY.lock().unwrap();
79+
registry.register(extension, hasher_fn);
80+
}
81+
82+
/// Get a hasher for the specified extension.
83+
///
84+
/// # Arguments
85+
/// * `extension` - The extension to locate a hasher for.
86+
///
87+
/// # Returns
88+
/// * The hasher for the extension or [None] if not found.
89+
///
90+
/// # Panics
91+
/// * If the registry is poisoned.
92+
pub fn get<S: AsRef<str>>(extension: S) -> Option<HasherFn> {
93+
let registry = REGISTRY.lock().unwrap();
94+
registry.get(extension)
95+
}
96+
97+
#[cfg(test)]
98+
mod tests {
99+
use super::*;
100+
101+
#[test]
102+
fn test_register() -> Result<()> {
103+
let extension = "sha256";
104+
let hashers = REGISTRY.lock().unwrap().hashers.len();
105+
assert!(!REGISTRY.lock().unwrap().hashers.is_empty());
106+
REGISTRY.lock().unwrap().hashers.remove(extension);
107+
assert_ne!(hashers, REGISTRY.lock().unwrap().hashers.len());
108+
register(extension, sha2_256::hash);
109+
assert_eq!(hashers, REGISTRY.lock().unwrap().hashers.len());
110+
111+
let hasher = get(extension).unwrap();
112+
let data = vec![1, 2, 3];
113+
let hash = hasher(&data)?;
114+
115+
assert_eq!(
116+
"039058c6f2c0cb492c533b0a4d14ef77cc0f78abccced5287d84a1a2011cfb81",
117+
hash
118+
);
119+
Ok(())
120+
}
121+
122+
#[test]
123+
fn test_sha2_256() -> Result<()> {
124+
let hasher = get("sha256").unwrap();
125+
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0];
126+
let hash = hasher(&data)?;
127+
128+
assert_eq!(
129+
"9a89c68c4c5e28b8c4a5567673d462fff515db46116f9900624d09c474f593fb",
130+
hash
131+
);
132+
Ok(())
133+
}
134+
}
+35Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
use crate::Result;
2+
use sha2::{Digest, Sha256};
3+
4+
/// Hashes the data using SHA2-256.
5+
///
6+
/// # Arguments
7+
/// * `data` - The data to hash.
8+
///
9+
/// # Returns
10+
/// * The hash of the data.
11+
///
12+
/// # Errors
13+
/// * If the data cannot be hashed.
14+
pub fn hash(data: &Vec<u8>) -> Result<String> {
15+
let mut hasher = Sha256::new();
16+
hasher.update(data);
17+
let hash = hex::encode(hasher.finalize());
18+
Ok(hash)
19+
}
20+
21+
#[cfg(test)]
22+
mod tests {
23+
use super::*;
24+
25+
#[test]
26+
fn test_hash() -> Result<()> {
27+
let data = vec![4, 2];
28+
let hash = hash(&data)?;
29+
assert_eq!(
30+
"b7586d310e5efb1b7d10a917ba5af403adbf54f4f77fe7fdcb4880a95dac7e7e",
31+
hash
32+
);
33+
Ok(())
34+
}
35+
}

‎postgresql_archive/src/lib.rs

Copy file name to clipboardExpand all lines: postgresql_archive/src/lib.rs
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ mod archive;
113113
#[cfg(feature = "blocking")]
114114
pub mod blocking;
115115
mod error;
116+
pub mod hasher;
116117
pub mod matcher;
117118
pub mod repository;
118119
mod version;

‎postgresql_archive/src/matcher/mod.rs

Copy file name to clipboard
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
mod default;
2-
mod postgresql_binaries;
3-
pub(crate) mod registry;
1+
pub mod default;
2+
pub mod postgresql_binaries;
3+
pub mod registry;

‎postgresql_archive/src/matcher/registry.rs

Copy file name to clipboardExpand all lines: postgresql_archive/src/matcher/registry.rs
+11-5Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ lazy_static! {
1010
Arc::new(Mutex::new(MatchersRegistry::default()));
1111
}
1212

13-
type MatcherFn = fn(&str, &Version) -> Result<bool>;
13+
pub type MatcherFn = fn(&str, &Version) -> Result<bool>;
1414

1515
/// Singleton struct to store matchers
1616
struct MatchersRegistry {
@@ -75,6 +75,9 @@ impl Default for MatchersRegistry {
7575
/// # Arguments
7676
/// * `url` - The URL to register the matcher for; [None] to register the default.
7777
/// * `matcher_fn` - The matcher function to register.
78+
///
79+
/// # Panics
80+
/// * If the registry is poisoned.
7881
#[allow(dead_code)]
7982
pub fn register<S: AsRef<str>>(url: Option<S>, matcher_fn: MatcherFn) {
8083
let mut registry = REGISTRY.lock().unwrap();
@@ -89,6 +92,9 @@ pub fn register<S: AsRef<str>>(url: Option<S>, matcher_fn: MatcherFn) {
8992
///
9093
/// # Returns
9194
/// * The matcher for the URL, or the default matcher.
95+
///
96+
/// # Panics
97+
/// * If the registry is poisoned.
9298
pub fn get<S: AsRef<str>>(url: S) -> MatcherFn {
9399
let registry = REGISTRY.lock().unwrap();
94100
registry.get(url)
@@ -99,8 +105,8 @@ mod tests {
99105
use super::*;
100106
use std::env;
101107

102-
#[tokio::test]
103-
async fn test_register() -> Result<()> {
108+
#[test]
109+
fn test_register() -> Result<()> {
104110
let matchers = REGISTRY.lock().unwrap().matchers.len();
105111
assert!(!REGISTRY.lock().unwrap().matchers.is_empty());
106112
REGISTRY.lock().unwrap().matchers.remove(&None::<String>);
@@ -117,8 +123,8 @@ mod tests {
117123
Ok(())
118124
}
119125

120-
#[tokio::test]
121-
async fn test_default_matcher() -> Result<()> {
126+
#[test]
127+
fn test_default_matcher() -> Result<()> {
122128
let matcher = get("https://foo.com");
123129
let version = Version::new(16, 3, 0);
124130
let os = env::consts::OS;

‎postgresql_archive/src/repository/github/repository.rs

Copy file name to clipboardExpand all lines: postgresql_archive/src/repository/github/repository.rs
+26-11Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
use crate::hasher::registry::HasherFn;
12
use crate::repository::github::models::{Asset, Release};
23
use crate::repository::model::Repository;
34
use crate::repository::Archive;
45
use crate::Error::{
56
ArchiveHashMismatch, AssetHashNotFound, AssetNotFound, RepositoryFailure, VersionNotFound,
67
};
7-
use crate::{matcher, Result};
8+
use crate::{hasher, matcher, Result};
89
use async_trait::async_trait;
910
use bytes::Bytes;
1011
use http::{header, Extensions};
@@ -16,7 +17,6 @@ use reqwest_retry::policies::ExponentialBackoff;
1617
use reqwest_retry::RetryTransientMiddleware;
1718
use reqwest_tracing::TracingMiddleware;
1819
use semver::{Version, VersionReq};
19-
use sha2::{Digest, Sha256};
2020
use std::env;
2121
use std::str::FromStr;
2222
use tracing::{debug, instrument, warn};
@@ -26,7 +26,7 @@ const GITHUB_API_VERSION_HEADER: &str = "X-GitHub-Api-Version";
2626
const GITHUB_API_VERSION: &str = "2022-11-28";
2727

2828
lazy_static! {
29-
static ref GITHUB_TOKEN: Option<String> = match std::env::var("GITHUB_TOKEN") {
29+
static ref GITHUB_TOKEN: Option<String> = match env::var("GITHUB_TOKEN") {
3030
Ok(token) => {
3131
debug!("GITHUB_TOKEN environment variable found");
3232
Some(token)
@@ -200,7 +200,11 @@ impl GitHub {
200200
/// # Errors
201201
/// * If the asset is not found.
202202
#[instrument(level = "debug", skip(version, release))]
203-
fn get_asset(&self, version: &Version, release: &Release) -> Result<(Asset, Option<Asset>)> {
203+
fn get_asset(
204+
&self,
205+
version: &Version,
206+
release: &Release,
207+
) -> Result<(Asset, Option<Asset>, Option<HasherFn>)> {
204208
let matcher = matcher::registry::get(&self.url);
205209
let mut release_asset: Option<Asset> = None;
206210
for asset in &release.assets {
@@ -214,16 +218,26 @@ impl GitHub {
214218
return Err(AssetNotFound);
215219
};
216220

221+
// Attempt to find the asset hash for the asset.
217222
let mut asset_hash: Option<Asset> = None;
218-
let hash_name = format!("{}.sha256", asset.name);
223+
let mut asset_hasher_fn: Option<HasherFn> = None;
219224
for release_asset in &release.assets {
220-
if release_asset.name == hash_name {
225+
let release_asset_name = release_asset.name.as_str();
226+
if !release_asset_name.starts_with(&asset.name) {
227+
continue;
228+
}
229+
let extension = release_asset_name
230+
.strip_prefix(format!("{}.", asset.name.as_str()).as_str())
231+
.unwrap_or_default();
232+
233+
if let Some(hasher_fn) = hasher::registry::get(extension) {
221234
asset_hash = Some(release_asset.clone());
235+
asset_hasher_fn = Some(hasher_fn);
222236
break;
223237
}
224238
}
225239

226-
Ok((asset, asset_hash))
240+
Ok((asset, asset_hash, asset_hasher_fn))
227241
}
228242
}
229243

@@ -246,7 +260,7 @@ impl Repository for GitHub {
246260
async fn get_archive(&self, version_req: &VersionReq) -> Result<Archive> {
247261
let release = self.get_release(version_req).await?;
248262
let version = Self::get_version_from_tag_name(release.tag_name.as_str())?;
249-
let (asset, asset_hash) = self.get_asset(&version, &release)?;
263+
let (asset, asset_hash, asset_hasher_fn) = self.get_asset(&version, &release)?;
250264
let name = asset.name.clone();
251265

252266
let client = reqwest_client();
@@ -280,9 +294,10 @@ impl Repository for GitHub {
280294
human_bytes(text.len() as f64)
281295
);
282296

283-
let mut hasher = Sha256::new();
284-
hasher.update(&archive);
285-
let archive_hash = hex::encode(hasher.finalize());
297+
let archive_hash = match asset_hasher_fn {
298+
Some(hasher_fn) => hasher_fn(&bytes)?,
299+
None => String::new(),
300+
};
286301

287302
if archive_hash != hash {
288303
return Err(ArchiveHashMismatch { archive_hash, hash });

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.