Skip to content

Commit

Permalink
add custom API and concurrency support for OCI downloads
Browse files Browse the repository at this point in the history
  • Loading branch information
QaidVoid committed Jan 22, 2025
1 parent 75f2551 commit 3cd28c3
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 12 deletions.
8 changes: 8 additions & 0 deletions src/bin/soar-dl/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,12 @@ pub struct Args {
/// Output file path
#[arg(required = false, short, long)]
pub output: Option<String>,

/// GHCR concurrency
#[arg(required = false, short, long)]
pub concurrency: Option<u64>,

/// GHCR API to use
#[arg(required = false, long)]
pub ghcr_api: Option<String>,
}
14 changes: 10 additions & 4 deletions src/bin/soar-dl/download_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use indicatif::HumanBytes;
use regex::Regex;
use serde::Deserialize;
use soar_dl::{
downloader::{DownloadOptions, DownloadState, Downloader},
downloader::{DownloadOptions, DownloadState, Downloader, OciDownloadOptions},
error::{DownloadError, PlatformError},
github::{Github, GithubAsset, GithubRelease},
gitlab::{Gitlab, GitlabAsset, GitlabRelease},
Expand Down Expand Up @@ -81,6 +81,8 @@ impl DownloadManager {
let assets = handler.filter_releases(&releases, &options).await?;

let selected_asset = self.select_asset(&assets)?;

println!("Downloading asset from {}", selected_asset.download_url());
handler.download(&selected_asset, options.clone()).await?;
Ok(())
}
Expand Down Expand Up @@ -130,10 +132,12 @@ impl DownloadManager {
for reference in &self.args.ghcr {
println!("Downloading using OCI reference: {}", reference);

let options = DownloadOptions {
let options = OciDownloadOptions {
url: reference.clone(),
concurrency: self.args.concurrency.clone(),
output_path: self.args.output.clone(),
progress_callback: Some(self.progress_callback.clone()),
api: self.args.ghcr_api.clone(),
};
let _ = downloader
.download_oci(options)
Expand Down Expand Up @@ -187,10 +191,12 @@ impl DownloadManager {
Ok(PlatformUrl::Oci(url)) => {
println!("Downloading using OCI reference: {}", url);

let options = DownloadOptions {
url: link.clone(),
let options = OciDownloadOptions {
url: url.clone(),
concurrency: self.args.concurrency.clone(),
output_path: self.args.output.clone(),
progress_callback: Some(self.progress_callback.clone()),
api: self.args.ghcr_api.clone(),
};
let _ = downloader
.download_oci(options)
Expand Down
19 changes: 16 additions & 3 deletions src/downloader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use reqwest::header::USER_AGENT;
use tokio::{
fs::{self, OpenOptions},
io::AsyncWriteExt,
sync::Semaphore,
task,
};
use url::Url;
Expand Down Expand Up @@ -39,6 +40,15 @@ pub struct Downloader {
client: reqwest::Client,
}

#[derive(Clone)]
pub struct OciDownloadOptions {
pub url: String,
pub concurrency: Option<u64>,
pub output_path: Option<String>,
pub progress_callback: Option<Arc<dyn Fn(DownloadState) + Send + Sync + 'static>>,
pub api: Option<String>,
}

impl Downloader {
pub async fn download(&self, options: DownloadOptions) -> Result<String, DownloadError> {
let url = Url::parse(&options.url).map_err(|err| DownloadError::InvalidUrl {
Expand Down Expand Up @@ -124,7 +134,7 @@ impl Downloader {
pub async fn download_blob(
&self,
client: OciClient,
options: DownloadOptions,
options: OciDownloadOptions,
) -> Result<(), DownloadError> {
let reference = client.reference.clone();
let digest = reference.tag;
Expand Down Expand Up @@ -170,10 +180,10 @@ impl Downloader {
Ok(())
}

pub async fn download_oci(&self, options: DownloadOptions) -> Result<(), DownloadError> {
pub async fn download_oci(&self, options: OciDownloadOptions) -> Result<(), DownloadError> {
let url = options.url.clone();
let reference: Reference = url.into();
let oci_client = OciClient::new(&reference);
let oci_client = OciClient::new(&reference, options.api.clone());

if reference.tag.starts_with("sha256:") {
return self.download_blob(oci_client, options).await;
Expand All @@ -188,6 +198,7 @@ impl Downloader {
callback(DownloadState::Preparing(total_bytes));
}

let semaphore = Arc::new(Semaphore::new(options.concurrency.unwrap_or(1) as usize));
let downloaded_bytes = Arc::new(Mutex::new(0u64));
let outdir = options.output_path;
let base_path = if let Some(dir) = outdir {
Expand All @@ -198,6 +209,7 @@ impl Downloader {
};

for layer in manifest.layers {
let permit = semaphore.clone().acquire_owned().await.unwrap();
let client_clone = oci_client.clone();
let cb_clone = options.progress_callback.clone();
let downloaded_bytes = downloaded_bytes.clone();
Expand All @@ -219,6 +231,7 @@ impl Downloader {

Ok::<(), DownloadError>(())
});
drop(permit);
tasks.push(task);
}

Expand Down
22 changes: 17 additions & 5 deletions src/oci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub struct OciManifest {
pub struct OciClient {
client: reqwest::Client,
pub reference: Reference,
pub api: Option<String>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -86,11 +87,12 @@ impl From<String> for Reference {
}

impl OciClient {
pub fn new(reference: &Reference) -> Self {
pub fn new(reference: &Reference, api: Option<String>) -> Self {
let client = reqwest::Client::new();
Self {
client,
reference: reference.clone(),
api,
}
}

Expand All @@ -103,8 +105,13 @@ impl OciClient {

pub async fn manifest(&self) -> Result<OciManifest, DownloadError> {
let manifest_url = format!(
"https://ghcr.io/v2/{}/manifests/{}",
self.reference.package, self.reference.tag
"{}/{}/manifests/{}",
self.api
.clone()
.unwrap_or("https://ghcr.io/v2".to_string())
.trim_end_matches('/'),
self.reference.package,
self.reference.tag
);
let resp = self
.client
Expand Down Expand Up @@ -139,8 +146,13 @@ impl OciClient {
F: Fn(u64, u64) + Send + 'static,
{
let blob_url = format!(
"https://ghcr.io/v2/{}/blobs/{}",
self.reference.package, layer.digest
"{}/{}/blobs/{}",
self.api
.clone()
.unwrap_or("https://ghcr.io/v2".to_string())
.trim_end_matches('/'),
self.reference.package,
layer.digest
);
let resp = self
.client
Expand Down

0 comments on commit 3cd28c3

Please sign in to comment.