mistralrs_server_core/
mistralrs_server_router_builder.rs

1//! ## mistral.rs server router builder.
2
3use anyhow::Result;
4use axum::{
5    extract::DefaultBodyLimit,
6    http::{self, Method},
7    routing::{get, post},
8    Router,
9};
10use tower_http::cors::{AllowOrigin, CorsLayer};
11use utoipa_swagger_ui::SwaggerUi;
12
13use crate::{
14    chat_completion::chatcompletions,
15    completions::completions,
16    handlers::{health, models, re_isq},
17    image_generation::image_generation,
18    openapi_doc::get_openapi_doc,
19    responses::{create_response, delete_response, get_response},
20    speech_generation::speech_generation,
21    types::SharedMistralRsState,
22};
23
24// NOTE(EricLBuehler): Accept up to 50mb input
25const N_INPUT_SIZE: usize = 50;
26const MB_TO_B: usize = 1024 * 1024; // 1024 kb in a mb
27
28/// This is the axum default request body limit for the router. Accept up to 50mb input.
29pub const DEFAULT_MAX_BODY_LIMIT: usize = N_INPUT_SIZE * MB_TO_B;
30
31/// A builder for creating a mistral.rs server router with configurable options.
32///
33/// ### Examples
34///
35/// Basic usage:
36/// ```ignore
37/// use mistralrs_server_core::mistralrs_server_router_builder::MistralRsServerRouterBuilder;
38///
39/// let router = MistralRsServerRouterBuilder::new()
40///     .with_mistralrs(mistralrs_instance)
41///     .build()
42///     .await?;
43/// ```
44///
45/// With custom configuration:
46/// ```ignore
47/// use mistralrs_server_core::mistralrs_server_router_builder::MistralRsServerRouterBuilder;
48///
49/// let router = MistralRsServerRouterBuilder::new()
50///     .with_mistralrs(mistralrs_instance)
51///     .with_include_swagger_routes(false)
52///     .with_base_path("/api/mistral")
53///     .build()
54///     .await?;
55/// ```
56pub struct MistralRsServerRouterBuilder {
57    /// The shared mistral.rs instance
58    mistralrs: Option<SharedMistralRsState>,
59    /// Whether to include Swagger/OpenAPI documentation routes
60    include_swagger_routes: bool,
61    /// Optional base path prefix for all routes
62    base_path: Option<String>,
63    /// Optional CORS allowed origins
64    allowed_origins: Option<Vec<String>>,
65    /// Optional axum default request body limit
66    max_body_limit: Option<usize>,
67}
68
69impl Default for MistralRsServerRouterBuilder {
70    /// Creates a new builder with default configuration.
71    fn default() -> Self {
72        Self {
73            mistralrs: None,
74            include_swagger_routes: true,
75            base_path: None,
76            allowed_origins: None,
77            max_body_limit: None,
78        }
79    }
80}
81
82impl MistralRsServerRouterBuilder {
83    /// Creates a new `MistralRsServerRouterBuilder` with default settings.
84    ///
85    /// This is equivalent to calling `Default::default()`.
86    ///
87    /// ### Examples
88    ///
89    /// ```ignore
90    /// use mistralrs_server_core::mistralrs_server_router_builder::MistralRsServerRouterBuilder;
91    ///
92    /// let builder = MistralRsServerRouterBuilder::new();
93    /// ```
94    pub fn new() -> Self {
95        Default::default()
96    }
97
98    /// Sets the shared mistral.rs instance
99    pub fn with_mistralrs(mut self, mistralrs: SharedMistralRsState) -> Self {
100        self.mistralrs = Some(mistralrs);
101        self
102    }
103
104    /// Configures whether to include OpenAPI doc routes.
105    ///
106    /// When enabled (default), the router will include routes for Swagger UI
107    /// at `/docs` and the OpenAPI specification at `/api-doc/openapi.json`.
108    /// These routes respect the configured base path if one is set.
109    pub fn with_include_swagger_routes(mut self, include_swagger_routes: bool) -> Self {
110        self.include_swagger_routes = include_swagger_routes;
111        self
112    }
113
114    /// Sets a base path prefix for all routes.
115    ///
116    /// When set, all routes will be prefixed with the given path. This is
117    /// useful when including the mistral.rs server instance in another axum project.
118    pub fn with_base_path(mut self, base_path: &str) -> Self {
119        self.base_path = Some(base_path.to_owned());
120        self
121    }
122
123    /// Sets the CORS allowed origins.
124    pub fn with_allowed_origins(mut self, origins: Vec<String>) -> Self {
125        self.allowed_origins = Some(origins);
126        self
127    }
128
129    /// Sets the axum default request body limit.
130    pub fn with_max_body_limit(mut self, max_body_limit: usize) -> Self {
131        self.max_body_limit = Some(max_body_limit);
132        self
133    }
134
135    /// Builds the configured axum router.
136    ///
137    /// ### Examples
138    ///
139    /// ```ignore
140    /// use mistralrs_server_core::mistralrs_server_router_builder::MistralRsServerRouterBuilder;
141    ///
142    /// let router = MistralRsServerRouterBuilder::new()
143    ///     .with_mistralrs(mistralrs_instance)
144    ///     .build()
145    ///     .await?;
146    /// ```
147    pub async fn build(self) -> Result<Router> {
148        let mistralrs = self.mistralrs.ok_or_else(|| {
149            anyhow::anyhow!("`mistralrs` instance must be set. Use `with_mistralrs`.")
150        })?;
151
152        let mistralrs_server_router = init_router(
153            mistralrs,
154            self.include_swagger_routes,
155            self.base_path.as_deref(),
156            self.allowed_origins,
157            self.max_body_limit,
158        );
159
160        mistralrs_server_router
161    }
162}
163
164/// Initializes and configures the underlying axum router with MistralRs API endpoints.
165///
166/// This function creates a router with all the necessary API endpoints,
167/// CORS configuration, body size limits, and optional Swagger documentation.
168fn init_router(
169    state: SharedMistralRsState,
170    include_swagger_routes: bool,
171    base_path: Option<&str>,
172    allowed_origins: Option<Vec<String>>,
173    max_body_limit: Option<usize>,
174) -> Result<Router> {
175    let allow_origin = if let Some(origins) = allowed_origins {
176        let parsed_origins: Result<Vec<_>, _> = origins.into_iter().map(|o| o.parse()).collect();
177
178        match parsed_origins {
179            Ok(origins) => AllowOrigin::list(origins),
180            Err(_) => anyhow::bail!("Invalid origin format"),
181        }
182    } else {
183        AllowOrigin::any()
184    };
185
186    let router_max_body_limit = max_body_limit.unwrap_or(DEFAULT_MAX_BODY_LIMIT);
187
188    let cors_layer = CorsLayer::new()
189        .allow_methods([Method::GET, Method::POST])
190        .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
191        .allow_origin(allow_origin);
192
193    // Use the provided base path or default to ""
194    let prefix = base_path.unwrap_or("");
195
196    let mut router = Router::new()
197        .route("/v1/chat/completions", post(chatcompletions))
198        .route("/v1/completions", post(completions))
199        .route("/v1/models", get(models))
200        .route("/health", get(health))
201        .route("/", get(health))
202        .route("/re_isq", post(re_isq))
203        .route("/v1/images/generations", post(image_generation))
204        .route("/v1/audio/speech", post(speech_generation))
205        .route("/v1/responses", post(create_response))
206        .route(
207            "/v1/responses/{response_id}",
208            get(get_response).delete(delete_response),
209        )
210        .layer(cors_layer)
211        .layer(DefaultBodyLimit::max(router_max_body_limit))
212        .with_state(state);
213
214    if include_swagger_routes {
215        let doc = get_openapi_doc(None);
216
217        router = router.merge(
218            SwaggerUi::new(format!("{prefix}/docs"))
219                .url(format!("{prefix}/api-doc/openapi.json"), doc),
220        );
221    }
222
223    Ok(router)
224}