mistralrs_server_core/
util.rs

1//! ## General utilities.
2
3use image::DynamicImage;
4use mistralrs_core::AudioInput;
5use mistralrs_core::MistralRs;
6use std::error::Error;
7use std::sync::Arc;
8use tokio::{
9    fs::{self, File},
10    io::AsyncReadExt,
11};
12
13/// Parses and loads an image from a URL, file path, or data URL.
14///
15/// This function accepts various input formats and attempts to parse them in order:
16/// 1. First tries to parse as a complete URL (http/https/file/data schemes)
17/// 2. If that fails, checks if it's a local file path and converts to file URL
18/// 3. Finally falls back to treating it as a malformed URL and returns an error
19///
20/// ### Arguments
21///
22/// * `url_unparsed` - A string that can be:
23///   - An HTTP/HTTPS URL (e.g., "<https://example.com/image.png>")
24///   - A file path (e.g., "/path/to/image.jpg" or "image.png")
25///   - A data URL with base64 encoded image (e.g., "data:image/png;base64,...")
26///   - A file URL (e.g., "file:///path/to/image.jpg")
27///
28/// ### Examples
29///
30/// ```ignore
31/// use mistralrs_server_core::util::parse_image_url;
32///
33/// // Load from HTTP URL
34/// let image = parse_image_url("https://example.com/photo.jpg").await?;
35///
36/// // Load from local file path
37/// let image = parse_image_url("./assets/logo.png").await?;
38///
39/// // Load from data URL
40/// let image = parse_image_url("...").await?;
41///
42/// // Load from file URL
43/// let image = parse_image_url("file:///home/user/picture.jpg").await?;
44/// ```
45pub async fn parse_image_url(url_unparsed: &str) -> Result<DynamicImage, anyhow::Error> {
46    let url = if let Ok(url) = url::Url::parse(url_unparsed) {
47        url
48    } else if File::open(url_unparsed).await.is_ok() {
49        url::Url::from_file_path(std::path::absolute(url_unparsed)?)
50            .map_err(|_| anyhow::anyhow!("Could not parse file path: {}", url_unparsed))?
51    } else {
52        anyhow::bail!(
53            "Invalid source '{}': not a valid URL (http/https/data) and file not found on server. \
54             Use a full URL, a data URL, or an absolute file path that exists on the server.",
55            url_unparsed
56        )
57    };
58
59    let bytes = if url.scheme() == "http" || url.scheme() == "https" {
60        // Read from http
61        match reqwest::get(url.clone()).await {
62            Ok(http_resp) => http_resp.bytes().await?.to_vec(),
63            Err(e) => anyhow::bail!(e),
64        }
65    } else if url.scheme() == "file" {
66        let path = url
67            .to_file_path()
68            .map_err(|_| anyhow::anyhow!("Could not parse file path: {}", url))?;
69
70        if let Ok(mut f) = File::open(&path).await {
71            // Read from local file
72            let metadata = fs::metadata(&path).await?;
73            let mut buffer = vec![0; metadata.len() as usize];
74            f.read_exact(&mut buffer).await?;
75            buffer
76        } else {
77            anyhow::bail!("Could not open file at path: {}", url);
78        }
79    } else if url.scheme() == "data" {
80        // Decode with base64
81        let data_url = data_url::DataUrl::process(url.as_str())?;
82        data_url.decode_to_vec()?.0
83    } else {
84        anyhow::bail!("Unsupported URL scheme: {}", url.scheme());
85    };
86
87    Ok(image::load_from_memory(&bytes)?)
88}
89
90/// Parses and loads an audio file from a URL, file path, or data URL.
91pub async fn parse_audio_url(url_unparsed: &str) -> Result<AudioInput, anyhow::Error> {
92    let url = if let Ok(url) = url::Url::parse(url_unparsed) {
93        url
94    } else if File::open(url_unparsed).await.is_ok() {
95        url::Url::from_file_path(std::path::absolute(url_unparsed)?)
96            .map_err(|_| anyhow::anyhow!("Could not parse file path: {}", url_unparsed))?
97    } else {
98        anyhow::bail!(
99            "Invalid source '{}': not a valid URL (http/https/data) and file not found on server. \
100             Use a full URL, a data URL, or an absolute file path that exists on the server.",
101            url_unparsed
102        )
103    };
104
105    let bytes = if url.scheme() == "http" || url.scheme() == "https" {
106        match reqwest::get(url.clone()).await {
107            Ok(http_resp) => http_resp.bytes().await?.to_vec(),
108            Err(e) => anyhow::bail!(e),
109        }
110    } else if url.scheme() == "file" {
111        let path = url
112            .to_file_path()
113            .map_err(|_| anyhow::anyhow!("Could not parse file path: {}", url))?;
114
115        if let Ok(mut f) = File::open(&path).await {
116            let metadata = fs::metadata(&path).await?;
117            let mut buffer = vec![0; metadata.len() as usize];
118            f.read_exact(&mut buffer).await?;
119            buffer
120        } else {
121            anyhow::bail!("Could not open file at path: {}", url);
122        }
123    } else if url.scheme() == "data" {
124        let data_url = data_url::DataUrl::process(url.as_str())?;
125        data_url.decode_to_vec()?.0
126    } else {
127        anyhow::bail!("Unsupported URL scheme: {}", url.scheme());
128    };
129
130    AudioInput::from_bytes(&bytes)
131}
132
133/// Validates that the requested model matches one of the loaded models.
134///
135/// This function checks if the model parameter from an OpenAI API request
136/// matches one of the models that are currently loaded by the server.
137///
138/// The special model name "default" can be used to bypass this validation,
139/// which is useful for clients that require a model parameter but want
140/// to use the default model.
141///
142/// ### Arguments
143///
144/// * `requested_model` - The model name from the API request
145/// * `state` - The MistralRs state containing the loaded models info
146///
147/// ### Returns
148///
149/// Returns `Ok(())` if the model is available or if "default" is specified, otherwise returns an error.
150pub fn validate_model_name(
151    requested_model: &str,
152    state: Arc<MistralRs>,
153) -> Result<(), anyhow::Error> {
154    // Allow "default" as a special case to bypass validation
155    if requested_model == "default" {
156        return Ok(());
157    }
158
159    let available_models = state
160        .list_models()
161        .map_err(|e| anyhow::anyhow!("Failed to get available models: {}", e))?;
162
163    if available_models.is_empty() {
164        anyhow::bail!("No models are currently loaded.");
165    }
166
167    if !available_models.contains(&requested_model.to_string()) {
168        anyhow::bail!(
169            "Requested model '{}' is not available. Available models: {}. Use 'default' to use the default model.",
170            requested_model,
171            available_models.join(", ")
172        );
173    }
174    Ok(())
175}
176
177/// Sanitize error messages to remove internal implementation details like stack traces.
178/// This ensures that sensitive internal information is not exposed to API clients.
179///
180/// The function traverses the error chain to find the deepest (root) error and returns its message.
181/// This is useful for API responses where we want to provide meaningful error information
182/// without exposing internal stack traces or implementation details.
183///
184/// ### Arguments
185///
186/// * `error` - The error to sanitize
187///
188/// ### Returns
189///
190/// The message from the root cause error in the error chain
191///
192/// ### Examples
193///
194/// ```ignore
195/// use mistralrs_server_core::util::sanitize_error_message;
196///
197/// // For a simple error without chain
198/// let error = std::io::Error::new(std::io::ErrorKind::NotFound, "File not found");
199/// assert_eq!(sanitize_error_message(&error), "File not found");
200///
201/// // For chained errors, returns the root cause
202/// let root = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "Access denied");
203/// let wrapped = anyhow::Error::new(root).context("Failed to read file");
204/// // This would return "Access denied" instead of "Failed to read file"
205/// ```
206pub fn sanitize_error_message(error: &(dyn Error + 'static)) -> String {
207    // Traverse the error chain to find the deepest (root) error and return its message.
208    let mut current: &dyn Error = error;
209
210    // Keep traversing until we find an error with no source
211    while let Some(source) = current.source() {
212        current = source;
213    }
214
215    // Return the message of the root cause error
216    current.to_string()
217}
218
219#[cfg(test)]
220mod tests {
221    use image::GenericImageView;
222
223    use super::*;
224
225    #[tokio::test]
226    async fn test_parse_image_url() {
227        // from URL
228        let url = "https://www.rust-lang.org/logos/rust-logo-32x32.png";
229        let image = parse_image_url(url).await.unwrap();
230        assert_eq!(image.dimensions(), (32, 32));
231
232        let url = "http://www.rust-lang.org/logos/rust-logo-32x32.png";
233        let image = parse_image_url(url).await.unwrap();
234        assert_eq!(image.dimensions(), (32, 32));
235
236        // from file path
237        let url = "resources/rust-logo-32x32.png";
238        let image = parse_image_url(url).await.unwrap();
239        assert_eq!(image.dimensions(), (32, 32));
240
241        // URL must be an absolute path
242        let absolute_path = std::path::absolute(url).unwrap();
243        let url = format!("file://{}", absolute_path.as_os_str().to_str().unwrap());
244        let image = parse_image_url(&url).await.unwrap();
245        assert_eq!(image.dimensions(), (32, 32));
246
247        // from base64 encoded image (rust-logo-32x32.png)
248        let url = "
249        iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAYAAABzenr0AAAHhElEQVR4AZXVA5Aky9bA8f/JzKrq
250        npleX9u2bdu2bdu2bdv29z1e2zZm7k5PT3dXZeZ56I6J3o03sbG/iDLOSQuT6fptF11OREYDj4uR
251        i9UmAAdHa9ZH6QP+n8kg1+26HJMU44rG+4OjL3YqCv+693HOwcHiTJeYY2NUch/PLI3sOdZY82lU
252        Xbynp3yzEXMH8CCTINfuujzDEXQlVN9sju8/uFHPTy2KWLVpWsl9ZGQCvY2AF0ulu0RTBRHIi1AV
253        iZU0sSd0dWWXZKVsUeAVhiFX7roCwzGDA9rXV6uaqH/YcmnmPEQGg4IYIoLAYRHRABcaQIGuNMVa
254        IS98tZnnjOxJK4AwDDlzs0XoNGUmlWDsPr/98ucLIerrPVlCI8KAWMAQAYWXo8rKipyuMDewuaAv
255        g6wMgEa6M0dX6ugdqOPQxSs96WqlcukqoEoHuWiHZelki3yF/vHVV0OhdCUJfzZyQlYiiPlR4RxV
256        bgKqAbNthDto2Q64U6ACbAicKzCtAON6Uqr1HAk5XYlZEXiNDnLaBgvQxqiSzPdLX70PNT9U/pN9
257        0xNdSjT2UoXjJ84+x6ygwMQ/bSdyOnCgamSqSpmBepOY53OliXHAh7TJsesuCMBMU/XM/+dvve/9
258        PhgYl2X8Xi8IWZkobAg8xuQjx24L3KEamaY7oX/8IDZ6ukZkCwDvA8gpGy1EG9Vq44fRpXTa3oZv
259        BVeIQERQQBFUQQGE4frWj+3hdyxQtei2oHe4UDB1KvxWL34EpqPNLjzdWKYZXVqpr3fgdDV2QSJZ
260        A4M3loC0gqu0ggsgrXMQhlEBlgR2Au6OyF+AWby4hbvU4xVtRF2x7OQ7a+QbOWKN+Rjp4lF/NOLZ
261        o0sZvw96MIJPM6IYVEFFAFrnTEhF6CSqdHgaWEeEzQXuc9EzlYv8VPdkwtHAOS4P8Fsw52A40Mc4
262        rRp5ICKzR2WhCC8hsrgqFaWlXRPfCfJtRIxCVGQWYFoAERCU9rY2AKqXAO/7qHFA7YIi4ccczgFw
263        U490G/7WV7/KZdm0/YVHxBwcka2jyEKI7K3KdQorarvqI4aIXAWcRQcv5ixBjgZFVBEUg2KJiyFM
264        i+p2EpWB3L+UJHbamPsfMo37uFoj/8RjKqkRmiqAfKcioPKjwoCCjWKIGALtBERmi5glFHGAgswC
265        ur4AoEO1YDReAvITaNVIfInEVWOzoJwaBqGSwCeuVucTMebUIsTzBCHCD9G63dH4tCh4m5qIELAE
266        ERRDRHbT/24AAhMch5rgqSDm4AIARpS1uZ3k4SqLEsUCnFoX84nLupPLaoN+/8xZ6kUBYr82wT9N
267        W7AlghgChnbwdlMICCgCgMLQmaiAcDAdqtJ1R0+ie4VQrNCFosh5TpjJSe6fVGWnqFrx/2Nwe7G0
268        233oqNIxL9D5iSIIiCLwenu8VwEy9ad4cTNL9AQMXqlaa/7uxqt7SiQcVCvijQGDUR1Nh4jRdvt3
269        BJdjgLPbISsKVwHbgSBA66gVgY2A252GSlQ90eRNECEP4MUd5AO3u5Els2Gtrjd2p5a81iSYZB7v
270        ssUKl6UBm0dkRECI0k4Ag8LsiiyhkAHvANsDGwIVBQRoH/cW+CiKORBtt70oYi1i/I2p8LsL8BLW
271        3FHLwwZZYkcMBlDM68rQsEMZCk4EFNkN2A0EhTOB44D3gWWYgC4HvB4RsI5gLJlEGj72q5pHrY1f
272        mmpuqiD7RB+lK7UYUWIMRCxDHU4mCA5IR/tT0PIcbQoTvBMR1BfEEEjTlIZXKaVm32DcByYYR0Y4
273        0ahWvIIRxWiEULSDT9zZBGUChpZrAIZLQsWAS4gKZXzFKCcaBWMUSNypwNq1PL7dlbZe0qgovK7I
274        i4q8oPCywl//szG08TrwZscquF773kvACwovgbwOhugjXaljoFl8Z4ysHdFTI4qLKOO9q46o8Jei
275        VswWFMoOGj6iToyKfKKwIMgTAmcJyvB48j9bRM4DlqaVwPr4BiUTEAzdJowyxvwlQhXARQSAclc2
276        a+H101rD10ZWulbsq4GE5qKOdOoc+5kaORM4i0lQpAIcDnyIcjC+USEGxpSgVq9+4JKskZg4a3v0
277        IPussxSdTGNw2jrJD7Zcpmg2iaKr9LrRqMpLWLsE8DrDQ1UXA15HZDppDq5p0Zum6TbUBmq4LJsO
278        +JEOro6j0xQjyrMzWNCs1/4Y210a21+El0blvdU+NwZCeFGNuQGRe4APgCotFWA+YtwK1d3QiAb/
279        j9EujBmXKL9U69UscRUncfaJE5Dd112G4RT1hmZpQuYM3+UJRYBoE8QYMA40gmrrKIKGCNaSaMHU
280        iQfvaeYBQBiG7LTKEgxndI/przbii001lR7HqnVXphmU3EdCFHLjEI1ojKQWyonQI4pGT72Rv2OE
281        r7qtrAaMYBiy1xpLMClSTk/Nm/6EZtAHUms3y1BKzvCHZESEMVqnXkRyHwjIAxbdLEnsqcBJTILz
282        xjApU46pnBe8fwF4pb+/8Ywv/DK9zbCKsfWXUBhf+A1dOX00S+xfgc3L3dmKWSn7iklDjthxbSaH
283        c7YCVIAfi6JYn5bHjTHTGmurQJXJ8C/um928G9zK4gAAAABJRU5ErkJggg==
284        ";
285
286        let url = format!("data:image/png;base64,{url}");
287        let image = parse_image_url(&url).await.unwrap();
288        assert_eq!(image.dimensions(), (32, 32));
289
290        // audio from base64
291        let audio_b64 = "UklGRiYAAABXQVZFZm10IBAAAAABAAEAQB8AAIA+AAACABAAZGF0YQIAAAAAAA==";
292        let url = format!("data:audio/wav;base64,{audio_b64}");
293        let audio = parse_audio_url(&url).await.unwrap();
294        assert_eq!(audio.sample_rate, 8000);
295        assert_eq!(audio.samples.len(), 1);
296    }
297
298    #[test]
299    fn test_sanitize_error_message_with_backtrace() {
300        // Test error with backtrace
301        let error_with_backtrace = "Failed to parse Forge Provider response: A weight is negative, too large or not a valid number
302  0: candle_core::error::Error::bt
303  1: mistralrs_core::sampler::Sampler::sample_multinomial
304  2: mistralrs_core::sampler::Sampler::sample_top_kp_min_p
305  3: mistralrs_core::sampler::Sampler::sample
306  4: mistralrs_core::pipeline::sampling::sample_sequence::{{closure}}";
307
308        struct TestError(String);
309        impl std::fmt::Display for TestError {
310            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311                write!(f, "{}", self.0)
312            }
313        }
314        impl std::fmt::Debug for TestError {
315            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316                write!(f, "{}", self.0)
317            }
318        }
319        impl std::error::Error for TestError {}
320
321        let error = TestError(error_with_backtrace.to_string());
322        let sanitized = sanitize_error_message(&error);
323
324        // Since TestError has no source(), it should return the full message including backtrace
325        assert_eq!(sanitized, error_with_backtrace);
326        // The improved solution returns the root error as-is when there's no error chain
327    }
328
329    #[test]
330    fn test_sanitize_error_message_without_backtrace() {
331        // Test error without backtrace
332        let simple_error = "Simple error message without backtrace";
333
334        struct TestError(String);
335        impl std::fmt::Display for TestError {
336            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337                write!(f, "{}", self.0)
338            }
339        }
340        impl std::fmt::Debug for TestError {
341            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342                write!(f, "{}", self.0)
343            }
344        }
345        impl std::error::Error for TestError {}
346
347        let error = TestError(simple_error.to_string());
348        let sanitized = sanitize_error_message(&error);
349
350        assert_eq!(sanitized, simple_error);
351    }
352
353    #[test]
354    fn test_sanitize_error_message_with_chain() {
355        // Test error chain - the root cause should be extracted
356        use std::fmt;
357
358        #[derive(Debug)]
359        struct RootError;
360        impl fmt::Display for RootError {
361            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362                write!(f, "Root cause: Database connection failed")
363            }
364        }
365        impl std::error::Error for RootError {}
366
367        #[derive(Debug)]
368        struct MiddleError(Box<dyn std::error::Error>);
369        impl fmt::Display for MiddleError {
370            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
371                write!(f, "Middle error: Service unavailable")
372            }
373        }
374        impl std::error::Error for MiddleError {
375            fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
376                Some(&*self.0)
377            }
378        }
379
380        #[derive(Debug)]
381        struct TopError(Box<dyn std::error::Error>);
382        impl fmt::Display for TopError {
383            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
384                write!(
385                    f,
386                    "Top error: Request failed with backtrace\n  0: some::module::function"
387                )
388            }
389        }
390        impl std::error::Error for TopError {
391            fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
392                Some(&*self.0)
393            }
394        }
395
396        let root = RootError;
397        let middle = MiddleError(Box::new(root));
398        let top = TopError(Box::new(middle));
399
400        let sanitized = sanitize_error_message(&top);
401
402        // Should return the root cause, not the top-level error with backtrace
403        assert_eq!(sanitized, "Root cause: Database connection failed");
404        assert!(!sanitized.contains("backtrace"));
405        assert!(!sanitized.contains("Request failed"));
406    }
407}