mistralrs_server_core/
util.rs1use image::DynamicImage;
4use mistralrs_core::AudioInput;
5use mistralrs_core::MistralRs;
6use std::error::Error;
7use std::sync::Arc;
8use tokio::{
9 fs::{self, File},
10 io::AsyncReadExt,
11};
12
13pub async fn parse_image_url(url_unparsed: &str) -> Result<DynamicImage, anyhow::Error> {
46 let url = if let Ok(url) = url::Url::parse(url_unparsed) {
47 url
48 } else if File::open(url_unparsed).await.is_ok() {
49 url::Url::from_file_path(std::path::absolute(url_unparsed)?)
50 .map_err(|_| anyhow::anyhow!("Could not parse file path: {}", url_unparsed))?
51 } else {
52 anyhow::bail!(
53 "Invalid source '{}': not a valid URL (http/https/data) and file not found on server. \
54 Use a full URL, a data URL, or an absolute file path that exists on the server.",
55 url_unparsed
56 )
57 };
58
59 let bytes = if url.scheme() == "http" || url.scheme() == "https" {
60 match reqwest::get(url.clone()).await {
62 Ok(http_resp) => http_resp.bytes().await?.to_vec(),
63 Err(e) => anyhow::bail!(e),
64 }
65 } else if url.scheme() == "file" {
66 let path = url
67 .to_file_path()
68 .map_err(|_| anyhow::anyhow!("Could not parse file path: {}", url))?;
69
70 if let Ok(mut f) = File::open(&path).await {
71 let metadata = fs::metadata(&path).await?;
73 let mut buffer = vec![0; metadata.len() as usize];
74 f.read_exact(&mut buffer).await?;
75 buffer
76 } else {
77 anyhow::bail!("Could not open file at path: {}", url);
78 }
79 } else if url.scheme() == "data" {
80 let data_url = data_url::DataUrl::process(url.as_str())?;
82 data_url.decode_to_vec()?.0
83 } else {
84 anyhow::bail!("Unsupported URL scheme: {}", url.scheme());
85 };
86
87 Ok(image::load_from_memory(&bytes)?)
88}
89
90pub async fn parse_audio_url(url_unparsed: &str) -> Result<AudioInput, anyhow::Error> {
92 let url = if let Ok(url) = url::Url::parse(url_unparsed) {
93 url
94 } else if File::open(url_unparsed).await.is_ok() {
95 url::Url::from_file_path(std::path::absolute(url_unparsed)?)
96 .map_err(|_| anyhow::anyhow!("Could not parse file path: {}", url_unparsed))?
97 } else {
98 anyhow::bail!(
99 "Invalid source '{}': not a valid URL (http/https/data) and file not found on server. \
100 Use a full URL, a data URL, or an absolute file path that exists on the server.",
101 url_unparsed
102 )
103 };
104
105 let bytes = if url.scheme() == "http" || url.scheme() == "https" {
106 match reqwest::get(url.clone()).await {
107 Ok(http_resp) => http_resp.bytes().await?.to_vec(),
108 Err(e) => anyhow::bail!(e),
109 }
110 } else if url.scheme() == "file" {
111 let path = url
112 .to_file_path()
113 .map_err(|_| anyhow::anyhow!("Could not parse file path: {}", url))?;
114
115 if let Ok(mut f) = File::open(&path).await {
116 let metadata = fs::metadata(&path).await?;
117 let mut buffer = vec![0; metadata.len() as usize];
118 f.read_exact(&mut buffer).await?;
119 buffer
120 } else {
121 anyhow::bail!("Could not open file at path: {}", url);
122 }
123 } else if url.scheme() == "data" {
124 let data_url = data_url::DataUrl::process(url.as_str())?;
125 data_url.decode_to_vec()?.0
126 } else {
127 anyhow::bail!("Unsupported URL scheme: {}", url.scheme());
128 };
129
130 AudioInput::from_bytes(&bytes)
131}
132
133pub fn validate_model_name(
151 requested_model: &str,
152 state: Arc<MistralRs>,
153) -> Result<(), anyhow::Error> {
154 if requested_model == "default" {
156 return Ok(());
157 }
158
159 let available_models = state
160 .list_models()
161 .map_err(|e| anyhow::anyhow!("Failed to get available models: {}", e))?;
162
163 if available_models.is_empty() {
164 anyhow::bail!("No models are currently loaded.");
165 }
166
167 if !available_models.contains(&requested_model.to_string()) {
168 anyhow::bail!(
169 "Requested model '{}' is not available. Available models: {}. Use 'default' to use the default model.",
170 requested_model,
171 available_models.join(", ")
172 );
173 }
174 Ok(())
175}
176
177pub fn sanitize_error_message(error: &(dyn Error + 'static)) -> String {
207 let mut current: &dyn Error = error;
209
210 while let Some(source) = current.source() {
212 current = source;
213 }
214
215 current.to_string()
217}
218
219#[cfg(test)]
220mod tests {
221 use image::GenericImageView;
222
223 use super::*;
224
225 #[tokio::test]
226 async fn test_parse_image_url() {
227 let url = "https://www.rust-lang.org/logos/rust-logo-32x32.png";
229 let image = parse_image_url(url).await.unwrap();
230 assert_eq!(image.dimensions(), (32, 32));
231
232 let url = "http://www.rust-lang.org/logos/rust-logo-32x32.png";
233 let image = parse_image_url(url).await.unwrap();
234 assert_eq!(image.dimensions(), (32, 32));
235
236 let url = "resources/rust-logo-32x32.png";
238 let image = parse_image_url(url).await.unwrap();
239 assert_eq!(image.dimensions(), (32, 32));
240
241 let absolute_path = std::path::absolute(url).unwrap();
243 let url = format!("file://{}", absolute_path.as_os_str().to_str().unwrap());
244 let image = parse_image_url(&url).await.unwrap();
245 assert_eq!(image.dimensions(), (32, 32));
246
247 let url = "
249 iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAYAAABzenr0AAAHhElEQVR4AZXVA5Aky9bA8f/JzKrq
250 npleX9u2bdu2bdu2bdv29z1e2zZm7k5PT3dXZeZ56I6J3o03sbG/iDLOSQuT6fptF11OREYDj4uR
251 i9UmAAdHa9ZH6QP+n8kg1+26HJMU44rG+4OjL3YqCv+693HOwcHiTJeYY2NUch/PLI3sOdZY82lU
252 Xbynp3yzEXMH8CCTINfuujzDEXQlVN9sju8/uFHPTy2KWLVpWsl9ZGQCvY2AF0ulu0RTBRHIi1AV
253 iZU0sSd0dWWXZKVsUeAVhiFX7roCwzGDA9rXV6uaqH/YcmnmPEQGg4IYIoLAYRHRABcaQIGuNMVa
254 IS98tZnnjOxJK4AwDDlzs0XoNGUmlWDsPr/98ucLIerrPVlCI8KAWMAQAYWXo8rKipyuMDewuaAv
255 g6wMgEa6M0dX6ugdqOPQxSs96WqlcukqoEoHuWiHZelki3yF/vHVV0OhdCUJfzZyQlYiiPlR4RxV
256 bgKqAbNthDto2Q64U6ACbAicKzCtAON6Uqr1HAk5XYlZEXiNDnLaBgvQxqiSzPdLX70PNT9U/pN9
257 0xNdSjT2UoXjJ84+x6ygwMQ/bSdyOnCgamSqSpmBepOY53OliXHAh7TJsesuCMBMU/XM/+dvve/9
258 PhgYl2X8Xi8IWZkobAg8xuQjx24L3KEamaY7oX/8IDZ6ukZkCwDvA8gpGy1EG9Vq44fRpXTa3oZv
259 BVeIQERQQBFUQQGE4frWj+3hdyxQtei2oHe4UDB1KvxWL34EpqPNLjzdWKYZXVqpr3fgdDV2QSJZ
260 A4M3loC0gqu0ggsgrXMQhlEBlgR2Au6OyF+AWby4hbvU4xVtRF2x7OQ7a+QbOWKN+Rjp4lF/NOLZ
261 o0sZvw96MIJPM6IYVEFFAFrnTEhF6CSqdHgaWEeEzQXuc9EzlYv8VPdkwtHAOS4P8Fsw52A40Mc4
262 rRp5ICKzR2WhCC8hsrgqFaWlXRPfCfJtRIxCVGQWYFoAERCU9rY2AKqXAO/7qHFA7YIi4ccczgFw
263 U490G/7WV7/KZdm0/YVHxBwcka2jyEKI7K3KdQorarvqI4aIXAWcRQcv5ixBjgZFVBEUg2KJiyFM
264 i+p2EpWB3L+UJHbamPsfMo37uFoj/8RjKqkRmiqAfKcioPKjwoCCjWKIGALtBERmi5glFHGAgswC
265 ur4AoEO1YDReAvITaNVIfInEVWOzoJwaBqGSwCeuVucTMebUIsTzBCHCD9G63dH4tCh4m5qIELAE
266 ERRDRHbT/24AAhMch5rgqSDm4AIARpS1uZ3k4SqLEsUCnFoX84nLupPLaoN+/8xZ6kUBYr82wT9N
267 W7AlghgChnbwdlMICCgCgMLQmaiAcDAdqtJ1R0+ie4VQrNCFosh5TpjJSe6fVGWnqFrx/2Nwe7G0
268 233oqNIxL9D5iSIIiCLwenu8VwEy9ad4cTNL9AQMXqlaa/7uxqt7SiQcVCvijQGDUR1Nh4jRdvt3
269 BJdjgLPbISsKVwHbgSBA66gVgY2A252GSlQ90eRNECEP4MUd5AO3u5Els2Gtrjd2p5a81iSYZB7v
270 ssUKl6UBm0dkRECI0k4Ag8LsiiyhkAHvANsDGwIVBQRoH/cW+CiKORBtt70oYi1i/I2p8LsL8BLW
271 3FHLwwZZYkcMBlDM68rQsEMZCk4EFNkN2A0EhTOB44D3gWWYgC4HvB4RsI5gLJlEGj72q5pHrY1f
272 mmpuqiD7RB+lK7UYUWIMRCxDHU4mCA5IR/tT0PIcbQoTvBMR1BfEEEjTlIZXKaVm32DcByYYR0Y4
273 0ahWvIIRxWiEULSDT9zZBGUChpZrAIZLQsWAS4gKZXzFKCcaBWMUSNypwNq1PL7dlbZe0qgovK7I
274 i4q8oPCywl//szG08TrwZscquF773kvACwovgbwOhugjXaljoFl8Z4ysHdFTI4qLKOO9q46o8Jei
275 VswWFMoOGj6iToyKfKKwIMgTAmcJyvB48j9bRM4DlqaVwPr4BiUTEAzdJowyxvwlQhXARQSAclc2
276 a+H101rD10ZWulbsq4GE5qKOdOoc+5kaORM4i0lQpAIcDnyIcjC+USEGxpSgVq9+4JKskZg4a3v0
277 IPussxSdTGNw2jrJD7Zcpmg2iaKr9LrRqMpLWLsE8DrDQ1UXA15HZDppDq5p0Zum6TbUBmq4LJsO
278 +JEOro6j0xQjyrMzWNCs1/4Y210a21+El0blvdU+NwZCeFGNuQGRe4APgCotFWA+YtwK1d3QiAb/
279 j9EujBmXKL9U69UscRUncfaJE5Dd112G4RT1hmZpQuYM3+UJRYBoE8QYMA40gmrrKIKGCNaSaMHU
280 iQfvaeYBQBiG7LTKEgxndI/przbii001lR7HqnVXphmU3EdCFHLjEI1ojKQWyonQI4pGT72Rv2OE
281 r7qtrAaMYBiy1xpLMClSTk/Nm/6EZtAHUms3y1BKzvCHZESEMVqnXkRyHwjIAxbdLEnsqcBJTILz
282 xjApU46pnBe8fwF4pb+/8Ywv/DK9zbCKsfWXUBhf+A1dOX00S+xfgc3L3dmKWSn7iklDjthxbSaH
283 c7YCVIAfi6JYn5bHjTHTGmurQJXJ8C/um928G9zK4gAAAABJRU5ErkJggg==
284 ";
285
286 let url = format!("data:image/png;base64,{url}");
287 let image = parse_image_url(&url).await.unwrap();
288 assert_eq!(image.dimensions(), (32, 32));
289
290 let audio_b64 = "UklGRiYAAABXQVZFZm10IBAAAAABAAEAQB8AAIA+AAACABAAZGF0YQIAAAAAAA==";
292 let url = format!("data:audio/wav;base64,{audio_b64}");
293 let audio = parse_audio_url(&url).await.unwrap();
294 assert_eq!(audio.sample_rate, 8000);
295 assert_eq!(audio.samples.len(), 1);
296 }
297
298 #[test]
299 fn test_sanitize_error_message_with_backtrace() {
300 let error_with_backtrace = "Failed to parse Forge Provider response: A weight is negative, too large or not a valid number
302 0: candle_core::error::Error::bt
303 1: mistralrs_core::sampler::Sampler::sample_multinomial
304 2: mistralrs_core::sampler::Sampler::sample_top_kp_min_p
305 3: mistralrs_core::sampler::Sampler::sample
306 4: mistralrs_core::pipeline::sampling::sample_sequence::{{closure}}";
307
308 struct TestError(String);
309 impl std::fmt::Display for TestError {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 write!(f, "{}", self.0)
312 }
313 }
314 impl std::fmt::Debug for TestError {
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 write!(f, "{}", self.0)
317 }
318 }
319 impl std::error::Error for TestError {}
320
321 let error = TestError(error_with_backtrace.to_string());
322 let sanitized = sanitize_error_message(&error);
323
324 assert_eq!(sanitized, error_with_backtrace);
326 }
328
329 #[test]
330 fn test_sanitize_error_message_without_backtrace() {
331 let simple_error = "Simple error message without backtrace";
333
334 struct TestError(String);
335 impl std::fmt::Display for TestError {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 write!(f, "{}", self.0)
338 }
339 }
340 impl std::fmt::Debug for TestError {
341 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342 write!(f, "{}", self.0)
343 }
344 }
345 impl std::error::Error for TestError {}
346
347 let error = TestError(simple_error.to_string());
348 let sanitized = sanitize_error_message(&error);
349
350 assert_eq!(sanitized, simple_error);
351 }
352
353 #[test]
354 fn test_sanitize_error_message_with_chain() {
355 use std::fmt;
357
358 #[derive(Debug)]
359 struct RootError;
360 impl fmt::Display for RootError {
361 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362 write!(f, "Root cause: Database connection failed")
363 }
364 }
365 impl std::error::Error for RootError {}
366
367 #[derive(Debug)]
368 struct MiddleError(Box<dyn std::error::Error>);
369 impl fmt::Display for MiddleError {
370 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
371 write!(f, "Middle error: Service unavailable")
372 }
373 }
374 impl std::error::Error for MiddleError {
375 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
376 Some(&*self.0)
377 }
378 }
379
380 #[derive(Debug)]
381 struct TopError(Box<dyn std::error::Error>);
382 impl fmt::Display for TopError {
383 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
384 write!(
385 f,
386 "Top error: Request failed with backtrace\n 0: some::module::function"
387 )
388 }
389 }
390 impl std::error::Error for TopError {
391 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
392 Some(&*self.0)
393 }
394 }
395
396 let root = RootError;
397 let middle = MiddleError(Box::new(root));
398 let top = TopError(Box::new(middle));
399
400 let sanitized = sanitize_error_message(&top);
401
402 assert_eq!(sanitized, "Root cause: Database connection failed");
404 assert!(!sanitized.contains("backtrace"));
405 assert!(!sanitized.contains("Request failed"));
406 }
407}