mistralrs_server_core/
util.rs

1//! ## General utilities.
2
3use image::DynamicImage;
4use mistralrs_core::AudioInput;
5use mistralrs_core::MistralRs;
6use std::sync::Arc;
7use tokio::{
8    fs::{self, File},
9    io::AsyncReadExt,
10};
11
12/// Parses and loads an image from a URL, file path, or data URL.
13///
14/// This function accepts various input formats and attempts to parse them in order:
15/// 1. First tries to parse as a complete URL (http/https/file/data schemes)
16/// 2. If that fails, checks if it's a local file path and converts to file URL
17/// 3. Finally falls back to treating it as a malformed URL and returns an error
18///
19/// ### Arguments
20///
21/// * `url_unparsed` - A string that can be:
22///   - An HTTP/HTTPS URL (e.g., "<https://example.com/image.png>")
23///   - A file path (e.g., "/path/to/image.jpg" or "image.png")
24///   - A data URL with base64 encoded image (e.g., "data:image/png;base64,...")
25///   - A file URL (e.g., "file:///path/to/image.jpg")
26///
27/// ### Examples
28///
29/// ```ignore
30/// use mistralrs_server_core::util::parse_image_url;
31///
32/// // Load from HTTP URL
33/// let image = parse_image_url("https://example.com/photo.jpg").await?;
34///
35/// // Load from local file path
36/// let image = parse_image_url("./assets/logo.png").await?;
37///
38/// // Load from data URL
39/// let image = parse_image_url("...").await?;
40///
41/// // Load from file URL
42/// let image = parse_image_url("file:///home/user/picture.jpg").await?;
43/// ```
44pub async fn parse_image_url(url_unparsed: &str) -> Result<DynamicImage, anyhow::Error> {
45    let url = if let Ok(url) = url::Url::parse(url_unparsed) {
46        url
47    } else if File::open(url_unparsed).await.is_ok() {
48        url::Url::from_file_path(std::path::absolute(url_unparsed)?)
49            .map_err(|_| anyhow::anyhow!("Could not parse file path: {}", url_unparsed))?
50    } else {
51        url::Url::parse(url_unparsed)
52            .map_err(|_| anyhow::anyhow!("Could not parse as base64 data: {}", url_unparsed))?
53    };
54
55    let bytes = if url.scheme() == "http" || url.scheme() == "https" {
56        // Read from http
57        match reqwest::get(url.clone()).await {
58            Ok(http_resp) => http_resp.bytes().await?.to_vec(),
59            Err(e) => anyhow::bail!(e),
60        }
61    } else if url.scheme() == "file" {
62        let path = url
63            .to_file_path()
64            .map_err(|_| anyhow::anyhow!("Could not parse file path: {}", url))?;
65
66        if let Ok(mut f) = File::open(&path).await {
67            // Read from local file
68            let metadata = fs::metadata(&path).await?;
69            let mut buffer = vec![0; metadata.len() as usize];
70            f.read_exact(&mut buffer).await?;
71            buffer
72        } else {
73            anyhow::bail!("Could not open file at path: {}", url);
74        }
75    } else if url.scheme() == "data" {
76        // Decode with base64
77        let data_url = data_url::DataUrl::process(url.as_str())?;
78        data_url.decode_to_vec()?.0
79    } else {
80        anyhow::bail!("Unsupported URL scheme: {}", url.scheme());
81    };
82
83    Ok(image::load_from_memory(&bytes)?)
84}
85
86/// Parses and loads an audio file from a URL, file path, or data URL.
87pub async fn parse_audio_url(url_unparsed: &str) -> Result<AudioInput, anyhow::Error> {
88    let url = if let Ok(url) = url::Url::parse(url_unparsed) {
89        url
90    } else if File::open(url_unparsed).await.is_ok() {
91        url::Url::from_file_path(std::path::absolute(url_unparsed)?)
92            .map_err(|_| anyhow::anyhow!("Could not parse file path: {}", url_unparsed))?
93    } else {
94        url::Url::parse(url_unparsed)
95            .map_err(|_| anyhow::anyhow!("Could not parse as base64 data: {}", url_unparsed))?
96    };
97
98    let bytes = if url.scheme() == "http" || url.scheme() == "https" {
99        match reqwest::get(url.clone()).await {
100            Ok(http_resp) => http_resp.bytes().await?.to_vec(),
101            Err(e) => anyhow::bail!(e),
102        }
103    } else if url.scheme() == "file" {
104        let path = url
105            .to_file_path()
106            .map_err(|_| anyhow::anyhow!("Could not parse file path: {}", url))?;
107
108        if let Ok(mut f) = File::open(&path).await {
109            let metadata = fs::metadata(&path).await?;
110            let mut buffer = vec![0; metadata.len() as usize];
111            f.read_exact(&mut buffer).await?;
112            buffer
113        } else {
114            anyhow::bail!("Could not open file at path: {}", url);
115        }
116    } else if url.scheme() == "data" {
117        let data_url = data_url::DataUrl::process(url.as_str())?;
118        data_url.decode_to_vec()?.0
119    } else {
120        anyhow::bail!("Unsupported URL scheme: {}", url.scheme());
121    };
122
123    AudioInput::from_bytes(&bytes)
124}
125
126/// Validates that the requested model matches one of the loaded models.
127///
128/// This function checks if the model parameter from an OpenAI API request
129/// matches one of the models that are currently loaded by the server.
130///
131/// The special model name "default" can be used to bypass this validation,
132/// which is useful for clients that require a model parameter but want
133/// to use the default model.
134///
135/// ### Arguments
136///
137/// * `requested_model` - The model name from the API request
138/// * `state` - The MistralRs state containing the loaded models info
139///
140/// ### Returns
141///
142/// Returns `Ok(())` if the model is available or if "default" is specified, otherwise returns an error.
143pub fn validate_model_name(
144    requested_model: &str,
145    state: Arc<MistralRs>,
146) -> Result<(), anyhow::Error> {
147    // Allow "default" as a special case to bypass validation
148    if requested_model == "default" {
149        return Ok(());
150    }
151
152    let available_models = state
153        .list_models()
154        .map_err(|e| anyhow::anyhow!("Failed to get available models: {}", e))?;
155
156    if available_models.is_empty() {
157        anyhow::bail!("No models are currently loaded.");
158    }
159
160    if !available_models.contains(&requested_model.to_string()) {
161        anyhow::bail!(
162            "Requested model '{}' is not available. Available models: {}. Use 'default' to use the default model.",
163            requested_model,
164            available_models.join(", ")
165        );
166    }
167    Ok(())
168}
169
170#[cfg(test)]
171mod tests {
172    use image::GenericImageView;
173
174    use super::*;
175
176    #[tokio::test]
177    async fn test_parse_image_url() {
178        // from URL
179        let url = "https://www.rust-lang.org/logos/rust-logo-32x32.png";
180        let image = parse_image_url(url).await.unwrap();
181        assert_eq!(image.dimensions(), (32, 32));
182
183        let url = "http://www.rust-lang.org/logos/rust-logo-32x32.png";
184        let image = parse_image_url(url).await.unwrap();
185        assert_eq!(image.dimensions(), (32, 32));
186
187        // from file path
188        let url = "resources/rust-logo-32x32.png";
189        let image = parse_image_url(url).await.unwrap();
190        assert_eq!(image.dimensions(), (32, 32));
191
192        // URL must be an absolute path
193        let absolute_path = std::path::absolute(url).unwrap();
194        let url = format!("file://{}", absolute_path.as_os_str().to_str().unwrap());
195        let image = parse_image_url(&url).await.unwrap();
196        assert_eq!(image.dimensions(), (32, 32));
197
198        // from base64 encoded image (rust-logo-32x32.png)
199        let url = "
200        iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAYAAABzenr0AAAHhElEQVR4AZXVA5Aky9bA8f/JzKrq
201        npleX9u2bdu2bdu2bdv29z1e2zZm7k5PT3dXZeZ56I6J3o03sbG/iDLOSQuT6fptF11OREYDj4uR
202        i9UmAAdHa9ZH6QP+n8kg1+26HJMU44rG+4OjL3YqCv+693HOwcHiTJeYY2NUch/PLI3sOdZY82lU
203        Xbynp3yzEXMH8CCTINfuujzDEXQlVN9sju8/uFHPTy2KWLVpWsl9ZGQCvY2AF0ulu0RTBRHIi1AV
204        iZU0sSd0dWWXZKVsUeAVhiFX7roCwzGDA9rXV6uaqH/YcmnmPEQGg4IYIoLAYRHRABcaQIGuNMVa
205        IS98tZnnjOxJK4AwDDlzs0XoNGUmlWDsPr/98ucLIerrPVlCI8KAWMAQAYWXo8rKipyuMDewuaAv
206        g6wMgEa6M0dX6ugdqOPQxSs96WqlcukqoEoHuWiHZelki3yF/vHVV0OhdCUJfzZyQlYiiPlR4RxV
207        bgKqAbNthDto2Q64U6ACbAicKzCtAON6Uqr1HAk5XYlZEXiNDnLaBgvQxqiSzPdLX70PNT9U/pN9
208        0xNdSjT2UoXjJ84+x6ygwMQ/bSdyOnCgamSqSpmBepOY53OliXHAh7TJsesuCMBMU/XM/+dvve/9
209        PhgYl2X8Xi8IWZkobAg8xuQjx24L3KEamaY7oX/8IDZ6ukZkCwDvA8gpGy1EG9Vq44fRpXTa3oZv
210        BVeIQERQQBFUQQGE4frWj+3hdyxQtei2oHe4UDB1KvxWL34EpqPNLjzdWKYZXVqpr3fgdDV2QSJZ
211        A4M3loC0gqu0ggsgrXMQhlEBlgR2Au6OyF+AWby4hbvU4xVtRF2x7OQ7a+QbOWKN+Rjp4lF/NOLZ
212        o0sZvw96MIJPM6IYVEFFAFrnTEhF6CSqdHgaWEeEzQXuc9EzlYv8VPdkwtHAOS4P8Fsw52A40Mc4
213        rRp5ICKzR2WhCC8hsrgqFaWlXRPfCfJtRIxCVGQWYFoAERCU9rY2AKqXAO/7qHFA7YIi4ccczgFw
214        U490G/7WV7/KZdm0/YVHxBwcka2jyEKI7K3KdQorarvqI4aIXAWcRQcv5ixBjgZFVBEUg2KJiyFM
215        i+p2EpWB3L+UJHbamPsfMo37uFoj/8RjKqkRmiqAfKcioPKjwoCCjWKIGALtBERmi5glFHGAgswC
216        ur4AoEO1YDReAvITaNVIfInEVWOzoJwaBqGSwCeuVucTMebUIsTzBCHCD9G63dH4tCh4m5qIELAE
217        ERRDRHbT/24AAhMch5rgqSDm4AIARpS1uZ3k4SqLEsUCnFoX84nLupPLaoN+/8xZ6kUBYr82wT9N
218        W7AlghgChnbwdlMICCgCgMLQmaiAcDAdqtJ1R0+ie4VQrNCFosh5TpjJSe6fVGWnqFrx/2Nwe7G0
219        233oqNIxL9D5iSIIiCLwenu8VwEy9ad4cTNL9AQMXqlaa/7uxqt7SiQcVCvijQGDUR1Nh4jRdvt3
220        BJdjgLPbISsKVwHbgSBA66gVgY2A252GSlQ90eRNECEP4MUd5AO3u5Els2Gtrjd2p5a81iSYZB7v
221        ssUKl6UBm0dkRECI0k4Ag8LsiiyhkAHvANsDGwIVBQRoH/cW+CiKORBtt70oYi1i/I2p8LsL8BLW
222        3FHLwwZZYkcMBlDM68rQsEMZCk4EFNkN2A0EhTOB44D3gWWYgC4HvB4RsI5gLJlEGj72q5pHrY1f
223        mmpuqiD7RB+lK7UYUWIMRCxDHU4mCA5IR/tT0PIcbQoTvBMR1BfEEEjTlIZXKaVm32DcByYYR0Y4
224        0ahWvIIRxWiEULSDT9zZBGUChpZrAIZLQsWAS4gKZXzFKCcaBWMUSNypwNq1PL7dlbZe0qgovK7I
225        i4q8oPCywl//szG08TrwZscquF773kvACwovgbwOhugjXaljoFl8Z4ysHdFTI4qLKOO9q46o8Jei
226        VswWFMoOGj6iToyKfKKwIMgTAmcJyvB48j9bRM4DlqaVwPr4BiUTEAzdJowyxvwlQhXARQSAclc2
227        a+H101rD10ZWulbsq4GE5qKOdOoc+5kaORM4i0lQpAIcDnyIcjC+USEGxpSgVq9+4JKskZg4a3v0
228        IPussxSdTGNw2jrJD7Zcpmg2iaKr9LrRqMpLWLsE8DrDQ1UXA15HZDppDq5p0Zum6TbUBmq4LJsO
229        +JEOro6j0xQjyrMzWNCs1/4Y210a21+El0blvdU+NwZCeFGNuQGRe4APgCotFWA+YtwK1d3QiAb/
230        j9EujBmXKL9U69UscRUncfaJE5Dd112G4RT1hmZpQuYM3+UJRYBoE8QYMA40gmrrKIKGCNaSaMHU
231        iQfvaeYBQBiG7LTKEgxndI/przbii001lR7HqnVXphmU3EdCFHLjEI1ojKQWyonQI4pGT72Rv2OE
232        r7qtrAaMYBiy1xpLMClSTk/Nm/6EZtAHUms3y1BKzvCHZESEMVqnXkRyHwjIAxbdLEnsqcBJTILz
233        xjApU46pnBe8fwF4pb+/8Ywv/DK9zbCKsfWXUBhf+A1dOX00S+xfgc3L3dmKWSn7iklDjthxbSaH
234        c7YCVIAfi6JYn5bHjTHTGmurQJXJ8C/um928G9zK4gAAAABJRU5ErkJggg==
235        ";
236
237        let url = format!("data:image/png;base64,{url}");
238        let image = parse_image_url(&url).await.unwrap();
239        assert_eq!(image.dimensions(), (32, 32));
240
241        // audio from base64
242        let audio_b64 = "UklGRiYAAABXQVZFZm10IBAAAAABAAEAQB8AAIA+AAACABAAZGF0YQIAAAAAAA==";
243        let url = format!("data:audio/wav;base64,{audio_b64}");
244        let audio = parse_audio_url(&url).await.unwrap();
245        assert_eq!(audio.sample_rate, 8000);
246        assert_eq!(audio.samples.len(), 1);
247    }
248}