mistralrs_server_core/
mistralrs_server_router_builder.rs1use anyhow::Result;
4use axum::{
5 extract::DefaultBodyLimit,
6 http::{self, Method},
7 routing::{get, post},
8 Router,
9};
10use tower_http::cors::{AllowOrigin, CorsLayer};
11use utoipa_swagger_ui::SwaggerUi;
12
13use crate::{
14 chat_completion::chatcompletions,
15 completions::completions,
16 handlers::{health, models, re_isq},
17 image_generation::image_generation,
18 openapi_doc::get_openapi_doc,
19 responses::{create_response, delete_response, get_response},
20 speech_generation::speech_generation,
21 types::SharedMistralRsState,
22};
23
24const N_INPUT_SIZE: usize = 50;
26const MB_TO_B: usize = 1024 * 1024; pub const DEFAULT_MAX_BODY_LIMIT: usize = N_INPUT_SIZE * MB_TO_B;
30
31pub struct MistralRsServerRouterBuilder {
57 mistralrs: Option<SharedMistralRsState>,
59 include_swagger_routes: bool,
61 base_path: Option<String>,
63 allowed_origins: Option<Vec<String>>,
65 max_body_limit: Option<usize>,
67}
68
69impl Default for MistralRsServerRouterBuilder {
70 fn default() -> Self {
72 Self {
73 mistralrs: None,
74 include_swagger_routes: true,
75 base_path: None,
76 allowed_origins: None,
77 max_body_limit: None,
78 }
79 }
80}
81
82impl MistralRsServerRouterBuilder {
83 pub fn new() -> Self {
95 Default::default()
96 }
97
98 pub fn with_mistralrs(mut self, mistralrs: SharedMistralRsState) -> Self {
100 self.mistralrs = Some(mistralrs);
101 self
102 }
103
104 pub fn with_include_swagger_routes(mut self, include_swagger_routes: bool) -> Self {
110 self.include_swagger_routes = include_swagger_routes;
111 self
112 }
113
114 pub fn with_base_path(mut self, base_path: &str) -> Self {
119 self.base_path = Some(base_path.to_owned());
120 self
121 }
122
123 pub fn with_allowed_origins(mut self, origins: Vec<String>) -> Self {
125 self.allowed_origins = Some(origins);
126 self
127 }
128
129 pub fn with_max_body_limit(mut self, max_body_limit: usize) -> Self {
131 self.max_body_limit = Some(max_body_limit);
132 self
133 }
134
135 pub async fn build(self) -> Result<Router> {
148 let mistralrs = self.mistralrs.ok_or_else(|| {
149 anyhow::anyhow!("`mistralrs` instance must be set. Use `with_mistralrs`.")
150 })?;
151
152 let mistralrs_server_router = init_router(
153 mistralrs,
154 self.include_swagger_routes,
155 self.base_path.as_deref(),
156 self.allowed_origins,
157 self.max_body_limit,
158 );
159
160 mistralrs_server_router
161 }
162}
163
164fn init_router(
169 state: SharedMistralRsState,
170 include_swagger_routes: bool,
171 base_path: Option<&str>,
172 allowed_origins: Option<Vec<String>>,
173 max_body_limit: Option<usize>,
174) -> Result<Router> {
175 let allow_origin = if let Some(origins) = allowed_origins {
176 let parsed_origins: Result<Vec<_>, _> = origins.into_iter().map(|o| o.parse()).collect();
177
178 match parsed_origins {
179 Ok(origins) => AllowOrigin::list(origins),
180 Err(_) => anyhow::bail!("Invalid origin format"),
181 }
182 } else {
183 AllowOrigin::any()
184 };
185
186 let router_max_body_limit = max_body_limit.unwrap_or(DEFAULT_MAX_BODY_LIMIT);
187
188 let cors_layer = CorsLayer::new()
189 .allow_methods([Method::GET, Method::POST])
190 .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
191 .allow_origin(allow_origin);
192
193 let prefix = base_path.unwrap_or("");
195
196 let mut router = Router::new()
197 .route("/v1/chat/completions", post(chatcompletions))
198 .route("/v1/completions", post(completions))
199 .route("/v1/models", get(models))
200 .route("/health", get(health))
201 .route("/", get(health))
202 .route("/re_isq", post(re_isq))
203 .route("/v1/images/generations", post(image_generation))
204 .route("/v1/audio/speech", post(speech_generation))
205 .route("/v1/responses", post(create_response))
206 .route(
207 "/v1/responses/{response_id}",
208 get(get_response).delete(delete_response),
209 )
210 .layer(cors_layer)
211 .layer(DefaultBodyLimit::max(router_max_body_limit))
212 .with_state(state);
213
214 if include_swagger_routes {
215 let doc = get_openapi_doc(None);
216
217 router = router.merge(
218 SwaggerUi::new(format!("{prefix}/docs"))
219 .url(format!("{prefix}/api-doc/openapi.json"), doc),
220 );
221 }
222
223 Ok(router)
224}