mistralrs_server_core/
mistralrs_server_router_builder.rs1use anyhow::Result;
4use axum::{
5 extract::DefaultBodyLimit,
6 http::{self, Method},
7 routing::{get, post},
8 Router,
9};
10use tower_http::cors::{AllowOrigin, CorsLayer};
11use utoipa_swagger_ui::SwaggerUi;
12
13use crate::{
14 chat_completion::chatcompletions,
15 completions::completions,
16 handlers::{health, models, re_isq},
17 image_generation::image_generation,
18 openapi_doc::get_openapi_doc,
19 speech_generation::speech_generation,
20 types::SharedMistralRsState,
21};
22
23const N_INPUT_SIZE: usize = 50;
25const MB_TO_B: usize = 1024 * 1024; pub const DEFAULT_MAX_BODY_LIMIT: usize = N_INPUT_SIZE * MB_TO_B;
29
30pub struct MistralRsServerRouterBuilder {
56 mistralrs: Option<SharedMistralRsState>,
58 include_swagger_routes: bool,
60 base_path: Option<String>,
62 allowed_origins: Option<Vec<String>>,
64 max_body_limit: Option<usize>,
66}
67
68impl Default for MistralRsServerRouterBuilder {
69 fn default() -> Self {
71 Self {
72 mistralrs: None,
73 include_swagger_routes: true,
74 base_path: None,
75 allowed_origins: None,
76 max_body_limit: None,
77 }
78 }
79}
80
81impl MistralRsServerRouterBuilder {
82 pub fn new() -> Self {
94 Default::default()
95 }
96
97 pub fn with_mistralrs(mut self, mistralrs: SharedMistralRsState) -> Self {
99 self.mistralrs = Some(mistralrs);
100 self
101 }
102
103 pub fn with_include_swagger_routes(mut self, include_swagger_routes: bool) -> Self {
109 self.include_swagger_routes = include_swagger_routes;
110 self
111 }
112
113 pub fn with_base_path(mut self, base_path: &str) -> Self {
118 self.base_path = Some(base_path.to_owned());
119 self
120 }
121
122 pub fn with_allowed_origins(mut self, origins: Vec<String>) -> Self {
124 self.allowed_origins = Some(origins);
125 self
126 }
127
128 pub fn with_max_body_limit(mut self, max_body_limit: usize) -> Self {
130 self.max_body_limit = Some(max_body_limit);
131 self
132 }
133
134 pub async fn build(self) -> Result<Router> {
147 let mistralrs = self.mistralrs.ok_or_else(|| {
148 anyhow::anyhow!("`mistralrs` instance must be set. Use `with_mistralrs`.")
149 })?;
150
151 let mistralrs_server_router = init_router(
152 mistralrs,
153 self.include_swagger_routes,
154 self.base_path.as_deref(),
155 self.allowed_origins,
156 self.max_body_limit,
157 );
158
159 mistralrs_server_router
160 }
161}
162
163fn init_router(
168 state: SharedMistralRsState,
169 include_swagger_routes: bool,
170 base_path: Option<&str>,
171 allowed_origins: Option<Vec<String>>,
172 max_body_limit: Option<usize>,
173) -> Result<Router> {
174 let allow_origin = if let Some(origins) = allowed_origins {
175 let parsed_origins: Result<Vec<_>, _> = origins.into_iter().map(|o| o.parse()).collect();
176
177 match parsed_origins {
178 Ok(origins) => AllowOrigin::list(origins),
179 Err(_) => anyhow::bail!("Invalid origin format"),
180 }
181 } else {
182 AllowOrigin::any()
183 };
184
185 let router_max_body_limit = max_body_limit.unwrap_or(DEFAULT_MAX_BODY_LIMIT);
186
187 let cors_layer = CorsLayer::new()
188 .allow_methods([Method::GET, Method::POST])
189 .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
190 .allow_origin(allow_origin);
191
192 let prefix = base_path.unwrap_or("");
194
195 let mut router = Router::new()
196 .route("/v1/chat/completions", post(chatcompletions))
197 .route("/v1/completions", post(completions))
198 .route("/v1/models", get(models))
199 .route("/health", get(health))
200 .route("/", get(health))
201 .route("/re_isq", post(re_isq))
202 .route("/v1/images/generations", post(image_generation))
203 .route("/v1/audio/speech", post(speech_generation))
204 .layer(cors_layer)
205 .layer(DefaultBodyLimit::max(router_max_body_limit))
206 .with_state(state);
207
208 if include_swagger_routes {
209 let doc = get_openapi_doc(None);
210
211 router = router.merge(
212 SwaggerUi::new(format!("{prefix}/docs"))
213 .url(format!("{prefix}/api-doc/openapi.json"), doc),
214 );
215 }
216
217 Ok(router)
218}