1use anyhow::Result;
2use axum::{
3 extract::{DefaultBodyLimit, Json, State},
4 http::{self, Method},
5 routing::{get, post},
6 Router,
7};
8use candle_core::Device;
9use clap::Parser;
10use mistralrs_core::{
11 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, initialize_logging,
12 paged_attn_supported, parse_isq_value, BertEmbeddingModel, DefaultSchedulerMethod,
13 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, IsqType, Loader, LoaderBuilder,
14 MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelSelected, PagedAttentionConfig, Request,
15 SchedulerConfig, TokenSource,
16};
17use openai::{
18 ChatCompletionRequest, CompletionRequest, ImageGenerationRequest, Message, ModelObjects,
19 StopTokens,
20};
21use serde::{Deserialize, Serialize};
22use std::{num::NonZeroUsize, sync::Arc};
23
24mod chat_completion;
25mod completions;
26mod image_generation;
27mod interactive_mode;
28mod openai;
29mod util;
30
31use crate::openai::ModelObject;
32use crate::{
33 chat_completion::{__path_chatcompletions, chatcompletions},
34 completions::completions,
35 image_generation::image_generation,
36};
37
38use interactive_mode::interactive_mode;
39use tower_http::cors::{AllowOrigin, CorsLayer};
40use tracing::{info, warn};
41use utoipa::{OpenApi, ToSchema};
42use utoipa_swagger_ui::SwaggerUi;
43
44const N_INPUT_SIZE: usize = 50;
46const MB_TO_B: usize = 1024 * 1024; fn parse_token_source(s: &str) -> Result<TokenSource, String> {
49 s.parse()
50}
51
52#[derive(Parser)]
53#[command(version, about, long_about = None)]
54struct Args {
55 #[arg(long)]
57 serve_ip: Option<String>,
58
59 #[arg(short, long)]
61 seed: Option<u64>,
62
63 #[arg(short, long)]
65 port: Option<String>,
66
67 #[clap(long, short)]
69 log: Option<String>,
70
71 #[clap(long, short, action)]
75 truncate_sequence: bool,
76
77 #[clap(subcommand)]
79 model: ModelSelected,
80
81 #[arg(long, default_value_t = 16)]
83 max_seqs: usize,
84
85 #[arg(long, default_value_t = false)]
87 no_kv_cache: bool,
88
89 #[arg(short, long)]
92 chat_template: Option<String>,
93
94 #[arg(short, long)]
96 jinja_explicit: Option<String>,
97
98 #[arg(long, default_value_t = TokenSource::CacheToken, value_parser = parse_token_source)]
102 token_source: TokenSource,
103
104 #[clap(long, short, action)]
106 interactive_mode: bool,
107
108 #[arg(long, default_value_t = 16)]
110 prefix_cache_n: usize,
111
112 #[arg(short, long, value_parser, value_delimiter = ';')]
117 num_device_layers: Option<Vec<String>>,
118
119 #[arg(long = "isq", value_parser = parse_isq_value)]
121 in_situ_quant: Option<IsqType>,
122
123 #[arg(long = "pa-gpu-mem")]
127 paged_attn_gpu_mem: Option<usize>,
128
129 #[arg(long = "pa-gpu-mem-usage")]
134 paged_attn_gpu_mem_usage: Option<f32>,
135
136 #[arg(long = "pa-ctxt-len")]
141 paged_ctxt_len: Option<usize>,
142
143 #[arg(long = "pa-blk-size")]
146 paged_attn_block_size: Option<usize>,
147
148 #[arg(long = "no-paged-attn", default_value_t = false)]
150 no_paged_attn: bool,
151
152 #[arg(long = "paged-attn", default_value_t = false)]
154 paged_attn: bool,
155
156 #[arg(long = "throughput", default_value_t = false)]
158 throughput_log: bool,
159
160 #[arg(long = "prompt-batchsize")]
162 prompt_chunksize: Option<usize>,
163
164 #[arg(long)]
166 cpu: bool,
167
168 #[arg(long = "interactive-search")]
170 interactive_search: bool,
171
172 #[arg(long = "enable-search")]
174 enable_search: bool,
175
176 #[arg(long = "search-bert-model")]
178 search_bert_model: Option<String>,
179}
180
181#[utoipa::path(
182 get,
183 tag = "Mistral.rs",
184 path = "/v1/models",
185 responses((status = 200, description = "Served model info", body = ModelObjects))
186)]
187async fn models(State(state): State<Arc<MistralRs>>) -> Json<ModelObjects> {
188 Json(ModelObjects {
189 object: "list",
190 data: vec![ModelObject {
191 id: state.get_id(),
192 object: "model",
193 created: state.get_creation_time(),
194 owned_by: "local",
195 }],
196 })
197}
198
199#[utoipa::path(
200 get,
201 tag = "Mistral.rs",
202 path = "/health",
203 responses((status = 200, description = "Server is healthy"))
204)]
205async fn health() -> &'static str {
206 "OK"
207}
208
209#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
210struct AdapterActivationRequest {
211 #[schema(example = json!(vec!["adapter_1","adapter_2"]))]
212 adapter_names: Vec<String>,
213}
214
215#[utoipa::path(
216 post,
217 tag = "Mistral.rs",
218 path = "/activate_adapters",
219 request_body = AdapterActivationRequest,
220 responses((status = 200, description = "Activate a set of pre-loaded LoRA adapters"))
221)]
222async fn activate_adapters(
223 State(state): State<Arc<MistralRs>>,
224 Json(request): Json<AdapterActivationRequest>,
225) -> String {
226 let repr = format!("Adapter activation: {:?}", request.adapter_names);
227 MistralRs::maybe_log_request(state.clone(), repr.clone());
228 let request = Request::ActivateAdapters(request.adapter_names);
229 state.get_sender().unwrap().send(request).await.unwrap();
230 repr
231}
232
233#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
234struct ReIsqRequest {
235 #[schema(example = "Q4K")]
236 ggml_type: String,
237}
238
239#[utoipa::path(
240 post,
241 tag = "Mistral.rs",
242 path = "/re_isq",
243 request_body = ReIsqRequest,
244 responses((status = 200, description = "Reapply ISQ to a non GGUF or GGML model."))
245)]
246async fn re_isq(
247 State(state): State<Arc<MistralRs>>,
248 Json(request): Json<ReIsqRequest>,
249) -> Result<String, String> {
250 let repr = format!("Re ISQ: {:?}", request.ggml_type);
251 MistralRs::maybe_log_request(state.clone(), repr.clone());
252 let request = Request::ReIsq(parse_isq_value(&request.ggml_type)?);
253 state.get_sender().unwrap().send(request).await.unwrap();
254 Ok(repr)
255}
256
257fn get_router(state: Arc<MistralRs>) -> Router {
258 #[derive(OpenApi)]
259 #[openapi(
260 paths(models, health, chatcompletions),
261 components(
262 schemas(ModelObjects, ModelObject, ChatCompletionRequest, CompletionRequest, ImageGenerationRequest, StopTokens, Message)),
263 tags(
264 (name = "Mistral.rs", description = "Mistral.rs API")
265 ),
266 info(
267 title = "Mistral.rs",
268 license(
269 name = "MIT",
270 )
271 )
272 )]
273 struct ApiDoc;
274
275 let doc = { ApiDoc::openapi() };
276
277 let allow_origin = AllowOrigin::any();
278 let cors_layer = CorsLayer::new()
279 .allow_methods([Method::GET, Method::POST])
280 .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
281 .allow_origin(allow_origin);
282
283 Router::new()
284 .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc))
285 .route("/v1/chat/completions", post(chatcompletions))
286 .route("/v1/completions", post(completions))
287 .route("/v1/models", get(models))
288 .route("/health", get(health))
289 .route("/", get(health))
290 .route("/activate_adapters", post(activate_adapters))
291 .route("/re_isq", post(re_isq))
292 .route("/v1/images/generations", post(image_generation))
293 .layer(cors_layer)
294 .layer(DefaultBodyLimit::max(N_INPUT_SIZE * MB_TO_B))
295 .with_state(state)
296}
297
298#[tokio::main]
299async fn main() -> Result<()> {
300 let mut args = Args::parse();
301 initialize_logging();
302
303 let use_flash_attn = mistralrs_core::using_flash_attn();
304
305 let tgt_non_granular_index = get_tgt_non_granular_index(&args.model);
306 let dtype = get_model_dtype(&args.model)?;
307 let auto_device_map_params = get_auto_device_map_params(&args.model)?;
308
309 if tgt_non_granular_index.is_some() {
310 args.max_seqs = 1;
311 }
312
313 let prompt_chunksize = match args.prompt_chunksize {
314 Some(0) => {
315 anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
316 }
317 Some(x) => Some(NonZeroUsize::new(x).unwrap()),
318 None => None,
319 };
320
321 let max_seq_len = auto_device_map_params.max_seq_len();
322
323 let loader: Box<dyn Loader> = LoaderBuilder::new(args.model)
324 .with_no_kv_cache(args.no_kv_cache)
325 .with_chat_template(args.chat_template)
326 .with_use_flash_attn(use_flash_attn)
327 .with_prompt_chunksize(prompt_chunksize)
328 .with_jinja_explicit(args.jinja_explicit)
329 .build()?;
330
331 #[cfg(feature = "metal")]
332 let device = Device::new_metal(0)?;
333 #[cfg(not(feature = "metal"))]
334 let device = if args.cpu {
335 args.no_paged_attn = true;
336 Device::Cpu
337 } else if mistralrs_core::distributed::use_nccl() {
338 Device::Cpu
339 } else {
340 Device::cuda_if_available(0)?
341 };
342
343 if let Some(seed) = args.seed {
344 device.set_seed(seed)?;
345 }
346
347 info!(
348 "avx: {}, neon: {}, simd128: {}, f16c: {}",
349 candle_core::utils::with_avx(),
350 candle_core::utils::with_neon(),
351 candle_core::utils::with_simd128(),
352 candle_core::utils::with_f16c()
353 );
354 info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
355 if use_flash_attn {
356 info!("Using flash attention.");
357 }
358 if use_flash_attn && loader.get_kind().is_quantized() {
359 warn!("Using flash attention with a quantized model has no effect!")
360 }
361 info!("Model kind is: {}", loader.get_kind().to_string());
362
363 let mapper = if let Some(device_layers) = args.num_device_layers {
365 if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
366 let layers = device_layers[0].parse::<usize>().unwrap();
367 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
368 DeviceLayerMapMetadata { ordinal: 0, layers },
369 ]))
370 } else {
371 let mut mapping = Vec::new();
372 for layer in device_layers {
373 let split = layer.splitn(2, ':').collect::<Vec<_>>();
374 if split.len() < 2 {
375 panic!("Expected layer to be of format ORD:NUM, got {layer}");
376 }
377 let ord = split[0]
378 .parse::<usize>()
379 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
380 let num = split[1]
381 .parse::<usize>()
382 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
383 for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
384 if *ordinal == ord {
385 panic!("Duplicate ordinal {ord}");
386 }
387 }
388 mapping.push(DeviceLayerMapMetadata {
389 ordinal: ord,
390 layers: num,
391 });
392 }
393 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
394 }
395 } else {
396 DeviceMapSetting::Auto(auto_device_map_params)
397 };
398
399 let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
400 args.no_paged_attn
401 } else if device.is_metal() {
402 !args.paged_attn
403 } else {
404 true
405 };
406
407 let cache_config = match (
410 args.paged_attn_block_size,
411 args.paged_attn_gpu_mem,
412 args.paged_attn_gpu_mem_usage,
413 args.paged_ctxt_len,
414 paged_attn_supported(),
415 no_paged_attn,
416 ) {
417 (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
418 block_size,
419 512,
420 MemoryGpuConfig::ContextSize(max_seq_len),
421 )?),
422 (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
423 block_size,
424 512,
425 MemoryGpuConfig::ContextSize(ctxt),
426 )?),
427 (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
428 block_size,
429 512,
430 MemoryGpuConfig::Utilization(f),
431 )?),
432 (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
433 block_size,
434 512,
435 MemoryGpuConfig::MbAmount(m),
436 )?),
437 (block_size, Some(_m), Some(f), None, true, false) => {
438 info!("Both memory size, and usage were specified, defaulting to the usage value.");
439 Some(PagedAttentionConfig::new(
440 block_size,
441 512,
442 MemoryGpuConfig::Utilization(f),
443 )?)
444 }
445 (block_size, Some(_m), None, Some(ctxt), true, false) => {
446 info!("All memory size and ctxt len, defaulting to the context len value.");
447 Some(PagedAttentionConfig::new(
448 block_size,
449 512,
450 MemoryGpuConfig::ContextSize(ctxt),
451 )?)
452 }
453 (block_size, None, Some(f), Some(_ctxt), true, false) => {
454 info!("Both ctxt len and usage were specified, defaulting to the usage value.");
455 Some(PagedAttentionConfig::new(
456 block_size,
457 512,
458 MemoryGpuConfig::Utilization(f),
459 )?)
460 }
461 (_, _, _, _, _, _) => None,
462 };
463
464 let pipeline = loader.load_model_from_hf(
465 None,
466 args.token_source,
467 &dtype,
468 &device,
469 false,
470 mapper,
471 args.in_situ_quant,
472 cache_config,
473 )?;
474 info!("Model loaded.");
475
476 let scheduler_config = if cache_config.is_some() {
477 if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
479 SchedulerConfig::PagedAttentionMeta {
480 max_num_seqs: args.max_seqs,
481 config: cache_config.clone(),
482 }
483 } else {
484 SchedulerConfig::DefaultScheduler {
485 method: DefaultSchedulerMethod::Fixed(args.max_seqs.try_into().unwrap()),
486 }
487 }
488 } else {
489 SchedulerConfig::DefaultScheduler {
490 method: DefaultSchedulerMethod::Fixed(args.max_seqs.try_into().unwrap()),
491 }
492 };
493 let bert_model = if args.enable_search {
494 Some(
495 args.search_bert_model
496 .map(BertEmbeddingModel::Custom)
497 .unwrap_or_default(),
498 )
499 } else {
500 None
501 };
502 let mistralrs = MistralRsBuilder::new(
504 pipeline,
505 scheduler_config,
506 !args.interactive_mode,
507 bert_model,
508 )
509 .with_opt_log(args.log)
510 .with_truncate_sequence(args.truncate_sequence)
511 .with_no_kv_cache(args.no_kv_cache)
512 .with_prefix_cache_n(args.prefix_cache_n)
513 .build();
514
515 if args.interactive_mode {
516 interactive_mode(mistralrs, args.throughput_log, args.interactive_search).await;
517 return Ok(());
518 }
519
520 let setting_server = if !args.interactive_mode {
522 let port = args.port.expect("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port`?");
523 let ip = args.serve_ip.unwrap_or_else(|| "0.0.0.0".to_string());
524
525 let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?;
527 Some((listener, ip, port))
528 } else {
529 None
530 };
531
532 let app = get_router(mistralrs);
533 if let Some((listener, ip, port)) = setting_server {
534 info!("Serving on http://{ip}:{}.", port);
535 axum::serve(listener, app).await?;
536 };
537
538 Ok(())
539}