1use anyhow::Result;
2use clap::Parser;
3use mistralrs_core::{
4 initialize_logging, McpClientConfig, ModelSelected, PagedCacheType, TokenSource,
5};
6use rust_mcp_sdk::schema::LATEST_PROTOCOL_VERSION;
7use std::collections::HashMap;
8use tokio::join;
9use tracing::{error, info};
10
11use mistralrs_server_core::{
12 mistralrs_for_server_builder::{
13 configure_paged_attn_from_flags, defaults, get_bert_model, MistralRsForServerBuilder,
14 ModelConfig,
15 },
16 mistralrs_server_router_builder::MistralRsServerRouterBuilder,
17};
18
19mod interactive_mode;
20use interactive_mode::interactive_mode;
21mod mcp_server;
22
23#[derive(Parser)]
24#[command(version, about, long_about = None)]
25struct Args {
26 #[arg(long)]
28 serve_ip: Option<String>,
29
30 #[arg(short, long)]
32 seed: Option<u64>,
33
34 #[arg(short, long)]
36 port: Option<u16>,
37
38 #[clap(long, short)]
40 log: Option<String>,
41
42 #[clap(long, short, action)]
46 truncate_sequence: bool,
47
48 #[clap(subcommand)]
50 model: ModelSelected,
51
52 #[arg(long, default_value_t = defaults::MAX_SEQS)]
54 max_seqs: usize,
55
56 #[arg(long, default_value_t = defaults::NO_KV_CACHE)]
58 no_kv_cache: bool,
59
60 #[arg(short, long)]
63 chat_template: Option<String>,
64
65 #[arg(short, long)]
67 jinja_explicit: Option<String>,
68
69 #[arg(long, default_value_t = defaults::TOKEN_SOURCE, value_parser = parse_token_source)]
73 token_source: TokenSource,
74
75 #[clap(long, short, action)]
77 interactive_mode: bool,
78
79 #[arg(long, default_value_t = defaults::PREFIX_CACHE_N)]
81 prefix_cache_n: usize,
82
83 #[arg(short, long, value_parser, value_delimiter = ';')]
88 num_device_layers: Option<Vec<String>>,
89
90 #[arg(long = "isq")]
92 in_situ_quant: Option<String>,
93
94 #[arg(long = "pa-gpu-mem")]
98 paged_attn_gpu_mem: Option<usize>,
99
100 #[arg(long = "pa-gpu-mem-usage")]
105 paged_attn_gpu_mem_usage: Option<f32>,
106
107 #[arg(long = "pa-ctxt-len")]
112 paged_ctxt_len: Option<usize>,
113
114 #[arg(long = "pa-cache-type", value_parser = parse_cache_type)]
117 cache_type: Option<PagedCacheType>,
118
119 #[arg(long = "pa-blk-size")]
122 paged_attn_block_size: Option<usize>,
123
124 #[arg(
126 long = "no-paged-attn",
127 default_value_t = false,
128 conflicts_with = "paged_attn"
129 )]
130 no_paged_attn: bool,
131
132 #[arg(
134 long = "paged-attn",
135 default_value_t = false,
136 conflicts_with_all = ["no_paged_attn", "cpu"]
137 )]
138 paged_attn: bool,
139
140 #[arg(long = "prompt-batchsize")]
142 prompt_chunksize: Option<usize>,
143
144 #[arg(long)]
146 cpu: bool,
147
148 #[arg(long = "enable-search")]
150 enable_search: bool,
151
152 #[arg(long = "search-bert-model")]
154 search_bert_model: Option<String>,
155
156 #[arg(long = "enable-thinking")]
158 enable_thinking: bool,
159
160 #[arg(long)]
162 mcp_port: Option<u16>,
163
164 #[arg(long)]
166 mcp_config: Option<String>,
167}
168
169fn parse_token_source(s: &str) -> Result<TokenSource, String> {
170 s.parse()
171}
172
173fn parse_cache_type(s: &str) -> Result<PagedCacheType, String> {
174 s.parse()
175}
176
177fn load_mcp_config(mcp_config_path: Option<&str>) -> Result<Option<McpClientConfig>> {
179 let config_path = if let Some(path) = mcp_config_path {
180 Some(path.to_string())
181 } else {
182 std::env::var("MCP_CONFIG_PATH").ok()
184 };
185
186 if let Some(path) = config_path {
187 match std::fs::read_to_string(&path) {
188 Ok(config_content) => {
189 match serde_json::from_str::<McpClientConfig>(&config_content) {
190 Ok(config) => {
191 if let Err(e) = validate_mcp_config(&config) {
193 error!("MCP configuration validation failed: {}", e);
194 anyhow::bail!("Invalid MCP configuration: {}", e);
195 }
196
197 info!("Loaded and validated MCP configuration from {}", path);
198 info!("Configured {} MCP servers", config.servers.len());
199 Ok(Some(config))
200 }
201 Err(e) => {
202 error!("Failed to parse MCP configuration: {}", e);
203 error!("Please check your JSON syntax and ensure it matches the MCP configuration schema");
204 anyhow::bail!("Invalid MCP configuration format: {}", e);
205 }
206 }
207 }
208 Err(e) => {
209 error!("Failed to read MCP configuration file {}: {}", path, e);
210 error!("Please ensure the file exists and is readable");
211 anyhow::bail!("Cannot read MCP configuration file: {}", e);
212 }
213 }
214 } else {
215 Ok(None)
216 }
217}
218
219fn validate_mcp_config(config: &McpClientConfig) -> Result<()> {
221 use std::collections::HashSet;
222
223 let mut seen_ids = HashSet::new();
225 for server in &config.servers {
226 if !seen_ids.insert(&server.id) {
227 anyhow::bail!("Duplicate server ID: {}", server.id);
228 }
229
230 if !server
232 .id
233 .chars()
234 .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
235 {
236 anyhow::bail!(
237 "Invalid server ID '{}': must contain only alphanumeric, hyphen, underscore",
238 server.id
239 );
240 }
241
242 match &server.source {
244 mistralrs_core::McpServerSource::Http { url, .. }
245 | mistralrs_core::McpServerSource::WebSocket { url, .. } => {
246 if !url.starts_with("http://")
248 && !url.starts_with("https://")
249 && !url.starts_with("ws://")
250 && !url.starts_with("wss://")
251 {
252 anyhow::bail!("Invalid URL for server '{}': must start with http://, https://, ws://, or wss://", server.id);
253 }
254 if url.len() < 10 {
255 anyhow::bail!("Invalid URL for server '{}': URL too short", server.id);
256 }
257 }
258 mistralrs_core::McpServerSource::Process { command, .. } => {
259 if command.is_empty() {
260 anyhow::bail!("Empty command for server '{}'", server.id);
261 }
262 }
263 }
264 }
265
266 if let Some(timeout) = config.tool_timeout_secs {
268 if timeout == 0 {
269 anyhow::bail!("tool_timeout_secs must be greater than 0");
270 }
271 }
272
273 if let Some(max_calls) = config.max_concurrent_calls {
274 if max_calls == 0 {
275 anyhow::bail!("max_concurrent_calls must be greater than 0");
276 }
277 }
278
279 Ok(())
280}
281
282#[derive(Clone, serde::Deserialize)]
284struct ModelConfigParsed {
285 #[serde(flatten)]
287 model: ModelSelected,
288 chat_template: Option<String>,
290 jinja_explicit: Option<String>,
292 num_device_layers: Option<Vec<String>>,
294 in_situ_quant: Option<String>,
296}
297
298fn load_multi_model_config(config_path: &str) -> Result<Vec<ModelConfig>> {
300 let config_content = std::fs::read_to_string(config_path).map_err(|e| {
301 anyhow::anyhow!(
302 "Failed to read multi-model config file {}: {}",
303 config_path,
304 e
305 )
306 })?;
307
308 let configs_parsed: HashMap<String, ModelConfigParsed> = serde_json::from_str(&config_content)
309 .map_err(|e| anyhow::anyhow!("Failed to parse multi-model config: {}", e))?;
310
311 if configs_parsed.is_empty() {
312 anyhow::bail!("Multi-model configuration file is empty");
313 }
314
315 let mut configs = Vec::new();
316 for (model_id, parsed_config) in configs_parsed {
317 let config = ModelConfig {
318 model_id,
319 model: parsed_config.model,
320 chat_template: parsed_config.chat_template,
321 jinja_explicit: parsed_config.jinja_explicit,
322 num_device_layers: parsed_config.num_device_layers,
323 in_situ_quant: parsed_config.in_situ_quant,
324 };
325 configs.push(config);
326 }
327
328 info!(
329 "Loaded multi-model configuration with {} models",
330 configs.len()
331 );
332 Ok(configs)
333}
334
335#[tokio::main]
336async fn main() -> Result<()> {
337 let args = Args::parse();
338
339 initialize_logging();
340
341 let mcp_config = load_mcp_config(args.mcp_config.as_deref())?;
343
344 let paged_attn = configure_paged_attn_from_flags(args.paged_attn, args.no_paged_attn)?;
345
346 let mistralrs = match args.model {
347 ModelSelected::MultiModel {
348 config,
349 default_model_id,
350 } => {
351 let model_configs = load_multi_model_config(&config)?;
353
354 let mut builder = MistralRsForServerBuilder::new()
355 .with_truncate_sequence(args.truncate_sequence)
356 .with_max_seqs(args.max_seqs)
357 .with_no_kv_cache(args.no_kv_cache)
358 .with_token_source(args.token_source)
359 .with_interactive_mode(args.interactive_mode)
360 .with_prefix_cache_n(args.prefix_cache_n)
361 .set_paged_attn(paged_attn)
362 .with_cpu(args.cpu)
363 .with_enable_search(args.enable_search)
364 .with_seed_optional(args.seed)
365 .with_log_optional(args.log)
366 .with_prompt_chunksize_optional(args.prompt_chunksize)
367 .with_mcp_config_optional(mcp_config)
368 .with_paged_attn_cache_type(args.cache_type.unwrap_or_default());
369
370 for config in model_configs {
372 builder = builder.add_model_config(config);
373 }
374
375 if let Some(default_id) = default_model_id {
377 builder = builder.with_default_model_id(default_id);
378 }
379
380 builder.build_multi_model().await?
381 }
382 model => {
383 MistralRsForServerBuilder::new()
385 .with_truncate_sequence(args.truncate_sequence)
386 .with_model(model)
387 .with_max_seqs(args.max_seqs)
388 .with_no_kv_cache(args.no_kv_cache)
389 .with_token_source(args.token_source)
390 .with_interactive_mode(args.interactive_mode)
391 .with_prefix_cache_n(args.prefix_cache_n)
392 .set_paged_attn(paged_attn)
393 .with_cpu(args.cpu)
394 .with_enable_search(args.enable_search)
395 .with_seed_optional(args.seed)
396 .with_log_optional(args.log)
397 .with_chat_template_optional(args.chat_template)
398 .with_jinja_explicit_optional(args.jinja_explicit)
399 .with_num_device_layers_optional(args.num_device_layers)
400 .with_in_situ_quant_optional(args.in_situ_quant)
401 .with_paged_attn_gpu_mem_optional(args.paged_attn_gpu_mem)
402 .with_paged_attn_gpu_mem_usage_optional(args.paged_attn_gpu_mem_usage)
403 .with_paged_ctxt_len_optional(args.paged_ctxt_len)
404 .with_paged_attn_block_size_optional(args.paged_attn_block_size)
405 .with_prompt_chunksize_optional(args.prompt_chunksize)
406 .with_mcp_config_optional(mcp_config)
407 .with_paged_attn_cache_type(args.cache_type.unwrap_or_default())
408 .build()
409 .await?
410 }
411 };
412
413 let bert_model = get_bert_model(args.enable_search, args.search_bert_model);
415
416 if args.interactive_mode {
417 interactive_mode(
418 mistralrs,
419 bert_model.is_some(),
420 args.enable_thinking.then_some(true),
421 )
422 .await;
423 return Ok(());
424 }
425
426 if !args.interactive_mode && args.port.is_none() && args.mcp_port.is_none() {
427 anyhow::bail!("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port` or `--mcp-port`?")
428 }
429
430 let mcp_port = if let Some(port) = args.mcp_port {
431 let host = args
432 .serve_ip
433 .clone()
434 .unwrap_or_else(|| "0.0.0.0".to_string());
435 info!("MCP server listening on http://{host}:{port}/mcp.");
436 info!("MCP protocol version is {}.", LATEST_PROTOCOL_VERSION);
437 let mcp_server = mcp_server::create_http_mcp_server(mistralrs.clone(), host, port);
438
439 tokio::spawn(async move {
440 if let Err(e) = mcp_server.await {
441 eprintln!("MCP server error: {e}");
442 }
443 })
444 } else {
445 tokio::spawn(async {})
446 };
447
448 let oai_port = if let Some(port) = args.port {
449 let ip = args
450 .serve_ip
451 .clone()
452 .unwrap_or_else(|| "0.0.0.0".to_string());
453
454 let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?;
456
457 let app = MistralRsServerRouterBuilder::new()
458 .with_mistralrs(mistralrs)
459 .build()
460 .await?;
461
462 info!("OpenAI-compatible server listening on http://{ip}:{port}.");
463
464 tokio::spawn(async move {
465 if let Err(e) = axum::serve(listener, app).await {
466 eprintln!("OpenAI server error: {e}");
467 }
468 })
469 } else {
470 tokio::spawn(async {})
471 };
472
473 let (_, _) = join!(oai_port, mcp_port);
474
475 Ok(())
476}