mistralrs_server/
main.rs

1use anyhow::Result;
2use clap::Parser;
3use mistralrs_core::{
4    initialize_logging, McpClientConfig, ModelSelected, PagedCacheType, TokenSource,
5};
6use rust_mcp_sdk::schema::LATEST_PROTOCOL_VERSION;
7use std::collections::HashMap;
8use tokio::join;
9use tracing::{error, info};
10
11use mistralrs_server_core::{
12    mistralrs_for_server_builder::{
13        configure_paged_attn_from_flags, defaults, get_bert_model, MistralRsForServerBuilder,
14        ModelConfig,
15    },
16    mistralrs_server_router_builder::MistralRsServerRouterBuilder,
17};
18
19mod interactive_mode;
20use interactive_mode::interactive_mode;
21mod mcp_server;
22
23#[derive(Parser)]
24#[command(version, about, long_about = None)]
25struct Args {
26    /// IP to serve on. Defaults to "0.0.0.0"
27    #[arg(long)]
28    serve_ip: Option<String>,
29
30    /// Integer seed to ensure reproducible random number generation.
31    #[arg(short, long)]
32    seed: Option<u64>,
33
34    /// Port to serve on.
35    #[arg(short, long)]
36    port: Option<u16>,
37
38    /// Log all responses and requests to this file
39    #[clap(long, short)]
40    log: Option<String>,
41
42    /// If a sequence is larger than the maximum model length, truncate the number
43    /// of tokens such that the sequence will fit at most the maximum length.
44    /// If `max_tokens` is not specified in the request, space for 10 tokens will be reserved instead.
45    #[clap(long, short, action)]
46    truncate_sequence: bool,
47
48    /// Model selector
49    #[clap(subcommand)]
50    model: ModelSelected,
51
52    /// Maximum running sequences at any time. If the `tgt_non_granular_index` flag is set for X-LoRA models, this will be set to 1.
53    #[arg(long, default_value_t = defaults::MAX_SEQS)]
54    max_seqs: usize,
55
56    /// Use no KV cache.
57    #[arg(long, default_value_t = defaults::NO_KV_CACHE)]
58    no_kv_cache: bool,
59
60    /// Chat template file with a JINJA file with `messages`, `add_generation_prompt`, `bos_token`, `eos_token`, and `unk_token` as inputs.
61    /// Used if the automatic deserialization fails. If this ends with `.json` (ie., it is a file) then that template is loaded.
62    #[arg(short, long)]
63    chat_template: Option<String>,
64
65    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
66    #[arg(short, long)]
67    jinja_explicit: Option<String>,
68
69    /// Source of the token for authentication.
70    /// Can be in the formats: `literal:<value>`, `env:<value>`, `path:<value>`, `cache` to use a cached token, or `none` to use no token.
71    /// Defaults to `cache`.
72    #[arg(long, default_value_t = defaults::TOKEN_SOURCE, value_parser = parse_token_source)]
73    token_source: TokenSource,
74
75    /// Enter interactive mode instead of serving a chat server.
76    #[clap(long, short, action)]
77    interactive_mode: bool,
78
79    /// Number of prefix caches to hold on the device. Other caches are evicted to the CPU based on a LRU strategy.
80    #[arg(long, default_value_t = defaults::PREFIX_CACHE_N)]
81    prefix_cache_n: usize,
82
83    /// NOTE: This can be omitted to use automatic device mapping!
84    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
85    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
86    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
87    #[arg(short, long, value_parser, value_delimiter = ';')]
88    num_device_layers: Option<Vec<String>>,
89
90    /// In-situ quantization to apply.
91    #[arg(long = "isq")]
92    in_situ_quant: Option<String>,
93
94    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
95    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
96    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
97    #[arg(long = "pa-gpu-mem")]
98    paged_attn_gpu_mem: Option<usize>,
99
100    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
101    /// If this is not set and the device is CUDA, it will default to `0.9`.
102    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
103    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
104    #[arg(long = "pa-gpu-mem-usage")]
105    paged_attn_gpu_mem_usage: Option<f32>,
106
107    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
108    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
109    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
110    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
111    #[arg(long = "pa-ctxt-len")]
112    paged_ctxt_len: Option<usize>,
113
114    /// PagedAttention KV cache type (auto or f8e4m3).
115    /// Defaults to `auto`.
116    #[arg(long = "pa-cache-type", value_parser = parse_cache_type)]
117    cache_type: Option<PagedCacheType>,
118
119    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
120    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
121    #[arg(long = "pa-blk-size")]
122    paged_attn_block_size: Option<usize>,
123
124    /// Disable PagedAttention on CUDA. Because PagedAttention is already disabled on Metal, this is only applicable on CUDA.
125    #[arg(
126        long = "no-paged-attn",
127        default_value_t = false,
128        conflicts_with = "paged_attn"
129    )]
130    no_paged_attn: bool,
131
132    /// Enable PagedAttention on Metal. Because PagedAttention is already enabled on CUDA, this is only applicable on Metal.
133    #[arg(
134        long = "paged-attn",
135        default_value_t = false,
136        conflicts_with_all = ["no_paged_attn", "cpu"]
137    )]
138    paged_attn: bool,
139
140    /// Use CPU only
141    #[arg(long)]
142    cpu: bool,
143
144    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified below or the default.
145    #[arg(long = "enable-search")]
146    enable_search: bool,
147
148    /// Specify a Hugging Face model ID for a BERT model to assist web searching. Defaults to Snowflake Arctic Embed L.
149    #[arg(long = "search-bert-model")]
150    search_bert_model: Option<String>,
151
152    /// Enable thinking for interactive mode and models that support it.
153    #[arg(long = "enable-thinking")]
154    enable_thinking: bool,
155
156    /// Port to serve MCP protocol on
157    #[arg(long)]
158    mcp_port: Option<u16>,
159
160    /// MCP client configuration file path
161    #[arg(long)]
162    mcp_config: Option<String>,
163}
164
165fn parse_token_source(s: &str) -> Result<TokenSource, String> {
166    s.parse()
167}
168
169fn parse_cache_type(s: &str) -> Result<PagedCacheType, String> {
170    s.parse()
171}
172
173/// Load MCP configuration from file path or environment variable
174fn load_mcp_config(mcp_config_path: Option<&str>) -> Result<Option<McpClientConfig>> {
175    let config_path = if let Some(path) = mcp_config_path {
176        Some(path.to_string())
177    } else {
178        // Check environment variable if no CLI arg provided
179        std::env::var("MCP_CONFIG_PATH").ok()
180    };
181
182    if let Some(path) = config_path {
183        match std::fs::read_to_string(&path) {
184            Ok(config_content) => {
185                match serde_json::from_str::<McpClientConfig>(&config_content) {
186                    Ok(config) => {
187                        // Validate configuration
188                        if let Err(e) = validate_mcp_config(&config) {
189                            error!("MCP configuration validation failed: {}", e);
190                            anyhow::bail!("Invalid MCP configuration: {}", e);
191                        }
192
193                        info!("Loaded and validated MCP configuration from {}", path);
194                        info!("Configured {} MCP servers", config.servers.len());
195                        Ok(Some(config))
196                    }
197                    Err(e) => {
198                        error!("Failed to parse MCP configuration: {}", e);
199                        error!("Please check your JSON syntax and ensure it matches the MCP configuration schema");
200                        anyhow::bail!("Invalid MCP configuration format: {}", e);
201                    }
202                }
203            }
204            Err(e) => {
205                error!("Failed to read MCP configuration file {}: {}", path, e);
206                error!("Please ensure the file exists and is readable");
207                anyhow::bail!("Cannot read MCP configuration file: {}", e);
208            }
209        }
210    } else {
211        Ok(None)
212    }
213}
214
215/// Validate MCP configuration for common issues
216fn validate_mcp_config(config: &McpClientConfig) -> Result<()> {
217    use std::collections::HashSet;
218
219    // Check for duplicate server IDs
220    let mut seen_ids = HashSet::new();
221    for server in &config.servers {
222        if !seen_ids.insert(&server.id) {
223            anyhow::bail!("Duplicate server ID: {}", server.id);
224        }
225
226        // Validate server ID format
227        if !server
228            .id
229            .chars()
230            .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
231        {
232            anyhow::bail!(
233                "Invalid server ID '{}': must contain only alphanumeric, hyphen, underscore",
234                server.id
235            );
236        }
237
238        // Validate URLs for HTTP/WebSocket sources
239        match &server.source {
240            mistralrs_core::McpServerSource::Http { url, .. }
241            | mistralrs_core::McpServerSource::WebSocket { url, .. } => {
242                // Basic URL validation - check for scheme
243                if !url.starts_with("http://")
244                    && !url.starts_with("https://")
245                    && !url.starts_with("ws://")
246                    && !url.starts_with("wss://")
247                {
248                    anyhow::bail!("Invalid URL for server '{}': must start with http://, https://, ws://, or wss://", server.id);
249                }
250                if url.len() < 10 {
251                    anyhow::bail!("Invalid URL for server '{}': URL too short", server.id);
252                }
253            }
254            mistralrs_core::McpServerSource::Process { command, .. } => {
255                if command.is_empty() {
256                    anyhow::bail!("Empty command for server '{}'", server.id);
257                }
258            }
259        }
260    }
261
262    // Validate global settings
263    if let Some(timeout) = config.tool_timeout_secs {
264        if timeout == 0 {
265            anyhow::bail!("tool_timeout_secs must be greater than 0");
266        }
267    }
268
269    if let Some(max_calls) = config.max_concurrent_calls {
270        if max_calls == 0 {
271            anyhow::bail!("max_concurrent_calls must be greater than 0");
272        }
273    }
274
275    Ok(())
276}
277
278/// Configuration for a single model in a multi-model setup (parsing format)
279#[derive(Clone, serde::Deserialize)]
280struct ModelConfigParsed {
281    /// Model selector
282    #[serde(flatten)]
283    model: ModelSelected,
284    /// Model-specific chat template
285    chat_template: Option<String>,
286    /// Model-specific JINJA template
287    jinja_explicit: Option<String>,
288    /// Model-specific device layers
289    num_device_layers: Option<Vec<String>>,
290    /// Model-specific in-situ quantization
291    in_situ_quant: Option<String>,
292}
293
294/// Load multi-model configuration from file
295fn load_multi_model_config(config_path: &str) -> Result<Vec<ModelConfig>> {
296    let config_content = std::fs::read_to_string(config_path).map_err(|e| {
297        anyhow::anyhow!(
298            "Failed to read multi-model config file {}: {}",
299            config_path,
300            e
301        )
302    })?;
303
304    let configs_parsed: HashMap<String, ModelConfigParsed> = serde_json::from_str(&config_content)
305        .map_err(|e| anyhow::anyhow!("Failed to parse multi-model config: {}", e))?;
306
307    if configs_parsed.is_empty() {
308        anyhow::bail!("Multi-model configuration file is empty");
309    }
310
311    let mut configs = Vec::new();
312    for (model_id, parsed_config) in configs_parsed {
313        let config = ModelConfig {
314            model_id,
315            model: parsed_config.model,
316            chat_template: parsed_config.chat_template,
317            jinja_explicit: parsed_config.jinja_explicit,
318            num_device_layers: parsed_config.num_device_layers,
319            in_situ_quant: parsed_config.in_situ_quant,
320        };
321        configs.push(config);
322    }
323
324    info!(
325        "Loaded multi-model configuration with {} models",
326        configs.len()
327    );
328    Ok(configs)
329}
330
331#[tokio::main]
332async fn main() -> Result<()> {
333    let args = Args::parse();
334
335    initialize_logging();
336
337    // Load MCP configuration if provided
338    let mcp_config = load_mcp_config(args.mcp_config.as_deref())?;
339
340    let paged_attn = configure_paged_attn_from_flags(args.paged_attn, args.no_paged_attn)?;
341
342    let mistralrs = match args.model {
343        ModelSelected::MultiModel {
344            config,
345            default_model_id,
346        } => {
347            // Multi-model mode
348            let model_configs = load_multi_model_config(&config)?;
349
350            let mut builder = MistralRsForServerBuilder::new()
351                .with_truncate_sequence(args.truncate_sequence)
352                .with_max_seqs(args.max_seqs)
353                .with_no_kv_cache(args.no_kv_cache)
354                .with_token_source(args.token_source)
355                .with_interactive_mode(args.interactive_mode)
356                .with_prefix_cache_n(args.prefix_cache_n)
357                .set_paged_attn(paged_attn)
358                .with_cpu(args.cpu)
359                .with_enable_search(args.enable_search)
360                .with_seed_optional(args.seed)
361                .with_log_optional(args.log)
362                .with_mcp_config_optional(mcp_config)
363                .with_paged_attn_cache_type(args.cache_type.unwrap_or_default());
364
365            // Add models to builder
366            for config in model_configs {
367                builder = builder.add_model_config(config);
368            }
369
370            // Set default model if specified
371            if let Some(default_id) = default_model_id {
372                builder = builder.with_default_model_id(default_id);
373            }
374
375            builder.build_multi_model().await?
376        }
377        model => {
378            // Single-model mode
379            MistralRsForServerBuilder::new()
380                .with_truncate_sequence(args.truncate_sequence)
381                .with_model(model)
382                .with_max_seqs(args.max_seqs)
383                .with_no_kv_cache(args.no_kv_cache)
384                .with_token_source(args.token_source)
385                .with_interactive_mode(args.interactive_mode)
386                .with_prefix_cache_n(args.prefix_cache_n)
387                .set_paged_attn(paged_attn)
388                .with_cpu(args.cpu)
389                .with_enable_search(args.enable_search)
390                .with_seed_optional(args.seed)
391                .with_log_optional(args.log)
392                .with_chat_template_optional(args.chat_template)
393                .with_jinja_explicit_optional(args.jinja_explicit)
394                .with_num_device_layers_optional(args.num_device_layers)
395                .with_in_situ_quant_optional(args.in_situ_quant)
396                .with_paged_attn_gpu_mem_optional(args.paged_attn_gpu_mem)
397                .with_paged_attn_gpu_mem_usage_optional(args.paged_attn_gpu_mem_usage)
398                .with_paged_ctxt_len_optional(args.paged_ctxt_len)
399                .with_paged_attn_block_size_optional(args.paged_attn_block_size)
400                .with_mcp_config_optional(mcp_config)
401                .with_paged_attn_cache_type(args.cache_type.unwrap_or_default())
402                .build()
403                .await?
404        }
405    };
406
407    // TODO: refactor this
408    let bert_model = get_bert_model(args.enable_search, args.search_bert_model);
409
410    if args.interactive_mode {
411        interactive_mode(
412            mistralrs,
413            bert_model.is_some(),
414            args.enable_thinking.then_some(true),
415        )
416        .await;
417        return Ok(());
418    }
419
420    if !args.interactive_mode && args.port.is_none() && args.mcp_port.is_none() {
421        anyhow::bail!("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port` or `--mcp-port`?")
422    }
423
424    let mcp_port = if let Some(port) = args.mcp_port {
425        let host = args
426            .serve_ip
427            .clone()
428            .unwrap_or_else(|| "0.0.0.0".to_string());
429        info!("MCP server listening on http://{host}:{port}/mcp.");
430        info!("MCP protocol version is {}.", LATEST_PROTOCOL_VERSION);
431        let mcp_server = mcp_server::create_http_mcp_server(mistralrs.clone(), host, port);
432
433        tokio::spawn(async move {
434            if let Err(e) = mcp_server.await {
435                eprintln!("MCP server error: {e}");
436            }
437        })
438    } else {
439        tokio::spawn(async {})
440    };
441
442    let oai_port = if let Some(port) = args.port {
443        let ip = args
444            .serve_ip
445            .clone()
446            .unwrap_or_else(|| "0.0.0.0".to_string());
447
448        // Create listener early to validate address before model loading
449        let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?;
450
451        let app = MistralRsServerRouterBuilder::new()
452            .with_mistralrs(mistralrs)
453            .build()
454            .await?;
455
456        info!("OpenAI-compatible server listening on http://{ip}:{port}.");
457
458        tokio::spawn(async move {
459            if let Err(e) = axum::serve(listener, app).await {
460                eprintln!("OpenAI server error: {e}");
461            }
462        })
463    } else {
464        tokio::spawn(async {})
465    };
466
467    let (_, _) = join!(oai_port, mcp_port);
468
469    Ok(())
470}