mistralrs_server_core/
mistralrs_server_router_builder.rs

1//! ## mistral.rs server router builder.
2
3use anyhow::Result;
4use axum::{
5    extract::DefaultBodyLimit,
6    http::{self, Method},
7    routing::{get, post},
8    Router,
9};
10use tower_http::cors::{AllowOrigin, CorsLayer};
11use utoipa_swagger_ui::SwaggerUi;
12
13use crate::{
14    chat_completion::chatcompletions,
15    completions::completions,
16    handlers::{health, models, re_isq},
17    image_generation::image_generation,
18    openapi_doc::get_openapi_doc,
19    speech_generation::speech_generation,
20    types::SharedMistralRsState,
21};
22
23// NOTE(EricLBuehler): Accept up to 50mb input
24const N_INPUT_SIZE: usize = 50;
25const MB_TO_B: usize = 1024 * 1024; // 1024 kb in a mb
26
27/// This is the axum default request body limit for the router. Accept up to 50mb input.
28pub const DEFAULT_MAX_BODY_LIMIT: usize = N_INPUT_SIZE * MB_TO_B;
29
30/// A builder for creating a mistral.rs server router with configurable options.
31///
32/// ### Examples
33///
34/// Basic usage:
35/// ```ignore
36/// use mistralrs_server_core::mistralrs_server_router_builder::MistralRsServerRouterBuilder;
37///
38/// let router = MistralRsServerRouterBuilder::new()
39///     .with_mistralrs(mistralrs_instance)
40///     .build()
41///     .await?;
42/// ```
43///
44/// With custom configuration:
45/// ```ignore
46/// use mistralrs_server_core::mistralrs_server_router_builder::MistralRsServerRouterBuilder;
47///
48/// let router = MistralRsServerRouterBuilder::new()
49///     .with_mistralrs(mistralrs_instance)
50///     .with_include_swagger_routes(false)
51///     .with_base_path("/api/mistral")
52///     .build()
53///     .await?;
54/// ```
55pub struct MistralRsServerRouterBuilder {
56    /// The shared mistral.rs instance
57    mistralrs: Option<SharedMistralRsState>,
58    /// Whether to include Swagger/OpenAPI documentation routes
59    include_swagger_routes: bool,
60    /// Optional base path prefix for all routes
61    base_path: Option<String>,
62    /// Optional CORS allowed origins
63    allowed_origins: Option<Vec<String>>,
64    /// Optional axum default request body limit
65    max_body_limit: Option<usize>,
66}
67
68impl Default for MistralRsServerRouterBuilder {
69    /// Creates a new builder with default configuration.
70    fn default() -> Self {
71        Self {
72            mistralrs: None,
73            include_swagger_routes: true,
74            base_path: None,
75            allowed_origins: None,
76            max_body_limit: None,
77        }
78    }
79}
80
81impl MistralRsServerRouterBuilder {
82    /// Creates a new `MistralRsServerRouterBuilder` with default settings.
83    ///
84    /// This is equivalent to calling `Default::default()`.
85    ///
86    /// ### Examples
87    ///
88    /// ```ignore
89    /// use mistralrs_server_core::mistralrs_server_router_builder::MistralRsServerRouterBuilder;
90    ///
91    /// let builder = MistralRsServerRouterBuilder::new();
92    /// ```
93    pub fn new() -> Self {
94        Default::default()
95    }
96
97    /// Sets the shared mistral.rs instance
98    pub fn with_mistralrs(mut self, mistralrs: SharedMistralRsState) -> Self {
99        self.mistralrs = Some(mistralrs);
100        self
101    }
102
103    /// Configures whether to include OpenAPI doc routes.
104    ///
105    /// When enabled (default), the router will include routes for Swagger UI
106    /// at `/docs` and the OpenAPI specification at `/api-doc/openapi.json`.
107    /// These routes respect the configured base path if one is set.
108    pub fn with_include_swagger_routes(mut self, include_swagger_routes: bool) -> Self {
109        self.include_swagger_routes = include_swagger_routes;
110        self
111    }
112
113    /// Sets a base path prefix for all routes.
114    ///
115    /// When set, all routes will be prefixed with the given path. This is
116    /// useful when including the mistral.rs server instance in another axum project.
117    pub fn with_base_path(mut self, base_path: &str) -> Self {
118        self.base_path = Some(base_path.to_owned());
119        self
120    }
121
122    /// Sets the CORS allowed origins.
123    pub fn with_allowed_origins(mut self, origins: Vec<String>) -> Self {
124        self.allowed_origins = Some(origins);
125        self
126    }
127
128    /// Sets the axum default request body limit.
129    pub fn with_max_body_limit(mut self, max_body_limit: usize) -> Self {
130        self.max_body_limit = Some(max_body_limit);
131        self
132    }
133
134    /// Builds the configured axum router.
135    ///
136    /// ### Examples
137    ///
138    /// ```ignore
139    /// use mistralrs_server_core::mistralrs_server_router_builder::MistralRsServerRouterBuilder;
140    ///
141    /// let router = MistralRsServerRouterBuilder::new()
142    ///     .with_mistralrs(mistralrs_instance)
143    ///     .build()
144    ///     .await?;
145    /// ```
146    pub async fn build(self) -> Result<Router> {
147        let mistralrs = self.mistralrs.ok_or_else(|| {
148            anyhow::anyhow!("`mistralrs` instance must be set. Use `with_mistralrs`.")
149        })?;
150
151        let mistralrs_server_router = init_router(
152            mistralrs,
153            self.include_swagger_routes,
154            self.base_path.as_deref(),
155            self.allowed_origins,
156            self.max_body_limit,
157        );
158
159        mistralrs_server_router
160    }
161}
162
163/// Initializes and configures the underlying axum router with MistralRs API endpoints.
164///
165/// This function creates a router with all the necessary API endpoints,
166/// CORS configuration, body size limits, and optional Swagger documentation.
167fn init_router(
168    state: SharedMistralRsState,
169    include_swagger_routes: bool,
170    base_path: Option<&str>,
171    allowed_origins: Option<Vec<String>>,
172    max_body_limit: Option<usize>,
173) -> Result<Router> {
174    let allow_origin = if let Some(origins) = allowed_origins {
175        let parsed_origins: Result<Vec<_>, _> = origins.into_iter().map(|o| o.parse()).collect();
176
177        match parsed_origins {
178            Ok(origins) => AllowOrigin::list(origins),
179            Err(_) => anyhow::bail!("Invalid origin format"),
180        }
181    } else {
182        AllowOrigin::any()
183    };
184
185    let router_max_body_limit = max_body_limit.unwrap_or(DEFAULT_MAX_BODY_LIMIT);
186
187    let cors_layer = CorsLayer::new()
188        .allow_methods([Method::GET, Method::POST])
189        .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
190        .allow_origin(allow_origin);
191
192    // Use the provided base path or default to ""
193    let prefix = base_path.unwrap_or("");
194
195    let mut router = Router::new()
196        .route("/v1/chat/completions", post(chatcompletions))
197        .route("/v1/completions", post(completions))
198        .route("/v1/models", get(models))
199        .route("/health", get(health))
200        .route("/", get(health))
201        .route("/re_isq", post(re_isq))
202        .route("/v1/images/generations", post(image_generation))
203        .route("/v1/audio/speech", post(speech_generation))
204        .layer(cors_layer)
205        .layer(DefaultBodyLimit::max(router_max_body_limit))
206        .with_state(state);
207
208    if include_swagger_routes {
209        let doc = get_openapi_doc(None);
210
211        router = router.merge(
212            SwaggerUi::new(format!("{prefix}/docs"))
213                .url(format!("{prefix}/api-doc/openapi.json"), doc),
214        );
215    }
216
217    Ok(router)
218}