mistralrs_server_core/
background_tasks.rs1use std::{
6 collections::HashMap,
7 sync::{Arc, RwLock},
8 time::{SystemTime, UNIX_EPOCH},
9};
10
11use uuid::Uuid;
12
13use crate::responses_types::{ResponseError, ResponseResource, ResponseStatus};
14
15#[derive(Debug, Clone)]
17#[allow(clippy::large_enum_variant)]
18pub enum BackgroundTaskState {
19 Queued,
21 InProgress,
23 Completed(ResponseResource),
25 Failed(ResponseError),
27 Cancelled,
29}
30
31#[derive(Debug)]
33pub struct BackgroundTask {
34 pub id: String,
36 pub state: BackgroundTaskState,
38 pub created_at: u64,
40 pub model: String,
42 pub cancel_requested: bool,
44}
45
46impl BackgroundTask {
47 pub fn new(id: String, model: String) -> Self {
49 let created_at = SystemTime::now()
50 .duration_since(UNIX_EPOCH)
51 .unwrap()
52 .as_secs();
53
54 Self {
55 id,
56 state: BackgroundTaskState::Queued,
57 created_at,
58 model,
59 cancel_requested: false,
60 }
61 }
62
63 pub fn to_response_resource(&self) -> ResponseResource {
65 let mut resource =
66 ResponseResource::new(self.id.clone(), self.model.clone(), self.created_at);
67
68 match &self.state {
69 BackgroundTaskState::Queued => {
70 resource.status = ResponseStatus::Queued;
71 }
72 BackgroundTaskState::InProgress => {
73 resource.status = ResponseStatus::InProgress;
74 }
75 BackgroundTaskState::Completed(resp) => {
76 return resp.clone();
77 }
78 BackgroundTaskState::Failed(error) => {
79 resource.status = ResponseStatus::Failed;
80 resource.error = Some(error.clone());
81 }
82 BackgroundTaskState::Cancelled => {
83 resource.status = ResponseStatus::Cancelled;
84 }
85 }
86
87 resource
88 }
89}
90
91#[derive(Debug, Default)]
93pub struct BackgroundTaskManager {
94 tasks: Arc<RwLock<HashMap<String, BackgroundTask>>>,
96}
97
98impl BackgroundTaskManager {
99 pub fn new() -> Self {
101 Self {
102 tasks: Arc::new(RwLock::new(HashMap::new())),
103 }
104 }
105
106 pub fn create_task(&self, model: String) -> String {
108 let id = format!("resp_{}", Uuid::new_v4());
109 let task = BackgroundTask::new(id.clone(), model);
110
111 let mut tasks = self.tasks.write().unwrap();
112 tasks.insert(id.clone(), task);
113
114 id
115 }
116
117 pub fn get_task(&self, id: &str) -> Option<BackgroundTask> {
119 let tasks = self.tasks.read().unwrap();
120 tasks.get(id).cloned()
121 }
122
123 pub fn get_response(&self, id: &str) -> Option<ResponseResource> {
125 let tasks = self.tasks.read().unwrap();
126 tasks.get(id).map(|t| t.to_response_resource())
127 }
128
129 pub fn mark_in_progress(&self, id: &str) -> bool {
131 let mut tasks = self.tasks.write().unwrap();
132 if let Some(task) = tasks.get_mut(id) {
133 task.state = BackgroundTaskState::InProgress;
134 true
135 } else {
136 false
137 }
138 }
139
140 pub fn mark_completed(&self, id: &str, response: ResponseResource) -> bool {
142 let mut tasks = self.tasks.write().unwrap();
143 if let Some(task) = tasks.get_mut(id) {
144 task.state = BackgroundTaskState::Completed(response);
145 true
146 } else {
147 false
148 }
149 }
150
151 pub fn mark_failed(&self, id: &str, error: ResponseError) -> bool {
153 let mut tasks = self.tasks.write().unwrap();
154 if let Some(task) = tasks.get_mut(id) {
155 task.state = BackgroundTaskState::Failed(error);
156 true
157 } else {
158 false
159 }
160 }
161
162 pub fn request_cancel(&self, id: &str) -> bool {
164 let mut tasks = self.tasks.write().unwrap();
165 if let Some(task) = tasks.get_mut(id) {
166 if matches!(
167 task.state,
168 BackgroundTaskState::Queued | BackgroundTaskState::InProgress
169 ) {
170 task.cancel_requested = true;
171 return true;
172 }
173 }
174 false
175 }
176
177 pub fn is_cancel_requested(&self, id: &str) -> bool {
179 let tasks = self.tasks.read().unwrap();
180 tasks.get(id).map(|t| t.cancel_requested).unwrap_or(false)
181 }
182
183 pub fn mark_cancelled(&self, id: &str) -> bool {
185 let mut tasks = self.tasks.write().unwrap();
186 if let Some(task) = tasks.get_mut(id) {
187 task.state = BackgroundTaskState::Cancelled;
188 true
189 } else {
190 false
191 }
192 }
193
194 pub fn delete_task(&self, id: &str) -> bool {
196 let mut tasks = self.tasks.write().unwrap();
197 tasks.remove(id).is_some()
198 }
199
200 pub fn list_tasks(&self) -> Vec<String> {
202 let tasks = self.tasks.read().unwrap();
203 tasks.keys().cloned().collect()
204 }
205
206 pub fn cleanup_old_tasks(&self, max_age_secs: u64) {
208 let now = SystemTime::now()
209 .duration_since(UNIX_EPOCH)
210 .unwrap()
211 .as_secs();
212
213 let mut tasks = self.tasks.write().unwrap();
214 tasks.retain(|_, task| {
215 if matches!(
217 task.state,
218 BackgroundTaskState::Queued | BackgroundTaskState::InProgress
219 ) {
220 return true;
221 }
222
223 now - task.created_at < max_age_secs
225 });
226 }
227}
228
229impl Clone for BackgroundTask {
230 fn clone(&self) -> Self {
231 Self {
232 id: self.id.clone(),
233 state: self.state.clone(),
234 created_at: self.created_at,
235 model: self.model.clone(),
236 cancel_requested: self.cancel_requested,
237 }
238 }
239}
240
241static BACKGROUND_TASK_MANAGER: std::sync::LazyLock<BackgroundTaskManager> =
243 std::sync::LazyLock::new(BackgroundTaskManager::new);
244
245pub fn get_background_task_manager() -> &'static BackgroundTaskManager {
247 &BACKGROUND_TASK_MANAGER
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_create_and_get_task() {
256 let manager = BackgroundTaskManager::new();
257 let id = manager.create_task("test-model".to_string());
258
259 let task = manager.get_task(&id).unwrap();
260 assert_eq!(task.id, id);
261 assert!(matches!(task.state, BackgroundTaskState::Queued));
262 }
263
264 #[test]
265 fn test_task_state_transitions() {
266 let manager = BackgroundTaskManager::new();
267 let id = manager.create_task("test-model".to_string());
268
269 assert!(manager.mark_in_progress(&id));
271 let task = manager.get_task(&id).unwrap();
272 assert!(matches!(task.state, BackgroundTaskState::InProgress));
273
274 let response = ResponseResource::new(id.clone(), "test-model".to_string(), 0);
276 assert!(manager.mark_completed(&id, response));
277 let task = manager.get_task(&id).unwrap();
278 assert!(matches!(task.state, BackgroundTaskState::Completed(_)));
279 }
280
281 #[test]
282 fn test_cancel_task() {
283 let manager = BackgroundTaskManager::new();
284 let id = manager.create_task("test-model".to_string());
285
286 assert!(manager.request_cancel(&id));
288 assert!(manager.is_cancel_requested(&id));
289
290 assert!(manager.mark_cancelled(&id));
292 let task = manager.get_task(&id).unwrap();
293 assert!(matches!(task.state, BackgroundTaskState::Cancelled));
294 }
295}