mistralrs_server_core/
mistralrs_server_router_builder.rs1use anyhow::Result;
4use axum::{
5 extract::DefaultBodyLimit,
6 http::{self, Method},
7 routing::{get, post},
8 Router,
9};
10use tower_http::cors::{AllowOrigin, CorsLayer};
11use utoipa_swagger_ui::SwaggerUi;
12
13use crate::{
14 chat_completion::chatcompletions,
15 completions::completions,
16 embeddings::embeddings,
17 handlers::{health, models, re_isq},
18 image_generation::image_generation,
19 openapi_doc::get_openapi_doc,
20 responses::{create_response, delete_response, get_response},
21 speech_generation::speech_generation,
22 types::SharedMistralRsState,
23};
24
25const N_INPUT_SIZE: usize = 50;
27const MB_TO_B: usize = 1024 * 1024; pub const DEFAULT_MAX_BODY_LIMIT: usize = N_INPUT_SIZE * MB_TO_B;
31
32pub struct MistralRsServerRouterBuilder {
58 mistralrs: Option<SharedMistralRsState>,
60 include_swagger_routes: bool,
62 base_path: Option<String>,
64 allowed_origins: Option<Vec<String>>,
66 max_body_limit: Option<usize>,
68}
69
70impl Default for MistralRsServerRouterBuilder {
71 fn default() -> Self {
73 Self {
74 mistralrs: None,
75 include_swagger_routes: true,
76 base_path: None,
77 allowed_origins: None,
78 max_body_limit: None,
79 }
80 }
81}
82
83impl MistralRsServerRouterBuilder {
84 pub fn new() -> Self {
96 Default::default()
97 }
98
99 pub fn with_mistralrs(mut self, mistralrs: SharedMistralRsState) -> Self {
101 self.mistralrs = Some(mistralrs);
102 self
103 }
104
105 pub fn with_include_swagger_routes(mut self, include_swagger_routes: bool) -> Self {
111 self.include_swagger_routes = include_swagger_routes;
112 self
113 }
114
115 pub fn with_base_path(mut self, base_path: &str) -> Self {
120 self.base_path = Some(base_path.to_owned());
121 self
122 }
123
124 pub fn with_allowed_origins(mut self, origins: Vec<String>) -> Self {
126 self.allowed_origins = Some(origins);
127 self
128 }
129
130 pub fn with_max_body_limit(mut self, max_body_limit: usize) -> Self {
132 self.max_body_limit = Some(max_body_limit);
133 self
134 }
135
136 pub async fn build(self) -> Result<Router> {
149 let mistralrs = self.mistralrs.ok_or_else(|| {
150 anyhow::anyhow!("`mistralrs` instance must be set. Use `with_mistralrs`.")
151 })?;
152
153 let mistralrs_server_router = init_router(
154 mistralrs,
155 self.include_swagger_routes,
156 self.base_path.as_deref(),
157 self.allowed_origins,
158 self.max_body_limit,
159 );
160
161 mistralrs_server_router
162 }
163}
164
165fn init_router(
170 state: SharedMistralRsState,
171 include_swagger_routes: bool,
172 base_path: Option<&str>,
173 allowed_origins: Option<Vec<String>>,
174 max_body_limit: Option<usize>,
175) -> Result<Router> {
176 let allow_origin = if let Some(origins) = allowed_origins {
177 let parsed_origins: Result<Vec<_>, _> = origins.into_iter().map(|o| o.parse()).collect();
178
179 match parsed_origins {
180 Ok(origins) => AllowOrigin::list(origins),
181 Err(_) => anyhow::bail!("Invalid origin format"),
182 }
183 } else {
184 AllowOrigin::any()
185 };
186
187 let router_max_body_limit = max_body_limit.unwrap_or(DEFAULT_MAX_BODY_LIMIT);
188
189 let cors_layer = CorsLayer::new()
190 .allow_methods([Method::GET, Method::POST])
191 .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
192 .allow_origin(allow_origin);
193
194 let prefix = base_path.unwrap_or("");
196
197 let mut router = Router::new()
198 .route("/v1/chat/completions", post(chatcompletions))
199 .route("/v1/completions", post(completions))
200 .route("/v1/embeddings", post(embeddings))
201 .route("/v1/models", get(models))
202 .route("/health", get(health))
203 .route("/", get(health))
204 .route("/re_isq", post(re_isq))
205 .route("/v1/images/generations", post(image_generation))
206 .route("/v1/audio/speech", post(speech_generation))
207 .route("/v1/responses", post(create_response))
208 .route(
209 "/v1/responses/{response_id}",
210 get(get_response).delete(delete_response),
211 )
212 .layer(cors_layer)
213 .layer(DefaultBodyLimit::max(router_max_body_limit))
214 .with_state(state);
215
216 if include_swagger_routes {
217 let doc = get_openapi_doc(None);
218
219 router = router.merge(
220 SwaggerUi::new(format!("{prefix}/docs"))
221 .url(format!("{prefix}/api-doc/openapi.json"), doc),
222 );
223 }
224
225 Ok(router)
226}