mistralrs_server_core/
cached_responses.rs1use anyhow::Result;
4use std::collections::HashMap;
5use std::sync::LazyLock;
6use std::sync::{Arc, RwLock};
7
8use crate::openai::{Message, ResponsesChunk, ResponsesObject};
9
10pub trait ResponseCache: Send + Sync {
12 fn store_response(&self, id: String, response: ResponsesObject) -> Result<()>;
14
15 fn get_response(&self, id: &str) -> Result<Option<ResponsesObject>>;
17
18 fn delete_response(&self, id: &str) -> Result<bool>;
20
21 fn store_chunks(&self, id: String, chunks: Vec<ResponsesChunk>) -> Result<()>;
23
24 fn get_chunks(&self, id: &str) -> Result<Option<Vec<ResponsesChunk>>>;
26
27 fn store_conversation_history(&self, id: String, messages: Vec<Message>) -> Result<()>;
29
30 fn get_conversation_history(&self, id: &str) -> Result<Option<Vec<Message>>>;
32}
33
34pub struct InMemoryResponseCache {
36 responses: Arc<RwLock<HashMap<String, ResponsesObject>>>,
37 chunks: Arc<RwLock<HashMap<String, Vec<ResponsesChunk>>>>,
38 conversation_histories: Arc<RwLock<HashMap<String, Vec<Message>>>>,
39}
40
41impl InMemoryResponseCache {
42 pub fn new() -> Self {
44 Self {
45 responses: Arc::new(RwLock::new(HashMap::new())),
46 chunks: Arc::new(RwLock::new(HashMap::new())),
47 conversation_histories: Arc::new(RwLock::new(HashMap::new())),
48 }
49 }
50}
51
52impl Default for InMemoryResponseCache {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl ResponseCache for InMemoryResponseCache {
59 fn store_response(&self, id: String, response: ResponsesObject) -> Result<()> {
60 let mut responses = self.responses.write().unwrap();
61 responses.insert(id, response);
62 Ok(())
63 }
64
65 fn get_response(&self, id: &str) -> Result<Option<ResponsesObject>> {
66 let responses = self.responses.read().unwrap();
67 Ok(responses.get(id).cloned())
68 }
69
70 fn delete_response(&self, id: &str) -> Result<bool> {
71 let mut responses = self.responses.write().unwrap();
72 let mut chunks = self.chunks.write().unwrap();
73 let mut histories = self.conversation_histories.write().unwrap();
74
75 let response_removed = responses.remove(id).is_some();
76 let chunks_removed = chunks.remove(id).is_some();
77 let history_removed = histories.remove(id).is_some();
78
79 Ok(response_removed || chunks_removed || history_removed)
80 }
81
82 fn store_chunks(&self, id: String, chunks: Vec<ResponsesChunk>) -> Result<()> {
83 let mut chunk_storage = self.chunks.write().unwrap();
84 chunk_storage.insert(id, chunks);
85 Ok(())
86 }
87
88 fn get_chunks(&self, id: &str) -> Result<Option<Vec<ResponsesChunk>>> {
89 let chunks = self.chunks.read().unwrap();
90 Ok(chunks.get(id).cloned())
91 }
92
93 fn store_conversation_history(&self, id: String, messages: Vec<Message>) -> Result<()> {
94 let mut histories = self.conversation_histories.write().unwrap();
95 histories.insert(id, messages);
96 Ok(())
97 }
98
99 fn get_conversation_history(&self, id: &str) -> Result<Option<Vec<Message>>> {
100 let histories = self.conversation_histories.read().unwrap();
101 Ok(histories.get(id).cloned())
102 }
103}
104
105pub static RESPONSE_CACHE: LazyLock<Arc<dyn ResponseCache>> =
107 LazyLock::new(|| Arc::new(InMemoryResponseCache::new()));
108
109pub fn get_response_cache() -> Arc<dyn ResponseCache> {
111 RESPONSE_CACHE.clone()
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_in_memory_cache() {
120 let cache = InMemoryResponseCache::new();
121
122 let response = ResponsesObject {
124 id: "test-id".to_string(),
125 object: "response",
126 created_at: 1234567890.0,
127 model: "test-model".to_string(),
128 status: "completed".to_string(),
129 output: vec![],
130 output_text: None,
131 usage: None,
132 error: None,
133 metadata: None,
134 instructions: None,
135 incomplete_details: None,
136 };
137
138 cache
140 .store_response("test-id".to_string(), response.clone())
141 .unwrap();
142 let retrieved = cache.get_response("test-id").unwrap();
143 assert!(retrieved.is_some());
144 assert_eq!(retrieved.unwrap().id, "test-id");
145
146 let deleted = cache.delete_response("test-id").unwrap();
148 assert!(deleted);
149 let retrieved = cache.get_response("test-id").unwrap();
150 assert!(retrieved.is_none());
151 }
152}