mistralrs_server/
main.rs

1use anyhow::Result;
2use clap::Parser;
3use mistralrs_core::{
4    initialize_logging, McpClientConfig, ModelSelected, PagedCacheType, TokenSource,
5};
6use rust_mcp_sdk::schema::LATEST_PROTOCOL_VERSION;
7use std::collections::HashMap;
8use tokio::join;
9use tracing::{error, info};
10
11use mistralrs_server_core::{
12    mistralrs_for_server_builder::{
13        configure_paged_attn_from_flags, defaults, get_bert_model, MistralRsForServerBuilder,
14        ModelConfig,
15    },
16    mistralrs_server_router_builder::MistralRsServerRouterBuilder,
17};
18
19mod interactive_mode;
20use interactive_mode::interactive_mode;
21mod mcp_server;
22
23#[derive(Parser)]
24#[command(version, about, long_about = None)]
25struct Args {
26    /// IP to serve on. Defaults to "0.0.0.0"
27    #[arg(long)]
28    serve_ip: Option<String>,
29
30    /// Integer seed to ensure reproducible random number generation.
31    #[arg(short, long)]
32    seed: Option<u64>,
33
34    /// Port to serve on.
35    #[arg(short, long)]
36    port: Option<u16>,
37
38    /// Log all responses and requests to this file
39    #[clap(long, short)]
40    log: Option<String>,
41
42    /// If a sequence is larger than the maximum model length, truncate the number
43    /// of tokens such that the sequence will fit at most the maximum length.
44    /// If `max_tokens` is not specified in the request, space for 10 tokens will be reserved instead.
45    #[clap(long, short, action)]
46    truncate_sequence: bool,
47
48    /// Model selector
49    #[clap(subcommand)]
50    model: ModelSelected,
51
52    /// Maximum running sequences at any time. If the `tgt_non_granular_index` flag is set for X-LoRA models, this will be set to 1.
53    #[arg(long, default_value_t = defaults::MAX_SEQS)]
54    max_seqs: usize,
55
56    /// Use no KV cache.
57    #[arg(long, default_value_t = defaults::NO_KV_CACHE)]
58    no_kv_cache: bool,
59
60    /// Chat template file with a JINJA file with `messages`, `add_generation_prompt`, `bos_token`, `eos_token`, and `unk_token` as inputs.
61    /// Used if the automatic deserialization fails. If this ends with `.json` (ie., it is a file) then that template is loaded.
62    #[arg(short, long)]
63    chat_template: Option<String>,
64
65    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
66    #[arg(short, long)]
67    jinja_explicit: Option<String>,
68
69    /// Source of the token for authentication.
70    /// Can be in the formats: `literal:<value>`, `env:<value>`, `path:<value>`, `cache` to use a cached token, or `none` to use no token.
71    /// Defaults to `cache`.
72    #[arg(long, default_value_t = defaults::TOKEN_SOURCE, value_parser = parse_token_source)]
73    token_source: TokenSource,
74
75    /// Enter interactive mode instead of serving a chat server.
76    #[clap(long, short, action)]
77    interactive_mode: bool,
78
79    /// Number of prefix caches to hold on the device. Other caches are evicted to the CPU based on a LRU strategy.
80    #[arg(long, default_value_t = defaults::PREFIX_CACHE_N)]
81    prefix_cache_n: usize,
82
83    /// NOTE: This can be omitted to use automatic device mapping!
84    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
85    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
86    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
87    #[arg(short, long, value_parser, value_delimiter = ';')]
88    num_device_layers: Option<Vec<String>>,
89
90    /// In-situ quantization to apply.
91    #[arg(long = "isq")]
92    in_situ_quant: Option<String>,
93
94    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
95    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
96    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
97    #[arg(long = "pa-gpu-mem")]
98    paged_attn_gpu_mem: Option<usize>,
99
100    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
101    /// If this is not set and the device is CUDA, it will default to `0.9`.
102    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
103    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
104    #[arg(long = "pa-gpu-mem-usage")]
105    paged_attn_gpu_mem_usage: Option<f32>,
106
107    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
108    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
109    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
110    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
111    #[arg(long = "pa-ctxt-len")]
112    paged_ctxt_len: Option<usize>,
113
114    /// PagedAttention KV cache type (auto or f8e4m3).
115    /// Defaults to `auto`.
116    #[arg(long = "pa-cache-type", value_parser = parse_cache_type)]
117    cache_type: Option<PagedCacheType>,
118
119    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
120    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
121    #[arg(long = "pa-blk-size")]
122    paged_attn_block_size: Option<usize>,
123
124    /// Disable PagedAttention on CUDA. Because PagedAttention is already disabled on Metal, this is only applicable on CUDA.
125    #[arg(
126        long = "no-paged-attn",
127        default_value_t = false,
128        conflicts_with = "paged_attn"
129    )]
130    no_paged_attn: bool,
131
132    /// Enable PagedAttention on Metal. Because PagedAttention is already enabled on CUDA, this is only applicable on Metal.
133    #[arg(
134        long = "paged-attn",
135        default_value_t = false,
136        conflicts_with_all = ["no_paged_attn", "cpu"]
137    )]
138    paged_attn: bool,
139
140    /// Number of tokens to batch the prompt step into. This can help with OOM errors when in the prompt step, but reduces performance.
141    #[arg(long = "prompt-batchsize")]
142    prompt_chunksize: Option<usize>,
143
144    /// Use CPU only
145    #[arg(long)]
146    cpu: bool,
147
148    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified below or the default.
149    #[arg(long = "enable-search")]
150    enable_search: bool,
151
152    /// Specify a Hugging Face model ID for a BERT model to assist web searching. Defaults to Snowflake Arctic Embed L.
153    #[arg(long = "search-bert-model")]
154    search_bert_model: Option<String>,
155
156    /// Enable thinking for interactive mode and models that support it.
157    #[arg(long = "enable-thinking")]
158    enable_thinking: bool,
159
160    /// Port to serve MCP protocol on
161    #[arg(long)]
162    mcp_port: Option<u16>,
163
164    /// MCP client configuration file path
165    #[arg(long)]
166    mcp_config: Option<String>,
167}
168
169fn parse_token_source(s: &str) -> Result<TokenSource, String> {
170    s.parse()
171}
172
173fn parse_cache_type(s: &str) -> Result<PagedCacheType, String> {
174    s.parse()
175}
176
177/// Load MCP configuration from file path or environment variable
178fn load_mcp_config(mcp_config_path: Option<&str>) -> Result<Option<McpClientConfig>> {
179    let config_path = if let Some(path) = mcp_config_path {
180        Some(path.to_string())
181    } else {
182        // Check environment variable if no CLI arg provided
183        std::env::var("MCP_CONFIG_PATH").ok()
184    };
185
186    if let Some(path) = config_path {
187        match std::fs::read_to_string(&path) {
188            Ok(config_content) => {
189                match serde_json::from_str::<McpClientConfig>(&config_content) {
190                    Ok(config) => {
191                        // Validate configuration
192                        if let Err(e) = validate_mcp_config(&config) {
193                            error!("MCP configuration validation failed: {}", e);
194                            anyhow::bail!("Invalid MCP configuration: {}", e);
195                        }
196
197                        info!("Loaded and validated MCP configuration from {}", path);
198                        info!("Configured {} MCP servers", config.servers.len());
199                        Ok(Some(config))
200                    }
201                    Err(e) => {
202                        error!("Failed to parse MCP configuration: {}", e);
203                        error!("Please check your JSON syntax and ensure it matches the MCP configuration schema");
204                        anyhow::bail!("Invalid MCP configuration format: {}", e);
205                    }
206                }
207            }
208            Err(e) => {
209                error!("Failed to read MCP configuration file {}: {}", path, e);
210                error!("Please ensure the file exists and is readable");
211                anyhow::bail!("Cannot read MCP configuration file: {}", e);
212            }
213        }
214    } else {
215        Ok(None)
216    }
217}
218
219/// Validate MCP configuration for common issues
220fn validate_mcp_config(config: &McpClientConfig) -> Result<()> {
221    use std::collections::HashSet;
222
223    // Check for duplicate server IDs
224    let mut seen_ids = HashSet::new();
225    for server in &config.servers {
226        if !seen_ids.insert(&server.id) {
227            anyhow::bail!("Duplicate server ID: {}", server.id);
228        }
229
230        // Validate server ID format
231        if !server
232            .id
233            .chars()
234            .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
235        {
236            anyhow::bail!(
237                "Invalid server ID '{}': must contain only alphanumeric, hyphen, underscore",
238                server.id
239            );
240        }
241
242        // Validate URLs for HTTP/WebSocket sources
243        match &server.source {
244            mistralrs_core::McpServerSource::Http { url, .. }
245            | mistralrs_core::McpServerSource::WebSocket { url, .. } => {
246                // Basic URL validation - check for scheme
247                if !url.starts_with("http://")
248                    && !url.starts_with("https://")
249                    && !url.starts_with("ws://")
250                    && !url.starts_with("wss://")
251                {
252                    anyhow::bail!("Invalid URL for server '{}': must start with http://, https://, ws://, or wss://", server.id);
253                }
254                if url.len() < 10 {
255                    anyhow::bail!("Invalid URL for server '{}': URL too short", server.id);
256                }
257            }
258            mistralrs_core::McpServerSource::Process { command, .. } => {
259                if command.is_empty() {
260                    anyhow::bail!("Empty command for server '{}'", server.id);
261                }
262            }
263        }
264    }
265
266    // Validate global settings
267    if let Some(timeout) = config.tool_timeout_secs {
268        if timeout == 0 {
269            anyhow::bail!("tool_timeout_secs must be greater than 0");
270        }
271    }
272
273    if let Some(max_calls) = config.max_concurrent_calls {
274        if max_calls == 0 {
275            anyhow::bail!("max_concurrent_calls must be greater than 0");
276        }
277    }
278
279    Ok(())
280}
281
282/// Configuration for a single model in a multi-model setup (parsing format)
283#[derive(Clone, serde::Deserialize)]
284struct ModelConfigParsed {
285    /// Model selector
286    #[serde(flatten)]
287    model: ModelSelected,
288    /// Model-specific chat template
289    chat_template: Option<String>,
290    /// Model-specific JINJA template
291    jinja_explicit: Option<String>,
292    /// Model-specific device layers
293    num_device_layers: Option<Vec<String>>,
294    /// Model-specific in-situ quantization
295    in_situ_quant: Option<String>,
296}
297
298/// Load multi-model configuration from file
299fn load_multi_model_config(config_path: &str) -> Result<Vec<ModelConfig>> {
300    let config_content = std::fs::read_to_string(config_path).map_err(|e| {
301        anyhow::anyhow!(
302            "Failed to read multi-model config file {}: {}",
303            config_path,
304            e
305        )
306    })?;
307
308    let configs_parsed: HashMap<String, ModelConfigParsed> = serde_json::from_str(&config_content)
309        .map_err(|e| anyhow::anyhow!("Failed to parse multi-model config: {}", e))?;
310
311    if configs_parsed.is_empty() {
312        anyhow::bail!("Multi-model configuration file is empty");
313    }
314
315    let mut configs = Vec::new();
316    for (model_id, parsed_config) in configs_parsed {
317        let config = ModelConfig {
318            model_id,
319            model: parsed_config.model,
320            chat_template: parsed_config.chat_template,
321            jinja_explicit: parsed_config.jinja_explicit,
322            num_device_layers: parsed_config.num_device_layers,
323            in_situ_quant: parsed_config.in_situ_quant,
324        };
325        configs.push(config);
326    }
327
328    info!(
329        "Loaded multi-model configuration with {} models",
330        configs.len()
331    );
332    Ok(configs)
333}
334
335#[tokio::main]
336async fn main() -> Result<()> {
337    let args = Args::parse();
338
339    initialize_logging();
340
341    // Load MCP configuration if provided
342    let mcp_config = load_mcp_config(args.mcp_config.as_deref())?;
343
344    let paged_attn = configure_paged_attn_from_flags(args.paged_attn, args.no_paged_attn)?;
345
346    let mistralrs = match args.model {
347        ModelSelected::MultiModel {
348            config,
349            default_model_id,
350        } => {
351            // Multi-model mode
352            let model_configs = load_multi_model_config(&config)?;
353
354            let mut builder = MistralRsForServerBuilder::new()
355                .with_truncate_sequence(args.truncate_sequence)
356                .with_max_seqs(args.max_seqs)
357                .with_no_kv_cache(args.no_kv_cache)
358                .with_token_source(args.token_source)
359                .with_interactive_mode(args.interactive_mode)
360                .with_prefix_cache_n(args.prefix_cache_n)
361                .set_paged_attn(paged_attn)
362                .with_cpu(args.cpu)
363                .with_enable_search(args.enable_search)
364                .with_seed_optional(args.seed)
365                .with_log_optional(args.log)
366                .with_prompt_chunksize_optional(args.prompt_chunksize)
367                .with_mcp_config_optional(mcp_config)
368                .with_paged_attn_cache_type(args.cache_type.unwrap_or_default());
369
370            // Add models to builder
371            for config in model_configs {
372                builder = builder.add_model_config(config);
373            }
374
375            // Set default model if specified
376            if let Some(default_id) = default_model_id {
377                builder = builder.with_default_model_id(default_id);
378            }
379
380            builder.build_multi_model().await?
381        }
382        model => {
383            // Single-model mode
384            MistralRsForServerBuilder::new()
385                .with_truncate_sequence(args.truncate_sequence)
386                .with_model(model)
387                .with_max_seqs(args.max_seqs)
388                .with_no_kv_cache(args.no_kv_cache)
389                .with_token_source(args.token_source)
390                .with_interactive_mode(args.interactive_mode)
391                .with_prefix_cache_n(args.prefix_cache_n)
392                .set_paged_attn(paged_attn)
393                .with_cpu(args.cpu)
394                .with_enable_search(args.enable_search)
395                .with_seed_optional(args.seed)
396                .with_log_optional(args.log)
397                .with_chat_template_optional(args.chat_template)
398                .with_jinja_explicit_optional(args.jinja_explicit)
399                .with_num_device_layers_optional(args.num_device_layers)
400                .with_in_situ_quant_optional(args.in_situ_quant)
401                .with_paged_attn_gpu_mem_optional(args.paged_attn_gpu_mem)
402                .with_paged_attn_gpu_mem_usage_optional(args.paged_attn_gpu_mem_usage)
403                .with_paged_ctxt_len_optional(args.paged_ctxt_len)
404                .with_paged_attn_block_size_optional(args.paged_attn_block_size)
405                .with_prompt_chunksize_optional(args.prompt_chunksize)
406                .with_mcp_config_optional(mcp_config)
407                .with_paged_attn_cache_type(args.cache_type.unwrap_or_default())
408                .build()
409                .await?
410        }
411    };
412
413    // TODO: refactor this
414    let bert_model = get_bert_model(args.enable_search, args.search_bert_model);
415
416    if args.interactive_mode {
417        interactive_mode(
418            mistralrs,
419            bert_model.is_some(),
420            args.enable_thinking.then_some(true),
421        )
422        .await;
423        return Ok(());
424    }
425
426    if !args.interactive_mode && args.port.is_none() && args.mcp_port.is_none() {
427        anyhow::bail!("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port` or `--mcp-port`?")
428    }
429
430    let mcp_port = if let Some(port) = args.mcp_port {
431        let host = args
432            .serve_ip
433            .clone()
434            .unwrap_or_else(|| "0.0.0.0".to_string());
435        info!("MCP server listening on http://{host}:{port}/mcp.");
436        info!("MCP protocol version is {}.", LATEST_PROTOCOL_VERSION);
437        let mcp_server = mcp_server::create_http_mcp_server(mistralrs.clone(), host, port);
438
439        tokio::spawn(async move {
440            if let Err(e) = mcp_server.await {
441                eprintln!("MCP server error: {e}");
442            }
443        })
444    } else {
445        tokio::spawn(async {})
446    };
447
448    let oai_port = if let Some(port) = args.port {
449        let ip = args
450            .serve_ip
451            .clone()
452            .unwrap_or_else(|| "0.0.0.0".to_string());
453
454        // Create listener early to validate address before model loading
455        let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?;
456
457        let app = MistralRsServerRouterBuilder::new()
458            .with_mistralrs(mistralrs)
459            .build()
460            .await?;
461
462        info!("OpenAI-compatible server listening on http://{ip}:{port}.");
463
464        tokio::spawn(async move {
465            if let Err(e) = axum::serve(listener, app).await {
466                eprintln!("OpenAI server error: {e}");
467            }
468        })
469    } else {
470        tokio::spawn(async {})
471    };
472
473    let (_, _) = join!(oai_port, mcp_port);
474
475    Ok(())
476}