mistralrs_server_core/
background_tasks.rs

1//! ## Background task management for the Responses API.
2//!
3//! This module handles background processing of responses when `background: true` is set.
4
5use std::{
6    collections::HashMap,
7    sync::{Arc, RwLock},
8    time::{SystemTime, UNIX_EPOCH},
9};
10
11use uuid::Uuid;
12
13use crate::responses_types::{ResponseError, ResponseResource, ResponseStatus};
14
15/// State of a background task
16#[derive(Debug, Clone)]
17#[allow(clippy::large_enum_variant)]
18pub enum BackgroundTaskState {
19    /// Task is queued
20    Queued,
21    /// Task is in progress
22    InProgress,
23    /// Task completed successfully
24    Completed(ResponseResource),
25    /// Task failed
26    Failed(ResponseError),
27    /// Task was cancelled
28    Cancelled,
29}
30
31/// A background task for processing responses
32#[derive(Debug)]
33pub struct BackgroundTask {
34    /// Task ID (same as response ID)
35    pub id: String,
36    /// Current state
37    pub state: BackgroundTaskState,
38    /// Created timestamp
39    pub created_at: u64,
40    /// Model name
41    pub model: String,
42    /// Cancellation flag
43    pub cancel_requested: bool,
44}
45
46impl BackgroundTask {
47    /// Create a new background task
48    pub fn new(id: String, model: String) -> Self {
49        let created_at = SystemTime::now()
50            .duration_since(UNIX_EPOCH)
51            .unwrap()
52            .as_secs();
53
54        Self {
55            id,
56            state: BackgroundTaskState::Queued,
57            created_at,
58            model,
59            cancel_requested: false,
60        }
61    }
62
63    /// Convert the current task state to a ResponseResource
64    pub fn to_response_resource(&self) -> ResponseResource {
65        let mut resource =
66            ResponseResource::new(self.id.clone(), self.model.clone(), self.created_at);
67
68        match &self.state {
69            BackgroundTaskState::Queued => {
70                resource.status = ResponseStatus::Queued;
71            }
72            BackgroundTaskState::InProgress => {
73                resource.status = ResponseStatus::InProgress;
74            }
75            BackgroundTaskState::Completed(resp) => {
76                return resp.clone();
77            }
78            BackgroundTaskState::Failed(error) => {
79                resource.status = ResponseStatus::Failed;
80                resource.error = Some(error.clone());
81            }
82            BackgroundTaskState::Cancelled => {
83                resource.status = ResponseStatus::Cancelled;
84            }
85        }
86
87        resource
88    }
89}
90
91/// Manager for background tasks
92#[derive(Debug, Default)]
93pub struct BackgroundTaskManager {
94    /// Map of task ID to task
95    tasks: Arc<RwLock<HashMap<String, BackgroundTask>>>,
96}
97
98impl BackgroundTaskManager {
99    /// Create a new background task manager
100    pub fn new() -> Self {
101        Self {
102            tasks: Arc::new(RwLock::new(HashMap::new())),
103        }
104    }
105
106    /// Create a new background task and return its ID
107    pub fn create_task(&self, model: String) -> String {
108        let id = format!("resp_{}", Uuid::new_v4());
109        let task = BackgroundTask::new(id.clone(), model);
110
111        let mut tasks = self.tasks.write().unwrap();
112        tasks.insert(id.clone(), task);
113
114        id
115    }
116
117    /// Get the current state of a task
118    pub fn get_task(&self, id: &str) -> Option<BackgroundTask> {
119        let tasks = self.tasks.read().unwrap();
120        tasks.get(id).cloned()
121    }
122
123    /// Get the response resource for a task
124    pub fn get_response(&self, id: &str) -> Option<ResponseResource> {
125        let tasks = self.tasks.read().unwrap();
126        tasks.get(id).map(|t| t.to_response_resource())
127    }
128
129    /// Update task to in_progress state
130    pub fn mark_in_progress(&self, id: &str) -> bool {
131        let mut tasks = self.tasks.write().unwrap();
132        if let Some(task) = tasks.get_mut(id) {
133            task.state = BackgroundTaskState::InProgress;
134            true
135        } else {
136            false
137        }
138    }
139
140    /// Update task to completed state
141    pub fn mark_completed(&self, id: &str, response: ResponseResource) -> bool {
142        let mut tasks = self.tasks.write().unwrap();
143        if let Some(task) = tasks.get_mut(id) {
144            task.state = BackgroundTaskState::Completed(response);
145            true
146        } else {
147            false
148        }
149    }
150
151    /// Update task to failed state
152    pub fn mark_failed(&self, id: &str, error: ResponseError) -> bool {
153        let mut tasks = self.tasks.write().unwrap();
154        if let Some(task) = tasks.get_mut(id) {
155            task.state = BackgroundTaskState::Failed(error);
156            true
157        } else {
158            false
159        }
160    }
161
162    /// Request cancellation of a task
163    pub fn request_cancel(&self, id: &str) -> bool {
164        let mut tasks = self.tasks.write().unwrap();
165        if let Some(task) = tasks.get_mut(id) {
166            if matches!(
167                task.state,
168                BackgroundTaskState::Queued | BackgroundTaskState::InProgress
169            ) {
170                task.cancel_requested = true;
171                return true;
172            }
173        }
174        false
175    }
176
177    /// Check if cancellation was requested for a task
178    pub fn is_cancel_requested(&self, id: &str) -> bool {
179        let tasks = self.tasks.read().unwrap();
180        tasks.get(id).map(|t| t.cancel_requested).unwrap_or(false)
181    }
182
183    /// Mark task as cancelled
184    pub fn mark_cancelled(&self, id: &str) -> bool {
185        let mut tasks = self.tasks.write().unwrap();
186        if let Some(task) = tasks.get_mut(id) {
187            task.state = BackgroundTaskState::Cancelled;
188            true
189        } else {
190            false
191        }
192    }
193
194    /// Delete a task
195    pub fn delete_task(&self, id: &str) -> bool {
196        let mut tasks = self.tasks.write().unwrap();
197        tasks.remove(id).is_some()
198    }
199
200    /// List all task IDs
201    pub fn list_tasks(&self) -> Vec<String> {
202        let tasks = self.tasks.read().unwrap();
203        tasks.keys().cloned().collect()
204    }
205
206    /// Clean up old completed/failed tasks older than the given duration (in seconds)
207    pub fn cleanup_old_tasks(&self, max_age_secs: u64) {
208        let now = SystemTime::now()
209            .duration_since(UNIX_EPOCH)
210            .unwrap()
211            .as_secs();
212
213        let mut tasks = self.tasks.write().unwrap();
214        tasks.retain(|_, task| {
215            // Keep queued and in-progress tasks
216            if matches!(
217                task.state,
218                BackgroundTaskState::Queued | BackgroundTaskState::InProgress
219            ) {
220                return true;
221            }
222
223            // Remove old completed/failed/cancelled tasks
224            now - task.created_at < max_age_secs
225        });
226    }
227}
228
229impl Clone for BackgroundTask {
230    fn clone(&self) -> Self {
231        Self {
232            id: self.id.clone(),
233            state: self.state.clone(),
234            created_at: self.created_at,
235            model: self.model.clone(),
236            cancel_requested: self.cancel_requested,
237        }
238    }
239}
240
241/// Global background task manager
242static BACKGROUND_TASK_MANAGER: std::sync::LazyLock<BackgroundTaskManager> =
243    std::sync::LazyLock::new(BackgroundTaskManager::new);
244
245/// Get the global background task manager
246pub fn get_background_task_manager() -> &'static BackgroundTaskManager {
247    &BACKGROUND_TASK_MANAGER
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_create_and_get_task() {
256        let manager = BackgroundTaskManager::new();
257        let id = manager.create_task("test-model".to_string());
258
259        let task = manager.get_task(&id).unwrap();
260        assert_eq!(task.id, id);
261        assert!(matches!(task.state, BackgroundTaskState::Queued));
262    }
263
264    #[test]
265    fn test_task_state_transitions() {
266        let manager = BackgroundTaskManager::new();
267        let id = manager.create_task("test-model".to_string());
268
269        // Move to in_progress
270        assert!(manager.mark_in_progress(&id));
271        let task = manager.get_task(&id).unwrap();
272        assert!(matches!(task.state, BackgroundTaskState::InProgress));
273
274        // Mark completed
275        let response = ResponseResource::new(id.clone(), "test-model".to_string(), 0);
276        assert!(manager.mark_completed(&id, response));
277        let task = manager.get_task(&id).unwrap();
278        assert!(matches!(task.state, BackgroundTaskState::Completed(_)));
279    }
280
281    #[test]
282    fn test_cancel_task() {
283        let manager = BackgroundTaskManager::new();
284        let id = manager.create_task("test-model".to_string());
285
286        // Request cancellation
287        assert!(manager.request_cancel(&id));
288        assert!(manager.is_cancel_requested(&id));
289
290        // Mark as cancelled
291        assert!(manager.mark_cancelled(&id));
292        let task = manager.get_task(&id).unwrap();
293        assert!(matches!(task.state, BackgroundTaskState::Cancelled));
294    }
295}