mistralrs_server/
main.rs

1use anyhow::Result;
2use clap::Parser;
3use mistralrs_core::{
4    initialize_logging, McpClientConfig, ModelSelected, PagedCacheType, SearchEmbeddingModel,
5    TokenSource,
6};
7use rust_mcp_sdk::schema::LATEST_PROTOCOL_VERSION;
8use std::collections::HashMap;
9use tokio::join;
10use tracing::{error, info};
11
12use mistralrs_server_core::{
13    mistralrs_for_server_builder::{
14        configure_paged_attn_from_flags, defaults, get_search_embedding_model,
15        MistralRsForServerBuilder, ModelConfig,
16    },
17    mistralrs_server_router_builder::MistralRsServerRouterBuilder,
18};
19
20mod interactive_mode;
21use interactive_mode::interactive_mode;
22mod mcp_server;
23
24#[derive(Parser)]
25#[command(version, about, long_about = None)]
26struct Args {
27    /// IP to serve on. Defaults to "0.0.0.0"
28    #[arg(long)]
29    serve_ip: Option<String>,
30
31    /// Integer seed to ensure reproducible random number generation.
32    #[arg(short, long)]
33    seed: Option<u64>,
34
35    /// Port to serve on.
36    #[arg(short, long)]
37    port: Option<u16>,
38
39    /// Log all responses and requests to this file
40    #[clap(long, short)]
41    log: Option<String>,
42
43    /// Model selector
44    #[clap(subcommand)]
45    model: ModelSelected,
46
47    /// Maximum running sequences at any time. If the `tgt_non_granular_index` flag is set for X-LoRA models, this will be set to 1.
48    #[arg(long, default_value_t = defaults::MAX_SEQS)]
49    max_seqs: usize,
50
51    /// Use no KV cache.
52    #[arg(long, default_value_t = defaults::NO_KV_CACHE)]
53    no_kv_cache: bool,
54
55    /// Chat template file with a JINJA file with `messages`, `add_generation_prompt`, `bos_token`, `eos_token`, and `unk_token` as inputs.
56    /// Used if the automatic deserialization fails. If this ends with `.json` (ie., it is a file) then that template is loaded.
57    #[arg(short, long)]
58    chat_template: Option<String>,
59
60    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
61    #[arg(short, long)]
62    jinja_explicit: Option<String>,
63
64    /// Source of the token for authentication.
65    /// Can be in the formats: `literal:<value>`, `env:<value>`, `path:<value>`, `cache` to use a cached token, or `none` to use no token.
66    /// Defaults to `cache`.
67    #[arg(long, default_value_t = defaults::TOKEN_SOURCE, value_parser = parse_token_source)]
68    token_source: TokenSource,
69
70    /// Enter interactive mode instead of serving a chat server.
71    #[clap(long, short, action)]
72    interactive_mode: bool,
73
74    /// Number of prefix caches to hold on the device. Other caches are evicted to the CPU based on a LRU strategy.
75    #[arg(long, default_value_t = defaults::PREFIX_CACHE_N)]
76    prefix_cache_n: usize,
77
78    /// NOTE: This can be omitted to use automatic device mapping!
79    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
80    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
81    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
82    #[arg(short, long, value_parser, value_delimiter = ';')]
83    num_device_layers: Option<Vec<String>>,
84
85    /// In-situ quantization to apply.
86    #[arg(long = "isq")]
87    in_situ_quant: Option<String>,
88
89    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
90    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
91    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
92    #[arg(long = "pa-gpu-mem")]
93    paged_attn_gpu_mem: Option<usize>,
94
95    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
96    /// If this is not set and the device is CUDA, it will default to `0.9`.
97    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
98    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
99    #[arg(long = "pa-gpu-mem-usage")]
100    paged_attn_gpu_mem_usage: Option<f32>,
101
102    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
103    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
104    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
105    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
106    #[arg(long = "pa-ctxt-len")]
107    paged_ctxt_len: Option<usize>,
108
109    /// PagedAttention KV cache type (auto or f8e4m3).
110    /// Defaults to `auto`.
111    #[arg(long = "pa-cache-type", value_parser = parse_cache_type)]
112    cache_type: Option<PagedCacheType>,
113
114    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
115    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
116    #[arg(long = "pa-blk-size")]
117    paged_attn_block_size: Option<usize>,
118
119    /// Disable PagedAttention on CUDA. Because PagedAttention is already disabled on Metal, this is only applicable on CUDA.
120    #[arg(
121        long = "no-paged-attn",
122        default_value_t = false,
123        conflicts_with = "paged_attn"
124    )]
125    no_paged_attn: bool,
126
127    /// Enable PagedAttention on Metal. Because PagedAttention is already enabled on CUDA, this is only applicable on Metal.
128    #[arg(
129        long = "paged-attn",
130        default_value_t = false,
131        conflicts_with_all = ["no_paged_attn", "cpu"]
132    )]
133    paged_attn: bool,
134
135    /// Use CPU only
136    #[arg(long)]
137    cpu: bool,
138
139    /// Enable searching compatible with the OpenAI `web_search_options` setting. This loads the selected search embedding model for reranking web results.
140    #[arg(long = "enable-search")]
141    enable_search: bool,
142
143    /// Select which built-in search embedding model to load (e.g., `embedding_gemma`).
144    #[arg(long = "search-embedding-model")]
145    search_embedding_model: Option<SearchEmbeddingModel>,
146
147    /// Enable thinking for interactive mode and models that support it.
148    #[arg(long = "enable-thinking")]
149    enable_thinking: bool,
150
151    /// Port to serve MCP protocol on
152    #[arg(long)]
153    mcp_port: Option<u16>,
154
155    /// MCP client configuration file path
156    #[arg(long)]
157    mcp_config: Option<String>,
158}
159
160fn parse_token_source(s: &str) -> Result<TokenSource, String> {
161    s.parse()
162}
163
164fn parse_cache_type(s: &str) -> Result<PagedCacheType, String> {
165    s.parse()
166}
167
168/// Load MCP configuration from file path or environment variable
169fn load_mcp_config(mcp_config_path: Option<&str>) -> Result<Option<McpClientConfig>> {
170    let config_path = if let Some(path) = mcp_config_path {
171        Some(path.to_string())
172    } else {
173        // Check environment variable if no CLI arg provided
174        std::env::var("MCP_CONFIG_PATH").ok()
175    };
176
177    if let Some(path) = config_path {
178        match std::fs::read_to_string(&path) {
179            Ok(config_content) => {
180                match serde_json::from_str::<McpClientConfig>(&config_content) {
181                    Ok(config) => {
182                        // Validate configuration
183                        if let Err(e) = validate_mcp_config(&config) {
184                            error!("MCP configuration validation failed: {}", e);
185                            anyhow::bail!("Invalid MCP configuration: {}", e);
186                        }
187
188                        info!("Loaded and validated MCP configuration from {}", path);
189                        info!("Configured {} MCP servers", config.servers.len());
190                        Ok(Some(config))
191                    }
192                    Err(e) => {
193                        error!("Failed to parse MCP configuration: {}", e);
194                        error!("Please check your JSON syntax and ensure it matches the MCP configuration schema");
195                        anyhow::bail!("Invalid MCP configuration format: {}", e);
196                    }
197                }
198            }
199            Err(e) => {
200                error!("Failed to read MCP configuration file {}: {}", path, e);
201                error!("Please ensure the file exists and is readable");
202                anyhow::bail!("Cannot read MCP configuration file: {}", e);
203            }
204        }
205    } else {
206        Ok(None)
207    }
208}
209
210/// Validate MCP configuration for common issues
211fn validate_mcp_config(config: &McpClientConfig) -> Result<()> {
212    use std::collections::HashSet;
213
214    // Check for duplicate server IDs
215    let mut seen_ids = HashSet::new();
216    for server in &config.servers {
217        if !seen_ids.insert(&server.id) {
218            anyhow::bail!("Duplicate server ID: {}", server.id);
219        }
220
221        // Validate server ID format
222        if !server
223            .id
224            .chars()
225            .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
226        {
227            anyhow::bail!(
228                "Invalid server ID '{}': must contain only alphanumeric, hyphen, underscore",
229                server.id
230            );
231        }
232
233        // Validate URLs for HTTP/WebSocket sources
234        match &server.source {
235            mistralrs_core::McpServerSource::Http { url, .. }
236            | mistralrs_core::McpServerSource::WebSocket { url, .. } => {
237                // Basic URL validation - check for scheme
238                if !url.starts_with("http://")
239                    && !url.starts_with("https://")
240                    && !url.starts_with("ws://")
241                    && !url.starts_with("wss://")
242                {
243                    anyhow::bail!("Invalid URL for server '{}': must start with http://, https://, ws://, or wss://", server.id);
244                }
245                if url.len() < 10 {
246                    anyhow::bail!("Invalid URL for server '{}': URL too short", server.id);
247                }
248            }
249            mistralrs_core::McpServerSource::Process { command, .. } => {
250                if command.is_empty() {
251                    anyhow::bail!("Empty command for server '{}'", server.id);
252                }
253            }
254        }
255    }
256
257    // Validate global settings
258    if let Some(timeout) = config.tool_timeout_secs {
259        if timeout == 0 {
260            anyhow::bail!("tool_timeout_secs must be greater than 0");
261        }
262    }
263
264    if let Some(max_calls) = config.max_concurrent_calls {
265        if max_calls == 0 {
266            anyhow::bail!("max_concurrent_calls must be greater than 0");
267        }
268    }
269
270    Ok(())
271}
272
273/// Configuration for a single model in a multi-model setup (parsing format)
274#[derive(Clone, serde::Deserialize)]
275struct ModelConfigParsed {
276    /// Model selector
277    #[serde(flatten)]
278    model: ModelSelected,
279    /// Model-specific chat template
280    chat_template: Option<String>,
281    /// Model-specific JINJA template
282    jinja_explicit: Option<String>,
283    /// Model-specific device layers
284    num_device_layers: Option<Vec<String>>,
285    /// Model-specific in-situ quantization
286    in_situ_quant: Option<String>,
287}
288
289/// Load multi-model configuration from file
290fn load_multi_model_config(config_path: &str) -> Result<Vec<ModelConfig>> {
291    let config_content = std::fs::read_to_string(config_path).map_err(|e| {
292        anyhow::anyhow!(
293            "Failed to read multi-model config file {}: {}",
294            config_path,
295            e
296        )
297    })?;
298
299    let configs_parsed: HashMap<String, ModelConfigParsed> = serde_json::from_str(&config_content)
300        .map_err(|e| anyhow::anyhow!("Failed to parse multi-model config: {}", e))?;
301
302    if configs_parsed.is_empty() {
303        anyhow::bail!("Multi-model configuration file is empty");
304    }
305
306    let mut configs = Vec::new();
307    for (model_id, parsed_config) in configs_parsed {
308        let config = ModelConfig {
309            model_id,
310            model: parsed_config.model,
311            chat_template: parsed_config.chat_template,
312            jinja_explicit: parsed_config.jinja_explicit,
313            num_device_layers: parsed_config.num_device_layers,
314            in_situ_quant: parsed_config.in_situ_quant,
315        };
316        configs.push(config);
317    }
318
319    info!(
320        "Loaded multi-model configuration with {} models",
321        configs.len()
322    );
323    Ok(configs)
324}
325
326#[tokio::main]
327async fn main() -> Result<()> {
328    let args = Args::parse();
329
330    initialize_logging();
331
332    // Load MCP configuration if provided
333    let mcp_config = load_mcp_config(args.mcp_config.as_deref())?;
334
335    let paged_attn = configure_paged_attn_from_flags(args.paged_attn, args.no_paged_attn)?;
336
337    let mistralrs = match args.model {
338        ModelSelected::MultiModel {
339            config,
340            default_model_id,
341        } => {
342            // Multi-model mode
343            let model_configs = load_multi_model_config(&config)?;
344
345            let mut builder = MistralRsForServerBuilder::new()
346                .with_max_seqs(args.max_seqs)
347                .with_no_kv_cache(args.no_kv_cache)
348                .with_token_source(args.token_source)
349                .with_interactive_mode(args.interactive_mode)
350                .with_prefix_cache_n(args.prefix_cache_n)
351                .set_paged_attn(paged_attn)
352                .with_cpu(args.cpu)
353                .with_enable_search(args.enable_search)
354                .with_seed_optional(args.seed)
355                .with_log_optional(args.log)
356                .with_mcp_config_optional(mcp_config)
357                .with_paged_attn_cache_type(args.cache_type.unwrap_or_default());
358
359            // Add models to builder
360            for config in model_configs {
361                builder = builder.add_model_config(config);
362            }
363
364            // Set default model if specified
365            if let Some(default_id) = default_model_id {
366                builder = builder.with_default_model_id(default_id);
367            }
368
369            if let Some(model) = args.search_embedding_model {
370                builder = builder.with_search_embedding_model(model);
371            }
372
373            builder.build_multi_model().await?
374        }
375        model => {
376            // Single-model mode
377            let mut builder = MistralRsForServerBuilder::new()
378                .with_model(model)
379                .with_max_seqs(args.max_seqs)
380                .with_no_kv_cache(args.no_kv_cache)
381                .with_token_source(args.token_source)
382                .with_interactive_mode(args.interactive_mode)
383                .with_prefix_cache_n(args.prefix_cache_n)
384                .set_paged_attn(paged_attn)
385                .with_cpu(args.cpu)
386                .with_enable_search(args.enable_search)
387                .with_seed_optional(args.seed)
388                .with_log_optional(args.log)
389                .with_chat_template_optional(args.chat_template)
390                .with_jinja_explicit_optional(args.jinja_explicit)
391                .with_num_device_layers_optional(args.num_device_layers)
392                .with_in_situ_quant_optional(args.in_situ_quant)
393                .with_paged_attn_gpu_mem_optional(args.paged_attn_gpu_mem)
394                .with_paged_attn_gpu_mem_usage_optional(args.paged_attn_gpu_mem_usage)
395                .with_paged_ctxt_len_optional(args.paged_ctxt_len)
396                .with_paged_attn_block_size_optional(args.paged_attn_block_size)
397                .with_mcp_config_optional(mcp_config)
398                .with_paged_attn_cache_type(args.cache_type.unwrap_or_default());
399
400            if let Some(model) = args.search_embedding_model {
401                builder = builder.with_search_embedding_model(model);
402            }
403
404            builder.build().await?
405        }
406    };
407
408    // TODO: refactor this
409    let search_embedding_model =
410        get_search_embedding_model(args.enable_search, args.search_embedding_model);
411
412    if args.interactive_mode {
413        interactive_mode(
414            mistralrs,
415            search_embedding_model.is_some(),
416            args.enable_thinking.then_some(true),
417        )
418        .await;
419        return Ok(());
420    }
421
422    if !args.interactive_mode && args.port.is_none() && args.mcp_port.is_none() {
423        anyhow::bail!("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port` or `--mcp-port`?")
424    }
425
426    let mcp_port = if let Some(port) = args.mcp_port {
427        let host = args
428            .serve_ip
429            .clone()
430            .unwrap_or_else(|| "0.0.0.0".to_string());
431        info!("MCP server listening on http://{host}:{port}/mcp.");
432        info!("MCP protocol version is {}.", LATEST_PROTOCOL_VERSION);
433        let mcp_server = mcp_server::create_http_mcp_server(mistralrs.clone(), host, port);
434
435        tokio::spawn(async move {
436            if let Err(e) = mcp_server.await {
437                eprintln!("MCP server error: {e}");
438            }
439        })
440    } else {
441        tokio::spawn(async {})
442    };
443
444    let oai_port = if let Some(port) = args.port {
445        let ip = args
446            .serve_ip
447            .clone()
448            .unwrap_or_else(|| "0.0.0.0".to_string());
449
450        // Create listener early to validate address before model loading
451        let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?;
452
453        let app = MistralRsServerRouterBuilder::new()
454            .with_mistralrs(mistralrs)
455            .build()
456            .await?;
457
458        info!("OpenAI-compatible server listening on http://{ip}:{port}.");
459
460        tokio::spawn(async move {
461            if let Err(e) = axum::serve(listener, app).await {
462                eprintln!("OpenAI server error: {e}");
463            }
464        })
465    } else {
466        tokio::spawn(async {})
467    };
468
469    let (_, _) = join!(oai_port, mcp_port);
470
471    Ok(())
472}