mistralrs_server/
main.rs

1use anyhow::Result;
2use axum::{
3    extract::{DefaultBodyLimit, Json, State},
4    http::{self, Method},
5    routing::{get, post},
6    Router,
7};
8use candle_core::Device;
9use clap::Parser;
10use mistralrs_core::{
11    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, initialize_logging,
12    paged_attn_supported, parse_isq_value, BertEmbeddingModel, DefaultSchedulerMethod,
13    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, IsqType, Loader, LoaderBuilder,
14    MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelSelected, PagedAttentionConfig, Request,
15    SchedulerConfig, TokenSource,
16};
17use openai::{
18    ChatCompletionRequest, CompletionRequest, ImageGenerationRequest, Message, ModelObjects,
19    StopTokens,
20};
21use serde::{Deserialize, Serialize};
22use std::{num::NonZeroUsize, sync::Arc};
23
24mod chat_completion;
25mod completions;
26mod image_generation;
27mod interactive_mode;
28mod openai;
29mod util;
30
31use crate::openai::ModelObject;
32use crate::{
33    chat_completion::{__path_chatcompletions, chatcompletions},
34    completions::completions,
35    image_generation::image_generation,
36};
37
38use interactive_mode::interactive_mode;
39use tower_http::cors::{AllowOrigin, CorsLayer};
40use tracing::{info, warn};
41use utoipa::{OpenApi, ToSchema};
42use utoipa_swagger_ui::SwaggerUi;
43
44// NOTE(EricLBuehler): Accept up to 50mb input
45const N_INPUT_SIZE: usize = 50;
46const MB_TO_B: usize = 1024 * 1024; // 1024 kb in a mb
47
48fn parse_token_source(s: &str) -> Result<TokenSource, String> {
49    s.parse()
50}
51
52#[derive(Parser)]
53#[command(version, about, long_about = None)]
54struct Args {
55    /// IP to serve on. Defaults to "0.0.0.0"
56    #[arg(long)]
57    serve_ip: Option<String>,
58
59    /// Integer seed to ensure reproducible random number generation.
60    #[arg(short, long)]
61    seed: Option<u64>,
62
63    /// Port to serve on.
64    #[arg(short, long)]
65    port: Option<String>,
66
67    /// Log all responses and requests to this file
68    #[clap(long, short)]
69    log: Option<String>,
70
71    /// If a sequence is larger than the maximum model length, truncate the number
72    /// of tokens such that the sequence will fit at most the maximum length.
73    /// If `max_tokens` is not specified in the request, space for 10 tokens will be reserved instead.
74    #[clap(long, short, action)]
75    truncate_sequence: bool,
76
77    /// Model selector
78    #[clap(subcommand)]
79    model: ModelSelected,
80
81    /// Maximum running sequences at any time. If the `tgt_non_granular_index` flag is set for X-LoRA models, this will be set to 1.
82    #[arg(long, default_value_t = 16)]
83    max_seqs: usize,
84
85    /// Use no KV cache.
86    #[arg(long, default_value_t = false)]
87    no_kv_cache: bool,
88
89    /// Chat template file with a JINJA file with `messages`, `add_generation_prompt`, `bos_token`, `eos_token`, and `unk_token` as inputs.
90    /// Used if the automatic deserialization fails. If this ends with `.json` (ie., it is a file) then that template is loaded.
91    #[arg(short, long)]
92    chat_template: Option<String>,
93
94    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
95    #[arg(short, long)]
96    jinja_explicit: Option<String>,
97
98    /// Source of the token for authentication.
99    /// Can be in the formats: `literal:<value>`, `env:<value>`, `path:<value>`, `cache` to use a cached token, or `none` to use no token.
100    /// Defaults to `cache`.
101    #[arg(long, default_value_t = TokenSource::CacheToken, value_parser = parse_token_source)]
102    token_source: TokenSource,
103
104    /// Enter interactive mode instead of serving a chat server.
105    #[clap(long, short, action)]
106    interactive_mode: bool,
107
108    /// Number of prefix caches to hold on the device. Other caches are evicted to the CPU based on a LRU strategy.
109    #[arg(long, default_value_t = 16)]
110    prefix_cache_n: usize,
111
112    /// NOTE: This can be omitted to use automatic device mapping!
113    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
114    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
115    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
116    #[arg(short, long, value_parser, value_delimiter = ';')]
117    num_device_layers: Option<Vec<String>>,
118
119    /// In-situ quantization to apply.
120    #[arg(long = "isq", value_parser = parse_isq_value)]
121    in_situ_quant: Option<IsqType>,
122
123    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
124    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
125    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
126    #[arg(long = "pa-gpu-mem")]
127    paged_attn_gpu_mem: Option<usize>,
128
129    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
130    /// If this is not set and the device is CUDA, it will default to `0.9`.
131    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
132    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
133    #[arg(long = "pa-gpu-mem-usage")]
134    paged_attn_gpu_mem_usage: Option<f32>,
135
136    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
137    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
138    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
139    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
140    #[arg(long = "pa-ctxt-len")]
141    paged_ctxt_len: Option<usize>,
142
143    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
144    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
145    #[arg(long = "pa-blk-size")]
146    paged_attn_block_size: Option<usize>,
147
148    /// Disable PagedAttention on CUDA. Because PagedAttention is already disabled on Metal, this is only applicable on CUDA.
149    #[arg(long = "no-paged-attn", default_value_t = false)]
150    no_paged_attn: bool,
151
152    /// Enable PagedAttention on Metal. Because PagedAttention is already enabled on CUDA, this is only applicable on Metal.
153    #[arg(long = "paged-attn", default_value_t = false)]
154    paged_attn: bool,
155
156    /// Number of tokens to batch the prompt step into. This can help with OOM errors when in the prompt step, but reduces performance.
157    #[arg(long = "prompt-batchsize")]
158    prompt_chunksize: Option<usize>,
159
160    /// Use CPU only
161    #[arg(long)]
162    cpu: bool,
163
164    /// Enable web searching for interactive mode.
165    #[arg(long = "interactive-search")]
166    interactive_search: bool,
167
168    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified below or the default.
169    #[arg(long = "enable-search")]
170    enable_search: bool,
171
172    /// Specify a Hugging Face model ID for a BERT model to assist web searching. Defaults to Snowflake Arctic Embed L.
173    #[arg(long = "search-bert-model")]
174    search_bert_model: Option<String>,
175}
176
177#[utoipa::path(
178    get,
179    tag = "Mistral.rs",
180    path = "/v1/models",
181    responses((status = 200, description = "Served model info", body = ModelObjects))
182)]
183async fn models(State(state): State<Arc<MistralRs>>) -> Json<ModelObjects> {
184    Json(ModelObjects {
185        object: "list",
186        data: vec![ModelObject {
187            id: state.get_id(),
188            object: "model",
189            created: state.get_creation_time(),
190            owned_by: "local",
191        }],
192    })
193}
194
195#[utoipa::path(
196    get,
197    tag = "Mistral.rs",
198    path = "/health",
199    responses((status = 200, description = "Server is healthy"))
200)]
201async fn health() -> &'static str {
202    "OK"
203}
204
205#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
206struct ReIsqRequest {
207    #[schema(example = "Q4K")]
208    ggml_type: String,
209}
210
211#[utoipa::path(
212    post,
213    tag = "Mistral.rs",
214    path = "/re_isq",
215    request_body = ReIsqRequest,
216    responses((status = 200, description = "Reapply ISQ to a non GGUF or GGML model."))
217)]
218async fn re_isq(
219    State(state): State<Arc<MistralRs>>,
220    Json(request): Json<ReIsqRequest>,
221) -> Result<String, String> {
222    let repr = format!("Re ISQ: {:?}", request.ggml_type);
223    MistralRs::maybe_log_request(state.clone(), repr.clone());
224    let request = Request::ReIsq(parse_isq_value(&request.ggml_type)?);
225    state.get_sender().unwrap().send(request).await.unwrap();
226    Ok(repr)
227}
228
229fn get_router(state: Arc<MistralRs>) -> Router {
230    #[derive(OpenApi)]
231    #[openapi(
232        paths(models, health, chatcompletions),
233        components(
234            schemas(ModelObjects, ModelObject, ChatCompletionRequest, CompletionRequest, ImageGenerationRequest, StopTokens, Message)),
235        tags(
236            (name = "Mistral.rs", description = "Mistral.rs API")
237        ),
238        info(
239            title = "Mistral.rs",
240            license(
241            name = "MIT",
242        )
243        )
244    )]
245    struct ApiDoc;
246
247    let doc = { ApiDoc::openapi() };
248
249    let allow_origin = AllowOrigin::any();
250    let cors_layer = CorsLayer::new()
251        .allow_methods([Method::GET, Method::POST])
252        .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
253        .allow_origin(allow_origin);
254
255    Router::new()
256        .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc))
257        .route("/v1/chat/completions", post(chatcompletions))
258        .route("/v1/completions", post(completions))
259        .route("/v1/models", get(models))
260        .route("/health", get(health))
261        .route("/", get(health))
262        .route("/re_isq", post(re_isq))
263        .route("/v1/images/generations", post(image_generation))
264        .layer(cors_layer)
265        .layer(DefaultBodyLimit::max(N_INPUT_SIZE * MB_TO_B))
266        .with_state(state)
267}
268
269#[tokio::main]
270async fn main() -> Result<()> {
271    let mut args = Args::parse();
272    initialize_logging();
273
274    let use_flash_attn = mistralrs_core::using_flash_attn();
275
276    let tgt_non_granular_index = get_tgt_non_granular_index(&args.model);
277    let dtype = get_model_dtype(&args.model)?;
278    let auto_device_map_params = get_auto_device_map_params(&args.model)?;
279
280    if tgt_non_granular_index.is_some() {
281        args.max_seqs = 1;
282    }
283
284    let prompt_chunksize = match args.prompt_chunksize {
285        Some(0) => {
286            anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
287        }
288        Some(x) => Some(NonZeroUsize::new(x).unwrap()),
289        None => None,
290    };
291
292    let max_seq_len = auto_device_map_params.max_seq_len();
293
294    let loader: Box<dyn Loader> = LoaderBuilder::new(args.model)
295        .with_no_kv_cache(args.no_kv_cache)
296        .with_chat_template(args.chat_template)
297        .with_use_flash_attn(use_flash_attn)
298        .with_prompt_chunksize(prompt_chunksize)
299        .with_jinja_explicit(args.jinja_explicit)
300        .build()?;
301
302    #[cfg(feature = "metal")]
303    let device = Device::new_metal(0)?;
304    #[cfg(not(feature = "metal"))]
305    let device = if args.cpu {
306        args.no_paged_attn = true;
307        Device::Cpu
308    } else if mistralrs_core::distributed::use_nccl() {
309        Device::Cpu
310    } else {
311        Device::cuda_if_available(0)?
312    };
313
314    if let Some(seed) = args.seed {
315        device.set_seed(seed)?;
316    }
317
318    info!(
319        "avx: {}, neon: {}, simd128: {}, f16c: {}",
320        candle_core::utils::with_avx(),
321        candle_core::utils::with_neon(),
322        candle_core::utils::with_simd128(),
323        candle_core::utils::with_f16c()
324    );
325    info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
326    if use_flash_attn {
327        info!("Using flash attention.");
328    }
329    if use_flash_attn && loader.get_kind().is_quantized() {
330        warn!("Using flash attention with a quantized model has no effect!")
331    }
332    info!("Model kind is: {}", loader.get_kind().to_string());
333
334    // Parse device mapper
335    let mapper = if let Some(device_layers) = args.num_device_layers {
336        if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
337            let layers = device_layers[0].parse::<usize>().unwrap();
338            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
339                DeviceLayerMapMetadata { ordinal: 0, layers },
340            ]))
341        } else {
342            let mut mapping = Vec::new();
343            for layer in device_layers {
344                let split = layer.splitn(2, ':').collect::<Vec<_>>();
345                if split.len() < 2 {
346                    panic!("Expected layer to be of format ORD:NUM, got {layer}");
347                }
348                let ord = split[0]
349                    .parse::<usize>()
350                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
351                let num = split[1]
352                    .parse::<usize>()
353                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
354                for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
355                    if *ordinal == ord {
356                        panic!("Duplicate ordinal {ord}");
357                    }
358                }
359                mapping.push(DeviceLayerMapMetadata {
360                    ordinal: ord,
361                    layers: num,
362                });
363            }
364            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
365        }
366    } else {
367        DeviceMapSetting::Auto(auto_device_map_params)
368    };
369
370    let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
371        args.no_paged_attn
372    } else if device.is_metal() {
373        !args.paged_attn
374    } else {
375        true
376    };
377
378    // Allocate 0.5 GB of CPU memory just as a placeholder.
379    // Nothing happens here as we have no `swap_out`, see `_preempt_by_swap`.
380    let cache_config = match (
381        args.paged_attn_block_size,
382        args.paged_attn_gpu_mem,
383        args.paged_attn_gpu_mem_usage,
384        args.paged_ctxt_len,
385        paged_attn_supported(),
386        no_paged_attn,
387    ) {
388        (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
389            block_size,
390            512,
391            MemoryGpuConfig::ContextSize(max_seq_len),
392        )?),
393        (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
394            block_size,
395            512,
396            MemoryGpuConfig::ContextSize(ctxt),
397        )?),
398        (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
399            block_size,
400            512,
401            MemoryGpuConfig::Utilization(f),
402        )?),
403        (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
404            block_size,
405            512,
406            MemoryGpuConfig::MbAmount(m),
407        )?),
408        (block_size, Some(_m), Some(f), None, true, false) => {
409            info!("Both memory size, and usage were specified, defaulting to the usage value.");
410            Some(PagedAttentionConfig::new(
411                block_size,
412                512,
413                MemoryGpuConfig::Utilization(f),
414            )?)
415        }
416        (block_size, Some(_m), None, Some(ctxt), true, false) => {
417            info!("All memory size and ctxt len, defaulting to the context len value.");
418            Some(PagedAttentionConfig::new(
419                block_size,
420                512,
421                MemoryGpuConfig::ContextSize(ctxt),
422            )?)
423        }
424        (block_size, None, Some(f), Some(_ctxt), true, false) => {
425            info!("Both ctxt len and usage were specified, defaulting to the usage value.");
426            Some(PagedAttentionConfig::new(
427                block_size,
428                512,
429                MemoryGpuConfig::Utilization(f),
430            )?)
431        }
432        (_, _, _, _, _, _) => None,
433    };
434
435    let pipeline = loader.load_model_from_hf(
436        None,
437        args.token_source,
438        &dtype,
439        &device,
440        false,
441        mapper,
442        args.in_situ_quant,
443        cache_config,
444    )?;
445    info!("Model loaded.");
446
447    let scheduler_config = if cache_config.is_some() {
448        // Handle case where we may have device mapping
449        if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
450            SchedulerConfig::PagedAttentionMeta {
451                max_num_seqs: args.max_seqs,
452                config: cache_config.clone(),
453            }
454        } else {
455            SchedulerConfig::DefaultScheduler {
456                method: DefaultSchedulerMethod::Fixed(args.max_seqs.try_into().unwrap()),
457            }
458        }
459    } else {
460        SchedulerConfig::DefaultScheduler {
461            method: DefaultSchedulerMethod::Fixed(args.max_seqs.try_into().unwrap()),
462        }
463    };
464    let bert_model = if args.enable_search {
465        Some(
466            args.search_bert_model
467                .map(BertEmbeddingModel::Custom)
468                .unwrap_or_default(),
469        )
470    } else {
471        None
472    };
473    // Throughput logging in the server
474    let mistralrs = MistralRsBuilder::new(
475        pipeline,
476        scheduler_config,
477        !args.interactive_mode,
478        bert_model,
479    )
480    .with_opt_log(args.log)
481    .with_truncate_sequence(args.truncate_sequence)
482    .with_no_kv_cache(args.no_kv_cache)
483    .with_prefix_cache_n(args.prefix_cache_n)
484    .build();
485
486    if args.interactive_mode {
487        interactive_mode(mistralrs, args.interactive_search).await;
488        return Ok(());
489    }
490
491    // Needs to be after the .build call as that is where the daemon waits.
492    let setting_server = if !args.interactive_mode {
493        let port = args.port.expect("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port`?");
494        let ip = args.serve_ip.unwrap_or_else(|| "0.0.0.0".to_string());
495
496        // Create listener early to validate address before model loading
497        let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?;
498        Some((listener, ip, port))
499    } else {
500        None
501    };
502
503    let app = get_router(mistralrs);
504    if let Some((listener, ip, port)) = setting_server {
505        info!("Serving on http://{ip}:{}.", port);
506        axum::serve(listener, app).await?;
507    };
508
509    Ok(())
510}