mistralrs_server_core/
mistralrs_server_router_builder.rs

1//! ## mistral.rs server router builder.
2
3use anyhow::Result;
4use axum::{
5    extract::DefaultBodyLimit,
6    http::{self, Method},
7    routing::{get, post},
8    Router,
9};
10use tower_http::cors::{AllowOrigin, CorsLayer};
11use utoipa_swagger_ui::SwaggerUi;
12
13use crate::{
14    chat_completion::chatcompletions,
15    completions::completions,
16    embeddings::embeddings,
17    handlers::{health, models, re_isq},
18    image_generation::image_generation,
19    openapi_doc::get_openapi_doc,
20    responses::{create_response, delete_response, get_response},
21    speech_generation::speech_generation,
22    types::SharedMistralRsState,
23};
24
25// NOTE(EricLBuehler): Accept up to 50mb input
26const N_INPUT_SIZE: usize = 50;
27const MB_TO_B: usize = 1024 * 1024; // 1024 kb in a mb
28
29/// This is the axum default request body limit for the router. Accept up to 50mb input.
30pub const DEFAULT_MAX_BODY_LIMIT: usize = N_INPUT_SIZE * MB_TO_B;
31
32/// A builder for creating a mistral.rs server router with configurable options.
33///
34/// ### Examples
35///
36/// Basic usage:
37/// ```ignore
38/// use mistralrs_server_core::mistralrs_server_router_builder::MistralRsServerRouterBuilder;
39///
40/// let router = MistralRsServerRouterBuilder::new()
41///     .with_mistralrs(mistralrs_instance)
42///     .build()
43///     .await?;
44/// ```
45///
46/// With custom configuration:
47/// ```ignore
48/// use mistralrs_server_core::mistralrs_server_router_builder::MistralRsServerRouterBuilder;
49///
50/// let router = MistralRsServerRouterBuilder::new()
51///     .with_mistralrs(mistralrs_instance)
52///     .with_include_swagger_routes(false)
53///     .with_base_path("/api/mistral")
54///     .build()
55///     .await?;
56/// ```
57pub struct MistralRsServerRouterBuilder {
58    /// The shared mistral.rs instance
59    mistralrs: Option<SharedMistralRsState>,
60    /// Whether to include Swagger/OpenAPI documentation routes
61    include_swagger_routes: bool,
62    /// Optional base path prefix for all routes
63    base_path: Option<String>,
64    /// Optional CORS allowed origins
65    allowed_origins: Option<Vec<String>>,
66    /// Optional axum default request body limit
67    max_body_limit: Option<usize>,
68}
69
70impl Default for MistralRsServerRouterBuilder {
71    /// Creates a new builder with default configuration.
72    fn default() -> Self {
73        Self {
74            mistralrs: None,
75            include_swagger_routes: true,
76            base_path: None,
77            allowed_origins: None,
78            max_body_limit: None,
79        }
80    }
81}
82
83impl MistralRsServerRouterBuilder {
84    /// Creates a new `MistralRsServerRouterBuilder` with default settings.
85    ///
86    /// This is equivalent to calling `Default::default()`.
87    ///
88    /// ### Examples
89    ///
90    /// ```ignore
91    /// use mistralrs_server_core::mistralrs_server_router_builder::MistralRsServerRouterBuilder;
92    ///
93    /// let builder = MistralRsServerRouterBuilder::new();
94    /// ```
95    pub fn new() -> Self {
96        Default::default()
97    }
98
99    /// Sets the shared mistral.rs instance
100    pub fn with_mistralrs(mut self, mistralrs: SharedMistralRsState) -> Self {
101        self.mistralrs = Some(mistralrs);
102        self
103    }
104
105    /// Configures whether to include OpenAPI doc routes.
106    ///
107    /// When enabled (default), the router will include routes for Swagger UI
108    /// at `/docs` and the OpenAPI specification at `/api-doc/openapi.json`.
109    /// These routes respect the configured base path if one is set.
110    pub fn with_include_swagger_routes(mut self, include_swagger_routes: bool) -> Self {
111        self.include_swagger_routes = include_swagger_routes;
112        self
113    }
114
115    /// Sets a base path prefix for all routes.
116    ///
117    /// When set, all routes will be prefixed with the given path. This is
118    /// useful when including the mistral.rs server instance in another axum project.
119    pub fn with_base_path(mut self, base_path: &str) -> Self {
120        self.base_path = Some(base_path.to_owned());
121        self
122    }
123
124    /// Sets the CORS allowed origins.
125    pub fn with_allowed_origins(mut self, origins: Vec<String>) -> Self {
126        self.allowed_origins = Some(origins);
127        self
128    }
129
130    /// Sets the axum default request body limit.
131    pub fn with_max_body_limit(mut self, max_body_limit: usize) -> Self {
132        self.max_body_limit = Some(max_body_limit);
133        self
134    }
135
136    /// Builds the configured axum router.
137    ///
138    /// ### Examples
139    ///
140    /// ```ignore
141    /// use mistralrs_server_core::mistralrs_server_router_builder::MistralRsServerRouterBuilder;
142    ///
143    /// let router = MistralRsServerRouterBuilder::new()
144    ///     .with_mistralrs(mistralrs_instance)
145    ///     .build()
146    ///     .await?;
147    /// ```
148    pub async fn build(self) -> Result<Router> {
149        let mistralrs = self.mistralrs.ok_or_else(|| {
150            anyhow::anyhow!("`mistralrs` instance must be set. Use `with_mistralrs`.")
151        })?;
152
153        let mistralrs_server_router = init_router(
154            mistralrs,
155            self.include_swagger_routes,
156            self.base_path.as_deref(),
157            self.allowed_origins,
158            self.max_body_limit,
159        );
160
161        mistralrs_server_router
162    }
163}
164
165/// Initializes and configures the underlying axum router with MistralRs API endpoints.
166///
167/// This function creates a router with all the necessary API endpoints,
168/// CORS configuration, body size limits, and optional Swagger documentation.
169fn init_router(
170    state: SharedMistralRsState,
171    include_swagger_routes: bool,
172    base_path: Option<&str>,
173    allowed_origins: Option<Vec<String>>,
174    max_body_limit: Option<usize>,
175) -> Result<Router> {
176    let allow_origin = if let Some(origins) = allowed_origins {
177        let parsed_origins: Result<Vec<_>, _> = origins.into_iter().map(|o| o.parse()).collect();
178
179        match parsed_origins {
180            Ok(origins) => AllowOrigin::list(origins),
181            Err(_) => anyhow::bail!("Invalid origin format"),
182        }
183    } else {
184        AllowOrigin::any()
185    };
186
187    let router_max_body_limit = max_body_limit.unwrap_or(DEFAULT_MAX_BODY_LIMIT);
188
189    let cors_layer = CorsLayer::new()
190        .allow_methods([Method::GET, Method::POST])
191        .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
192        .allow_origin(allow_origin);
193
194    // Use the provided base path or default to ""
195    let prefix = base_path.unwrap_or("");
196
197    let mut router = Router::new()
198        .route("/v1/chat/completions", post(chatcompletions))
199        .route("/v1/completions", post(completions))
200        .route("/v1/embeddings", post(embeddings))
201        .route("/v1/models", get(models))
202        .route("/health", get(health))
203        .route("/", get(health))
204        .route("/re_isq", post(re_isq))
205        .route("/v1/images/generations", post(image_generation))
206        .route("/v1/audio/speech", post(speech_generation))
207        .route("/v1/responses", post(create_response))
208        .route(
209            "/v1/responses/{response_id}",
210            get(get_response).delete(delete_response),
211        )
212        .layer(cors_layer)
213        .layer(DefaultBodyLimit::max(router_max_body_limit))
214        .with_state(state);
215
216    if include_swagger_routes {
217        let doc = get_openapi_doc(None);
218
219        router = router.merge(
220            SwaggerUi::new(format!("{prefix}/docs"))
221                .url(format!("{prefix}/api-doc/openapi.json"), doc),
222        );
223    }
224
225    Ok(router)
226}