1use anyhow::Result;
2use axum::{
3 extract::{DefaultBodyLimit, Json, State},
4 http::{self, Method},
5 routing::{get, post},
6 Router,
7};
8use candle_core::Device;
9use clap::Parser;
10use mistralrs_core::{
11 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, initialize_logging,
12 paged_attn_supported, parse_isq_value, BertEmbeddingModel, DefaultSchedulerMethod,
13 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, IsqType, Loader, LoaderBuilder,
14 MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelSelected, PagedAttentionConfig, Request,
15 SchedulerConfig, TokenSource,
16};
17use openai::{
18 ChatCompletionRequest, CompletionRequest, ImageGenerationRequest, Message, ModelObjects,
19 StopTokens,
20};
21use serde::{Deserialize, Serialize};
22use std::{num::NonZeroUsize, sync::Arc};
23
24mod chat_completion;
25mod completions;
26mod image_generation;
27mod interactive_mode;
28mod openai;
29mod util;
30
31use crate::openai::ModelObject;
32use crate::{
33 chat_completion::{__path_chatcompletions, chatcompletions},
34 completions::completions,
35 image_generation::image_generation,
36};
37
38use interactive_mode::interactive_mode;
39use tower_http::cors::{AllowOrigin, CorsLayer};
40use tracing::{info, warn};
41use utoipa::{OpenApi, ToSchema};
42use utoipa_swagger_ui::SwaggerUi;
43
44const N_INPUT_SIZE: usize = 50;
46const MB_TO_B: usize = 1024 * 1024; fn parse_token_source(s: &str) -> Result<TokenSource, String> {
49 s.parse()
50}
51
52#[derive(Parser)]
53#[command(version, about, long_about = None)]
54struct Args {
55 #[arg(long)]
57 serve_ip: Option<String>,
58
59 #[arg(short, long)]
61 seed: Option<u64>,
62
63 #[arg(short, long)]
65 port: Option<String>,
66
67 #[clap(long, short)]
69 log: Option<String>,
70
71 #[clap(long, short, action)]
75 truncate_sequence: bool,
76
77 #[clap(subcommand)]
79 model: ModelSelected,
80
81 #[arg(long, default_value_t = 16)]
83 max_seqs: usize,
84
85 #[arg(long, default_value_t = false)]
87 no_kv_cache: bool,
88
89 #[arg(short, long)]
92 chat_template: Option<String>,
93
94 #[arg(short, long)]
96 jinja_explicit: Option<String>,
97
98 #[arg(long, default_value_t = TokenSource::CacheToken, value_parser = parse_token_source)]
102 token_source: TokenSource,
103
104 #[clap(long, short, action)]
106 interactive_mode: bool,
107
108 #[arg(long, default_value_t = 16)]
110 prefix_cache_n: usize,
111
112 #[arg(short, long, value_parser, value_delimiter = ';')]
117 num_device_layers: Option<Vec<String>>,
118
119 #[arg(long = "isq", value_parser = parse_isq_value)]
121 in_situ_quant: Option<IsqType>,
122
123 #[arg(long = "pa-gpu-mem")]
127 paged_attn_gpu_mem: Option<usize>,
128
129 #[arg(long = "pa-gpu-mem-usage")]
134 paged_attn_gpu_mem_usage: Option<f32>,
135
136 #[arg(long = "pa-ctxt-len")]
141 paged_ctxt_len: Option<usize>,
142
143 #[arg(long = "pa-blk-size")]
146 paged_attn_block_size: Option<usize>,
147
148 #[arg(long = "no-paged-attn", default_value_t = false)]
150 no_paged_attn: bool,
151
152 #[arg(long = "paged-attn", default_value_t = false)]
154 paged_attn: bool,
155
156 #[arg(long = "prompt-batchsize")]
158 prompt_chunksize: Option<usize>,
159
160 #[arg(long)]
162 cpu: bool,
163
164 #[arg(long = "interactive-search")]
166 interactive_search: bool,
167
168 #[arg(long = "enable-search")]
170 enable_search: bool,
171
172 #[arg(long = "search-bert-model")]
174 search_bert_model: Option<String>,
175}
176
177#[utoipa::path(
178 get,
179 tag = "Mistral.rs",
180 path = "/v1/models",
181 responses((status = 200, description = "Served model info", body = ModelObjects))
182)]
183async fn models(State(state): State<Arc<MistralRs>>) -> Json<ModelObjects> {
184 Json(ModelObjects {
185 object: "list",
186 data: vec![ModelObject {
187 id: state.get_id(),
188 object: "model",
189 created: state.get_creation_time(),
190 owned_by: "local",
191 }],
192 })
193}
194
195#[utoipa::path(
196 get,
197 tag = "Mistral.rs",
198 path = "/health",
199 responses((status = 200, description = "Server is healthy"))
200)]
201async fn health() -> &'static str {
202 "OK"
203}
204
205#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
206struct ReIsqRequest {
207 #[schema(example = "Q4K")]
208 ggml_type: String,
209}
210
211#[utoipa::path(
212 post,
213 tag = "Mistral.rs",
214 path = "/re_isq",
215 request_body = ReIsqRequest,
216 responses((status = 200, description = "Reapply ISQ to a non GGUF or GGML model."))
217)]
218async fn re_isq(
219 State(state): State<Arc<MistralRs>>,
220 Json(request): Json<ReIsqRequest>,
221) -> Result<String, String> {
222 let repr = format!("Re ISQ: {:?}", request.ggml_type);
223 MistralRs::maybe_log_request(state.clone(), repr.clone());
224 let request = Request::ReIsq(parse_isq_value(&request.ggml_type)?);
225 state.get_sender().unwrap().send(request).await.unwrap();
226 Ok(repr)
227}
228
229fn get_router(state: Arc<MistralRs>) -> Router {
230 #[derive(OpenApi)]
231 #[openapi(
232 paths(models, health, chatcompletions),
233 components(
234 schemas(ModelObjects, ModelObject, ChatCompletionRequest, CompletionRequest, ImageGenerationRequest, StopTokens, Message)),
235 tags(
236 (name = "Mistral.rs", description = "Mistral.rs API")
237 ),
238 info(
239 title = "Mistral.rs",
240 license(
241 name = "MIT",
242 )
243 )
244 )]
245 struct ApiDoc;
246
247 let doc = { ApiDoc::openapi() };
248
249 let allow_origin = AllowOrigin::any();
250 let cors_layer = CorsLayer::new()
251 .allow_methods([Method::GET, Method::POST])
252 .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
253 .allow_origin(allow_origin);
254
255 Router::new()
256 .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc))
257 .route("/v1/chat/completions", post(chatcompletions))
258 .route("/v1/completions", post(completions))
259 .route("/v1/models", get(models))
260 .route("/health", get(health))
261 .route("/", get(health))
262 .route("/re_isq", post(re_isq))
263 .route("/v1/images/generations", post(image_generation))
264 .layer(cors_layer)
265 .layer(DefaultBodyLimit::max(N_INPUT_SIZE * MB_TO_B))
266 .with_state(state)
267}
268
269#[tokio::main]
270async fn main() -> Result<()> {
271 let mut args = Args::parse();
272 initialize_logging();
273
274 let use_flash_attn = mistralrs_core::using_flash_attn();
275
276 let tgt_non_granular_index = get_tgt_non_granular_index(&args.model);
277 let dtype = get_model_dtype(&args.model)?;
278 let auto_device_map_params = get_auto_device_map_params(&args.model)?;
279
280 if tgt_non_granular_index.is_some() {
281 args.max_seqs = 1;
282 }
283
284 let prompt_chunksize = match args.prompt_chunksize {
285 Some(0) => {
286 anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
287 }
288 Some(x) => Some(NonZeroUsize::new(x).unwrap()),
289 None => None,
290 };
291
292 let max_seq_len = auto_device_map_params.max_seq_len();
293
294 let loader: Box<dyn Loader> = LoaderBuilder::new(args.model)
295 .with_no_kv_cache(args.no_kv_cache)
296 .with_chat_template(args.chat_template)
297 .with_use_flash_attn(use_flash_attn)
298 .with_prompt_chunksize(prompt_chunksize)
299 .with_jinja_explicit(args.jinja_explicit)
300 .build()?;
301
302 #[cfg(feature = "metal")]
303 let device = Device::new_metal(0)?;
304 #[cfg(not(feature = "metal"))]
305 let device = if args.cpu {
306 args.no_paged_attn = true;
307 Device::Cpu
308 } else if mistralrs_core::distributed::use_nccl() {
309 Device::Cpu
310 } else {
311 Device::cuda_if_available(0)?
312 };
313
314 if let Some(seed) = args.seed {
315 device.set_seed(seed)?;
316 }
317
318 info!(
319 "avx: {}, neon: {}, simd128: {}, f16c: {}",
320 candle_core::utils::with_avx(),
321 candle_core::utils::with_neon(),
322 candle_core::utils::with_simd128(),
323 candle_core::utils::with_f16c()
324 );
325 info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
326 if use_flash_attn {
327 info!("Using flash attention.");
328 }
329 if use_flash_attn && loader.get_kind().is_quantized() {
330 warn!("Using flash attention with a quantized model has no effect!")
331 }
332 info!("Model kind is: {}", loader.get_kind().to_string());
333
334 let mapper = if let Some(device_layers) = args.num_device_layers {
336 if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
337 let layers = device_layers[0].parse::<usize>().unwrap();
338 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
339 DeviceLayerMapMetadata { ordinal: 0, layers },
340 ]))
341 } else {
342 let mut mapping = Vec::new();
343 for layer in device_layers {
344 let split = layer.splitn(2, ':').collect::<Vec<_>>();
345 if split.len() < 2 {
346 panic!("Expected layer to be of format ORD:NUM, got {layer}");
347 }
348 let ord = split[0]
349 .parse::<usize>()
350 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
351 let num = split[1]
352 .parse::<usize>()
353 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
354 for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
355 if *ordinal == ord {
356 panic!("Duplicate ordinal {ord}");
357 }
358 }
359 mapping.push(DeviceLayerMapMetadata {
360 ordinal: ord,
361 layers: num,
362 });
363 }
364 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
365 }
366 } else {
367 DeviceMapSetting::Auto(auto_device_map_params)
368 };
369
370 let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
371 args.no_paged_attn
372 } else if device.is_metal() {
373 !args.paged_attn
374 } else {
375 true
376 };
377
378 let cache_config = match (
381 args.paged_attn_block_size,
382 args.paged_attn_gpu_mem,
383 args.paged_attn_gpu_mem_usage,
384 args.paged_ctxt_len,
385 paged_attn_supported(),
386 no_paged_attn,
387 ) {
388 (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
389 block_size,
390 512,
391 MemoryGpuConfig::ContextSize(max_seq_len),
392 )?),
393 (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
394 block_size,
395 512,
396 MemoryGpuConfig::ContextSize(ctxt),
397 )?),
398 (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
399 block_size,
400 512,
401 MemoryGpuConfig::Utilization(f),
402 )?),
403 (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
404 block_size,
405 512,
406 MemoryGpuConfig::MbAmount(m),
407 )?),
408 (block_size, Some(_m), Some(f), None, true, false) => {
409 info!("Both memory size, and usage were specified, defaulting to the usage value.");
410 Some(PagedAttentionConfig::new(
411 block_size,
412 512,
413 MemoryGpuConfig::Utilization(f),
414 )?)
415 }
416 (block_size, Some(_m), None, Some(ctxt), true, false) => {
417 info!("All memory size and ctxt len, defaulting to the context len value.");
418 Some(PagedAttentionConfig::new(
419 block_size,
420 512,
421 MemoryGpuConfig::ContextSize(ctxt),
422 )?)
423 }
424 (block_size, None, Some(f), Some(_ctxt), true, false) => {
425 info!("Both ctxt len and usage were specified, defaulting to the usage value.");
426 Some(PagedAttentionConfig::new(
427 block_size,
428 512,
429 MemoryGpuConfig::Utilization(f),
430 )?)
431 }
432 (_, _, _, _, _, _) => None,
433 };
434
435 let pipeline = loader.load_model_from_hf(
436 None,
437 args.token_source,
438 &dtype,
439 &device,
440 false,
441 mapper,
442 args.in_situ_quant,
443 cache_config,
444 )?;
445 info!("Model loaded.");
446
447 let scheduler_config = if cache_config.is_some() {
448 if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
450 SchedulerConfig::PagedAttentionMeta {
451 max_num_seqs: args.max_seqs,
452 config: cache_config.clone(),
453 }
454 } else {
455 SchedulerConfig::DefaultScheduler {
456 method: DefaultSchedulerMethod::Fixed(args.max_seqs.try_into().unwrap()),
457 }
458 }
459 } else {
460 SchedulerConfig::DefaultScheduler {
461 method: DefaultSchedulerMethod::Fixed(args.max_seqs.try_into().unwrap()),
462 }
463 };
464 let bert_model = if args.enable_search {
465 Some(
466 args.search_bert_model
467 .map(BertEmbeddingModel::Custom)
468 .unwrap_or_default(),
469 )
470 } else {
471 None
472 };
473 let mistralrs = MistralRsBuilder::new(
475 pipeline,
476 scheduler_config,
477 !args.interactive_mode,
478 bert_model,
479 )
480 .with_opt_log(args.log)
481 .with_truncate_sequence(args.truncate_sequence)
482 .with_no_kv_cache(args.no_kv_cache)
483 .with_prefix_cache_n(args.prefix_cache_n)
484 .build();
485
486 if args.interactive_mode {
487 interactive_mode(mistralrs, args.interactive_search).await;
488 return Ok(());
489 }
490
491 let setting_server = if !args.interactive_mode {
493 let port = args.port.expect("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port`?");
494 let ip = args.serve_ip.unwrap_or_else(|| "0.0.0.0".to_string());
495
496 let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?;
498 Some((listener, ip, port))
499 } else {
500 None
501 };
502
503 let app = get_router(mistralrs);
504 if let Some((listener, ip, port)) = setting_server {
505 info!("Serving on http://{ip}:{}.", port);
506 axum::serve(listener, app).await?;
507 };
508
509 Ok(())
510}