mistralrs_server/
main.rs

1use anyhow::Result;
2use axum::{
3    extract::{DefaultBodyLimit, Json, State},
4    http::{self, Method},
5    routing::{get, post},
6    Router,
7};
8use candle_core::Device;
9use clap::Parser;
10use mistralrs_core::{
11    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, initialize_logging,
12    paged_attn_supported, parse_isq_value, BertEmbeddingModel, DefaultSchedulerMethod,
13    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, IsqType, Loader, LoaderBuilder,
14    MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelSelected, PagedAttentionConfig, Request,
15    SchedulerConfig, TokenSource,
16};
17use openai::{
18    ChatCompletionRequest, CompletionRequest, ImageGenerationRequest, Message, ModelObjects,
19    StopTokens,
20};
21use serde::{Deserialize, Serialize};
22use std::{num::NonZeroUsize, sync::Arc};
23
24mod chat_completion;
25mod completions;
26mod image_generation;
27mod interactive_mode;
28mod openai;
29mod util;
30
31use crate::openai::ModelObject;
32use crate::{
33    chat_completion::{__path_chatcompletions, chatcompletions},
34    completions::completions,
35    image_generation::image_generation,
36};
37
38use interactive_mode::interactive_mode;
39use tower_http::cors::{AllowOrigin, CorsLayer};
40use tracing::{info, warn};
41use utoipa::{OpenApi, ToSchema};
42use utoipa_swagger_ui::SwaggerUi;
43
44// NOTE(EricLBuehler): Accept up to 50mb input
45const N_INPUT_SIZE: usize = 50;
46const MB_TO_B: usize = 1024 * 1024; // 1024 kb in a mb
47
48fn parse_token_source(s: &str) -> Result<TokenSource, String> {
49    s.parse()
50}
51
52#[derive(Parser)]
53#[command(version, about, long_about = None)]
54struct Args {
55    /// IP to serve on. Defaults to "0.0.0.0"
56    #[arg(long)]
57    serve_ip: Option<String>,
58
59    /// Integer seed to ensure reproducible random number generation.
60    #[arg(short, long)]
61    seed: Option<u64>,
62
63    /// Port to serve on.
64    #[arg(short, long)]
65    port: Option<String>,
66
67    /// Log all responses and requests to this file
68    #[clap(long, short)]
69    log: Option<String>,
70
71    /// If a sequence is larger than the maximum model length, truncate the number
72    /// of tokens such that the sequence will fit at most the maximum length.
73    /// If `max_tokens` is not specified in the request, space for 10 tokens will be reserved instead.
74    #[clap(long, short, action)]
75    truncate_sequence: bool,
76
77    /// Model selector
78    #[clap(subcommand)]
79    model: ModelSelected,
80
81    /// Maximum running sequences at any time. If the `tgt_non_granular_index` flag is set for X-LoRA models, this will be set to 1.
82    #[arg(long, default_value_t = 16)]
83    max_seqs: usize,
84
85    /// Use no KV cache.
86    #[arg(long, default_value_t = false)]
87    no_kv_cache: bool,
88
89    /// Chat template file with a JINJA file with `messages`, `add_generation_prompt`, `bos_token`, `eos_token`, and `unk_token` as inputs.
90    /// Used if the automatic deserialization fails. If this ends with `.json` (ie., it is a file) then that template is loaded.
91    #[arg(short, long)]
92    chat_template: Option<String>,
93
94    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
95    #[arg(short, long)]
96    jinja_explicit: Option<String>,
97
98    /// Source of the token for authentication.
99    /// Can be in the formats: `literal:<value>`, `env:<value>`, `path:<value>`, `cache` to use a cached token, or `none` to use no token.
100    /// Defaults to `cache`.
101    #[arg(long, default_value_t = TokenSource::CacheToken, value_parser = parse_token_source)]
102    token_source: TokenSource,
103
104    /// Enter interactive mode instead of serving a chat server.
105    #[clap(long, short, action)]
106    interactive_mode: bool,
107
108    /// Number of prefix caches to hold on the device. Other caches are evicted to the CPU based on a LRU strategy.
109    #[arg(long, default_value_t = 16)]
110    prefix_cache_n: usize,
111
112    /// NOTE: This can be omitted to use automatic device mapping!
113    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
114    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
115    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
116    #[arg(short, long, value_parser, value_delimiter = ';')]
117    num_device_layers: Option<Vec<String>>,
118
119    /// In-situ quantization to apply.
120    #[arg(long = "isq", value_parser = parse_isq_value)]
121    in_situ_quant: Option<IsqType>,
122
123    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
124    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
125    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
126    #[arg(long = "pa-gpu-mem")]
127    paged_attn_gpu_mem: Option<usize>,
128
129    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
130    /// If this is not set and the device is CUDA, it will default to `0.9`.
131    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
132    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
133    #[arg(long = "pa-gpu-mem-usage")]
134    paged_attn_gpu_mem_usage: Option<f32>,
135
136    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
137    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
138    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
139    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
140    #[arg(long = "pa-ctxt-len")]
141    paged_ctxt_len: Option<usize>,
142
143    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
144    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
145    #[arg(long = "pa-blk-size")]
146    paged_attn_block_size: Option<usize>,
147
148    /// Disable PagedAttention on CUDA. Because PagedAttention is already disabled on Metal, this is only applicable on CUDA.
149    #[arg(long = "no-paged-attn", default_value_t = false)]
150    no_paged_attn: bool,
151
152    /// Enable PagedAttention on Metal. Because PagedAttention is already enabled on CUDA, this is only applicable on Metal.
153    #[arg(long = "paged-attn", default_value_t = false)]
154    paged_attn: bool,
155
156    /// Enable server throughput logging, supported in the server and with interactive mode
157    #[arg(long = "throughput", default_value_t = false)]
158    throughput_log: bool,
159
160    /// Number of tokens to batch the prompt step into. This can help with OOM errors when in the prompt step, but reduces performance.
161    #[arg(long = "prompt-batchsize")]
162    prompt_chunksize: Option<usize>,
163
164    /// Use CPU only
165    #[arg(long)]
166    cpu: bool,
167
168    /// Enable web searching for interactive mode.
169    #[arg(long = "interactive-search")]
170    interactive_search: bool,
171
172    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified below or the default.
173    #[arg(long = "enable-search")]
174    enable_search: bool,
175
176    /// Specify a Hugging Face model ID for a BERT model to assist web searching. Defaults to Snowflake Arctic Embed L.
177    #[arg(long = "search-bert-model")]
178    search_bert_model: Option<String>,
179}
180
181#[utoipa::path(
182    get,
183    tag = "Mistral.rs",
184    path = "/v1/models",
185    responses((status = 200, description = "Served model info", body = ModelObjects))
186)]
187async fn models(State(state): State<Arc<MistralRs>>) -> Json<ModelObjects> {
188    Json(ModelObjects {
189        object: "list",
190        data: vec![ModelObject {
191            id: state.get_id(),
192            object: "model",
193            created: state.get_creation_time(),
194            owned_by: "local",
195        }],
196    })
197}
198
199#[utoipa::path(
200    get,
201    tag = "Mistral.rs",
202    path = "/health",
203    responses((status = 200, description = "Server is healthy"))
204)]
205async fn health() -> &'static str {
206    "OK"
207}
208
209#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
210struct AdapterActivationRequest {
211    #[schema(example = json!(vec!["adapter_1","adapter_2"]))]
212    adapter_names: Vec<String>,
213}
214
215#[utoipa::path(
216    post,
217    tag = "Mistral.rs",
218    path = "/activate_adapters",
219    request_body = AdapterActivationRequest,
220    responses((status = 200, description = "Activate a set of pre-loaded LoRA adapters"))
221)]
222async fn activate_adapters(
223    State(state): State<Arc<MistralRs>>,
224    Json(request): Json<AdapterActivationRequest>,
225) -> String {
226    let repr = format!("Adapter activation: {:?}", request.adapter_names);
227    MistralRs::maybe_log_request(state.clone(), repr.clone());
228    let request = Request::ActivateAdapters(request.adapter_names);
229    state.get_sender().unwrap().send(request).await.unwrap();
230    repr
231}
232
233#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
234struct ReIsqRequest {
235    #[schema(example = "Q4K")]
236    ggml_type: String,
237}
238
239#[utoipa::path(
240    post,
241    tag = "Mistral.rs",
242    path = "/re_isq",
243    request_body = ReIsqRequest,
244    responses((status = 200, description = "Reapply ISQ to a non GGUF or GGML model."))
245)]
246async fn re_isq(
247    State(state): State<Arc<MistralRs>>,
248    Json(request): Json<ReIsqRequest>,
249) -> Result<String, String> {
250    let repr = format!("Re ISQ: {:?}", request.ggml_type);
251    MistralRs::maybe_log_request(state.clone(), repr.clone());
252    let request = Request::ReIsq(parse_isq_value(&request.ggml_type)?);
253    state.get_sender().unwrap().send(request).await.unwrap();
254    Ok(repr)
255}
256
257fn get_router(state: Arc<MistralRs>) -> Router {
258    #[derive(OpenApi)]
259    #[openapi(
260        paths(models, health, chatcompletions),
261        components(
262            schemas(ModelObjects, ModelObject, ChatCompletionRequest, CompletionRequest, ImageGenerationRequest, StopTokens, Message)),
263        tags(
264            (name = "Mistral.rs", description = "Mistral.rs API")
265        ),
266        info(
267            title = "Mistral.rs",
268            license(
269            name = "MIT",
270        )
271        )
272    )]
273    struct ApiDoc;
274
275    let doc = { ApiDoc::openapi() };
276
277    let allow_origin = AllowOrigin::any();
278    let cors_layer = CorsLayer::new()
279        .allow_methods([Method::GET, Method::POST])
280        .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
281        .allow_origin(allow_origin);
282
283    Router::new()
284        .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc))
285        .route("/v1/chat/completions", post(chatcompletions))
286        .route("/v1/completions", post(completions))
287        .route("/v1/models", get(models))
288        .route("/health", get(health))
289        .route("/", get(health))
290        .route("/activate_adapters", post(activate_adapters))
291        .route("/re_isq", post(re_isq))
292        .route("/v1/images/generations", post(image_generation))
293        .layer(cors_layer)
294        .layer(DefaultBodyLimit::max(N_INPUT_SIZE * MB_TO_B))
295        .with_state(state)
296}
297
298#[tokio::main]
299async fn main() -> Result<()> {
300    let mut args = Args::parse();
301    initialize_logging();
302
303    let use_flash_attn = mistralrs_core::using_flash_attn();
304
305    let tgt_non_granular_index = get_tgt_non_granular_index(&args.model);
306    let dtype = get_model_dtype(&args.model)?;
307    let auto_device_map_params = get_auto_device_map_params(&args.model)?;
308
309    if tgt_non_granular_index.is_some() {
310        args.max_seqs = 1;
311    }
312
313    let prompt_chunksize = match args.prompt_chunksize {
314        Some(0) => {
315            anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
316        }
317        Some(x) => Some(NonZeroUsize::new(x).unwrap()),
318        None => None,
319    };
320
321    let max_seq_len = auto_device_map_params.max_seq_len();
322
323    let loader: Box<dyn Loader> = LoaderBuilder::new(args.model)
324        .with_no_kv_cache(args.no_kv_cache)
325        .with_chat_template(args.chat_template)
326        .with_use_flash_attn(use_flash_attn)
327        .with_prompt_chunksize(prompt_chunksize)
328        .with_jinja_explicit(args.jinja_explicit)
329        .build()?;
330
331    #[cfg(feature = "metal")]
332    let device = Device::new_metal(0)?;
333    #[cfg(not(feature = "metal"))]
334    let device = if args.cpu {
335        args.no_paged_attn = true;
336        Device::Cpu
337    } else if mistralrs_core::distributed::use_nccl() {
338        Device::Cpu
339    } else {
340        Device::cuda_if_available(0)?
341    };
342
343    if let Some(seed) = args.seed {
344        device.set_seed(seed)?;
345    }
346
347    info!(
348        "avx: {}, neon: {}, simd128: {}, f16c: {}",
349        candle_core::utils::with_avx(),
350        candle_core::utils::with_neon(),
351        candle_core::utils::with_simd128(),
352        candle_core::utils::with_f16c()
353    );
354    info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
355    if use_flash_attn {
356        info!("Using flash attention.");
357    }
358    if use_flash_attn && loader.get_kind().is_quantized() {
359        warn!("Using flash attention with a quantized model has no effect!")
360    }
361    info!("Model kind is: {}", loader.get_kind().to_string());
362
363    // Parse device mapper
364    let mapper = if let Some(device_layers) = args.num_device_layers {
365        if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
366            let layers = device_layers[0].parse::<usize>().unwrap();
367            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
368                DeviceLayerMapMetadata { ordinal: 0, layers },
369            ]))
370        } else {
371            let mut mapping = Vec::new();
372            for layer in device_layers {
373                let split = layer.splitn(2, ':').collect::<Vec<_>>();
374                if split.len() < 2 {
375                    panic!("Expected layer to be of format ORD:NUM, got {layer}");
376                }
377                let ord = split[0]
378                    .parse::<usize>()
379                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
380                let num = split[1]
381                    .parse::<usize>()
382                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
383                for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
384                    if *ordinal == ord {
385                        panic!("Duplicate ordinal {ord}");
386                    }
387                }
388                mapping.push(DeviceLayerMapMetadata {
389                    ordinal: ord,
390                    layers: num,
391                });
392            }
393            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
394        }
395    } else {
396        DeviceMapSetting::Auto(auto_device_map_params)
397    };
398
399    let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
400        args.no_paged_attn
401    } else if device.is_metal() {
402        !args.paged_attn
403    } else {
404        true
405    };
406
407    // Allocate 0.5 GB of CPU memory just as a placeholder.
408    // Nothing happens here as we have no `swap_out`, see `_preempt_by_swap`.
409    let cache_config = match (
410        args.paged_attn_block_size,
411        args.paged_attn_gpu_mem,
412        args.paged_attn_gpu_mem_usage,
413        args.paged_ctxt_len,
414        paged_attn_supported(),
415        no_paged_attn,
416    ) {
417        (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
418            block_size,
419            512,
420            MemoryGpuConfig::ContextSize(max_seq_len),
421        )?),
422        (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
423            block_size,
424            512,
425            MemoryGpuConfig::ContextSize(ctxt),
426        )?),
427        (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
428            block_size,
429            512,
430            MemoryGpuConfig::Utilization(f),
431        )?),
432        (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
433            block_size,
434            512,
435            MemoryGpuConfig::MbAmount(m),
436        )?),
437        (block_size, Some(_m), Some(f), None, true, false) => {
438            info!("Both memory size, and usage were specified, defaulting to the usage value.");
439            Some(PagedAttentionConfig::new(
440                block_size,
441                512,
442                MemoryGpuConfig::Utilization(f),
443            )?)
444        }
445        (block_size, Some(_m), None, Some(ctxt), true, false) => {
446            info!("All memory size and ctxt len, defaulting to the context len value.");
447            Some(PagedAttentionConfig::new(
448                block_size,
449                512,
450                MemoryGpuConfig::ContextSize(ctxt),
451            )?)
452        }
453        (block_size, None, Some(f), Some(_ctxt), true, false) => {
454            info!("Both ctxt len and usage were specified, defaulting to the usage value.");
455            Some(PagedAttentionConfig::new(
456                block_size,
457                512,
458                MemoryGpuConfig::Utilization(f),
459            )?)
460        }
461        (_, _, _, _, _, _) => None,
462    };
463
464    let pipeline = loader.load_model_from_hf(
465        None,
466        args.token_source,
467        &dtype,
468        &device,
469        false,
470        mapper,
471        args.in_situ_quant,
472        cache_config,
473    )?;
474    info!("Model loaded.");
475
476    let scheduler_config = if cache_config.is_some() {
477        // Handle case where we may have device mapping
478        if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
479            SchedulerConfig::PagedAttentionMeta {
480                max_num_seqs: args.max_seqs,
481                config: cache_config.clone(),
482            }
483        } else {
484            SchedulerConfig::DefaultScheduler {
485                method: DefaultSchedulerMethod::Fixed(args.max_seqs.try_into().unwrap()),
486            }
487        }
488    } else {
489        SchedulerConfig::DefaultScheduler {
490            method: DefaultSchedulerMethod::Fixed(args.max_seqs.try_into().unwrap()),
491        }
492    };
493    let bert_model = if args.enable_search {
494        Some(
495            args.search_bert_model
496                .map(BertEmbeddingModel::Custom)
497                .unwrap_or_default(),
498        )
499    } else {
500        None
501    };
502    // Throughput logging in the server
503    let mistralrs = MistralRsBuilder::new(
504        pipeline,
505        scheduler_config,
506        !args.interactive_mode,
507        bert_model,
508    )
509    .with_opt_log(args.log)
510    .with_truncate_sequence(args.truncate_sequence)
511    .with_no_kv_cache(args.no_kv_cache)
512    .with_prefix_cache_n(args.prefix_cache_n)
513    .build();
514
515    if args.interactive_mode {
516        interactive_mode(mistralrs, args.throughput_log, args.interactive_search).await;
517        return Ok(());
518    }
519
520    // Needs to be after the .build call as that is where the daemon waits.
521    let setting_server = if !args.interactive_mode {
522        let port = args.port.expect("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port`?");
523        let ip = args.serve_ip.unwrap_or_else(|| "0.0.0.0".to_string());
524
525        // Create listener early to validate address before model loading
526        let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?;
527        Some((listener, ip, port))
528    } else {
529        None
530    };
531
532    let app = get_router(mistralrs);
533    if let Some((listener, ip, port)) = setting_server {
534        info!("Serving on http://{ip}:{}.", port);
535        axum::serve(listener, app).await?;
536    };
537
538    Ok(())
539}