mirror of
https://github.com/aaif-goose/goose.git
synced 2026-06-02 06:14:27 +02:00
Skip registering cancelled local downloads
Signed-off-by: jh-block <jhugo@block.xyz>
This commit is contained in:
@@ -523,6 +523,12 @@ fn mark_download_failed(model_id: &str, error: impl std::fmt::Display) {
|
||||
});
|
||||
}
|
||||
|
||||
fn model_download_completed(model_id: &str) -> bool {
|
||||
get_download_manager()
|
||||
.get_progress(&format!("{}-model", model_id))
|
||||
.is_some_and(|progress| progress.status == DownloadStatus::Completed)
|
||||
}
|
||||
|
||||
fn register_pending_download_model(
|
||||
model_id: &str,
|
||||
req: &DownloadModelRequest,
|
||||
@@ -634,6 +640,9 @@ pub async fn download_hf_model(
|
||||
};
|
||||
match resolved {
|
||||
Ok(resolved) => {
|
||||
if !model_download_completed(&model_id_for_task) {
|
||||
return;
|
||||
}
|
||||
if let Err(error) = register_resolved_model(resolved, &spec) {
|
||||
mark_download_failed(&model_id_for_task, error);
|
||||
}
|
||||
@@ -822,3 +831,44 @@ pub fn routes(state: Arc<AppState>) -> Router {
|
||||
)
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn progress_for(model_id: &str, status: DownloadStatus) -> DownloadProgress {
|
||||
DownloadProgress {
|
||||
model_id: format!("{}-model", model_id),
|
||||
status,
|
||||
bytes_downloaded: 0,
|
||||
total_bytes: 0,
|
||||
progress_percent: 0.0,
|
||||
speed_bps: None,
|
||||
eta_seconds: None,
|
||||
error: None,
|
||||
task_exited: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_download_completed_requires_completed_progress() {
|
||||
let model_id = "test-completed-registration-gate";
|
||||
let manager = get_download_manager();
|
||||
manager.set_progress(progress_for(model_id, DownloadStatus::Completed));
|
||||
|
||||
assert!(model_download_completed(model_id));
|
||||
|
||||
manager.clear_completed(&format!("{}-model", model_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_download_completed_rejects_cancelled_progress() {
|
||||
let model_id = "test-cancelled-registration-gate";
|
||||
let manager = get_download_manager();
|
||||
manager.set_progress(progress_for(model_id, DownloadStatus::Cancelled));
|
||||
|
||||
assert!(!model_download_completed(model_id));
|
||||
|
||||
manager.clear_completed(&format!("{}-model", model_id));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1073,6 +1073,38 @@ mod tests {
|
||||
assert_eq!(model_stem_from_repo("someone/SomeModel"), "somemodel");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hf_download_progress_init_preserves_cancelled_reservation() {
|
||||
let model_id = "test-cancelled-hf-progress-init";
|
||||
let download_id = format!("{}-model", model_id);
|
||||
let manager = crate::download_manager::get_download_manager();
|
||||
manager.set_progress(crate::download_manager::DownloadProgress {
|
||||
model_id: download_id.clone(),
|
||||
status: crate::download_manager::DownloadStatus::Cancelled,
|
||||
bytes_downloaded: 0,
|
||||
total_bytes: 0,
|
||||
progress_percent: 0.0,
|
||||
speed_bps: None,
|
||||
eta_seconds: None,
|
||||
error: None,
|
||||
task_exited: false,
|
||||
});
|
||||
|
||||
HfDownloadProgress::new(model_id.to_string(), 42).init();
|
||||
|
||||
let progress = manager.get_progress(&download_id).expect("progress");
|
||||
assert_eq!(
|
||||
progress.status,
|
||||
crate::download_manager::DownloadStatus::Cancelled
|
||||
);
|
||||
assert!(!progress.task_exited);
|
||||
|
||||
manager.update_progress(&download_id, |progress| {
|
||||
progress.task_exited = true;
|
||||
});
|
||||
manager.clear_completed(&download_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_model_file() {
|
||||
let stem = "gemma-3-27b-it";
|
||||
@@ -1839,9 +1871,24 @@ impl HfDownloadProgress {
|
||||
}
|
||||
|
||||
fn init(&self) {
|
||||
crate::download_manager::get_download_manager().set_progress(
|
||||
crate::download_manager::DownloadProgress {
|
||||
model_id: format!("{}-model", self.model_id),
|
||||
let manager = crate::download_manager::get_download_manager();
|
||||
let download_id = format!("{}-model", self.model_id);
|
||||
if manager.get_progress(&download_id).is_some() {
|
||||
manager.update_progress(&download_id, |progress| {
|
||||
if progress.status != crate::download_manager::DownloadStatus::Cancelled {
|
||||
progress.status = crate::download_manager::DownloadStatus::Downloading;
|
||||
progress.bytes_downloaded = 0;
|
||||
progress.total_bytes = self.total_bytes;
|
||||
progress.progress_percent = 0.0;
|
||||
progress.speed_bps = None;
|
||||
progress.eta_seconds = None;
|
||||
progress.error = None;
|
||||
progress.task_exited = false;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
manager.set_progress(crate::download_manager::DownloadProgress {
|
||||
model_id: download_id,
|
||||
status: crate::download_manager::DownloadStatus::Downloading,
|
||||
bytes_downloaded: 0,
|
||||
total_bytes: self.total_bytes,
|
||||
@@ -1850,8 +1897,8 @@ impl HfDownloadProgress {
|
||||
eta_seconds: None,
|
||||
error: None,
|
||||
task_exited: false,
|
||||
},
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn is_cancelled(&self) -> bool {
|
||||
|
||||
Reference in New Issue
Block a user