1use anyhow::Result;
2use clap::Parser;
3use mistralrs_core::{
4 initialize_logging, McpClientConfig, ModelSelected, PagedCacheType, TokenSource,
5};
6use rust_mcp_sdk::schema::LATEST_PROTOCOL_VERSION;
7use std::collections::HashMap;
8use tokio::join;
9use tracing::{error, info};
10
11use mistralrs_server_core::{
12 mistralrs_for_server_builder::{
13 configure_paged_attn_from_flags, defaults, get_bert_model, MistralRsForServerBuilder,
14 ModelConfig,
15 },
16 mistralrs_server_router_builder::MistralRsServerRouterBuilder,
17};
18
19mod interactive_mode;
20use interactive_mode::interactive_mode;
21mod mcp_server;
22
23#[derive(Parser)]
24#[command(version, about, long_about = None)]
25struct Args {
26 #[arg(long)]
28 serve_ip: Option<String>,
29
30 #[arg(short, long)]
32 seed: Option<u64>,
33
34 #[arg(short, long)]
36 port: Option<u16>,
37
38 #[clap(long, short)]
40 log: Option<String>,
41
42 #[clap(long, short, action)]
46 truncate_sequence: bool,
47
48 #[clap(subcommand)]
50 model: ModelSelected,
51
52 #[arg(long, default_value_t = defaults::MAX_SEQS)]
54 max_seqs: usize,
55
56 #[arg(long, default_value_t = defaults::NO_KV_CACHE)]
58 no_kv_cache: bool,
59
60 #[arg(short, long)]
63 chat_template: Option<String>,
64
65 #[arg(short, long)]
67 jinja_explicit: Option<String>,
68
69 #[arg(long, default_value_t = defaults::TOKEN_SOURCE, value_parser = parse_token_source)]
73 token_source: TokenSource,
74
75 #[clap(long, short, action)]
77 interactive_mode: bool,
78
79 #[arg(long, default_value_t = defaults::PREFIX_CACHE_N)]
81 prefix_cache_n: usize,
82
83 #[arg(short, long, value_parser, value_delimiter = ';')]
88 num_device_layers: Option<Vec<String>>,
89
90 #[arg(long = "isq")]
92 in_situ_quant: Option<String>,
93
94 #[arg(long = "pa-gpu-mem")]
98 paged_attn_gpu_mem: Option<usize>,
99
100 #[arg(long = "pa-gpu-mem-usage")]
105 paged_attn_gpu_mem_usage: Option<f32>,
106
107 #[arg(long = "pa-ctxt-len")]
112 paged_ctxt_len: Option<usize>,
113
114 #[arg(long = "pa-cache-type", value_parser = parse_cache_type)]
117 cache_type: Option<PagedCacheType>,
118
119 #[arg(long = "pa-blk-size")]
122 paged_attn_block_size: Option<usize>,
123
124 #[arg(
126 long = "no-paged-attn",
127 default_value_t = false,
128 conflicts_with = "paged_attn"
129 )]
130 no_paged_attn: bool,
131
132 #[arg(
134 long = "paged-attn",
135 default_value_t = false,
136 conflicts_with_all = ["no_paged_attn", "cpu"]
137 )]
138 paged_attn: bool,
139
140 #[arg(long)]
142 cpu: bool,
143
144 #[arg(long = "enable-search")]
146 enable_search: bool,
147
148 #[arg(long = "search-bert-model")]
150 search_bert_model: Option<String>,
151
152 #[arg(long = "enable-thinking")]
154 enable_thinking: bool,
155
156 #[arg(long)]
158 mcp_port: Option<u16>,
159
160 #[arg(long)]
162 mcp_config: Option<String>,
163}
164
165fn parse_token_source(s: &str) -> Result<TokenSource, String> {
166 s.parse()
167}
168
169fn parse_cache_type(s: &str) -> Result<PagedCacheType, String> {
170 s.parse()
171}
172
173fn load_mcp_config(mcp_config_path: Option<&str>) -> Result<Option<McpClientConfig>> {
175 let config_path = if let Some(path) = mcp_config_path {
176 Some(path.to_string())
177 } else {
178 std::env::var("MCP_CONFIG_PATH").ok()
180 };
181
182 if let Some(path) = config_path {
183 match std::fs::read_to_string(&path) {
184 Ok(config_content) => {
185 match serde_json::from_str::<McpClientConfig>(&config_content) {
186 Ok(config) => {
187 if let Err(e) = validate_mcp_config(&config) {
189 error!("MCP configuration validation failed: {}", e);
190 anyhow::bail!("Invalid MCP configuration: {}", e);
191 }
192
193 info!("Loaded and validated MCP configuration from {}", path);
194 info!("Configured {} MCP servers", config.servers.len());
195 Ok(Some(config))
196 }
197 Err(e) => {
198 error!("Failed to parse MCP configuration: {}", e);
199 error!("Please check your JSON syntax and ensure it matches the MCP configuration schema");
200 anyhow::bail!("Invalid MCP configuration format: {}", e);
201 }
202 }
203 }
204 Err(e) => {
205 error!("Failed to read MCP configuration file {}: {}", path, e);
206 error!("Please ensure the file exists and is readable");
207 anyhow::bail!("Cannot read MCP configuration file: {}", e);
208 }
209 }
210 } else {
211 Ok(None)
212 }
213}
214
215fn validate_mcp_config(config: &McpClientConfig) -> Result<()> {
217 use std::collections::HashSet;
218
219 let mut seen_ids = HashSet::new();
221 for server in &config.servers {
222 if !seen_ids.insert(&server.id) {
223 anyhow::bail!("Duplicate server ID: {}", server.id);
224 }
225
226 if !server
228 .id
229 .chars()
230 .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
231 {
232 anyhow::bail!(
233 "Invalid server ID '{}': must contain only alphanumeric, hyphen, underscore",
234 server.id
235 );
236 }
237
238 match &server.source {
240 mistralrs_core::McpServerSource::Http { url, .. }
241 | mistralrs_core::McpServerSource::WebSocket { url, .. } => {
242 if !url.starts_with("http://")
244 && !url.starts_with("https://")
245 && !url.starts_with("ws://")
246 && !url.starts_with("wss://")
247 {
248 anyhow::bail!("Invalid URL for server '{}': must start with http://, https://, ws://, or wss://", server.id);
249 }
250 if url.len() < 10 {
251 anyhow::bail!("Invalid URL for server '{}': URL too short", server.id);
252 }
253 }
254 mistralrs_core::McpServerSource::Process { command, .. } => {
255 if command.is_empty() {
256 anyhow::bail!("Empty command for server '{}'", server.id);
257 }
258 }
259 }
260 }
261
262 if let Some(timeout) = config.tool_timeout_secs {
264 if timeout == 0 {
265 anyhow::bail!("tool_timeout_secs must be greater than 0");
266 }
267 }
268
269 if let Some(max_calls) = config.max_concurrent_calls {
270 if max_calls == 0 {
271 anyhow::bail!("max_concurrent_calls must be greater than 0");
272 }
273 }
274
275 Ok(())
276}
277
278#[derive(Clone, serde::Deserialize)]
280struct ModelConfigParsed {
281 #[serde(flatten)]
283 model: ModelSelected,
284 chat_template: Option<String>,
286 jinja_explicit: Option<String>,
288 num_device_layers: Option<Vec<String>>,
290 in_situ_quant: Option<String>,
292}
293
294fn load_multi_model_config(config_path: &str) -> Result<Vec<ModelConfig>> {
296 let config_content = std::fs::read_to_string(config_path).map_err(|e| {
297 anyhow::anyhow!(
298 "Failed to read multi-model config file {}: {}",
299 config_path,
300 e
301 )
302 })?;
303
304 let configs_parsed: HashMap<String, ModelConfigParsed> = serde_json::from_str(&config_content)
305 .map_err(|e| anyhow::anyhow!("Failed to parse multi-model config: {}", e))?;
306
307 if configs_parsed.is_empty() {
308 anyhow::bail!("Multi-model configuration file is empty");
309 }
310
311 let mut configs = Vec::new();
312 for (model_id, parsed_config) in configs_parsed {
313 let config = ModelConfig {
314 model_id,
315 model: parsed_config.model,
316 chat_template: parsed_config.chat_template,
317 jinja_explicit: parsed_config.jinja_explicit,
318 num_device_layers: parsed_config.num_device_layers,
319 in_situ_quant: parsed_config.in_situ_quant,
320 };
321 configs.push(config);
322 }
323
324 info!(
325 "Loaded multi-model configuration with {} models",
326 configs.len()
327 );
328 Ok(configs)
329}
330
331#[tokio::main]
332async fn main() -> Result<()> {
333 let args = Args::parse();
334
335 initialize_logging();
336
337 let mcp_config = load_mcp_config(args.mcp_config.as_deref())?;
339
340 let paged_attn = configure_paged_attn_from_flags(args.paged_attn, args.no_paged_attn)?;
341
342 let mistralrs = match args.model {
343 ModelSelected::MultiModel {
344 config,
345 default_model_id,
346 } => {
347 let model_configs = load_multi_model_config(&config)?;
349
350 let mut builder = MistralRsForServerBuilder::new()
351 .with_truncate_sequence(args.truncate_sequence)
352 .with_max_seqs(args.max_seqs)
353 .with_no_kv_cache(args.no_kv_cache)
354 .with_token_source(args.token_source)
355 .with_interactive_mode(args.interactive_mode)
356 .with_prefix_cache_n(args.prefix_cache_n)
357 .set_paged_attn(paged_attn)
358 .with_cpu(args.cpu)
359 .with_enable_search(args.enable_search)
360 .with_seed_optional(args.seed)
361 .with_log_optional(args.log)
362 .with_mcp_config_optional(mcp_config)
363 .with_paged_attn_cache_type(args.cache_type.unwrap_or_default());
364
365 for config in model_configs {
367 builder = builder.add_model_config(config);
368 }
369
370 if let Some(default_id) = default_model_id {
372 builder = builder.with_default_model_id(default_id);
373 }
374
375 builder.build_multi_model().await?
376 }
377 model => {
378 MistralRsForServerBuilder::new()
380 .with_truncate_sequence(args.truncate_sequence)
381 .with_model(model)
382 .with_max_seqs(args.max_seqs)
383 .with_no_kv_cache(args.no_kv_cache)
384 .with_token_source(args.token_source)
385 .with_interactive_mode(args.interactive_mode)
386 .with_prefix_cache_n(args.prefix_cache_n)
387 .set_paged_attn(paged_attn)
388 .with_cpu(args.cpu)
389 .with_enable_search(args.enable_search)
390 .with_seed_optional(args.seed)
391 .with_log_optional(args.log)
392 .with_chat_template_optional(args.chat_template)
393 .with_jinja_explicit_optional(args.jinja_explicit)
394 .with_num_device_layers_optional(args.num_device_layers)
395 .with_in_situ_quant_optional(args.in_situ_quant)
396 .with_paged_attn_gpu_mem_optional(args.paged_attn_gpu_mem)
397 .with_paged_attn_gpu_mem_usage_optional(args.paged_attn_gpu_mem_usage)
398 .with_paged_ctxt_len_optional(args.paged_ctxt_len)
399 .with_paged_attn_block_size_optional(args.paged_attn_block_size)
400 .with_mcp_config_optional(mcp_config)
401 .with_paged_attn_cache_type(args.cache_type.unwrap_or_default())
402 .build()
403 .await?
404 }
405 };
406
407 let bert_model = get_bert_model(args.enable_search, args.search_bert_model);
409
410 if args.interactive_mode {
411 interactive_mode(
412 mistralrs,
413 bert_model.is_some(),
414 args.enable_thinking.then_some(true),
415 )
416 .await;
417 return Ok(());
418 }
419
420 if !args.interactive_mode && args.port.is_none() && args.mcp_port.is_none() {
421 anyhow::bail!("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port` or `--mcp-port`?")
422 }
423
424 let mcp_port = if let Some(port) = args.mcp_port {
425 let host = args
426 .serve_ip
427 .clone()
428 .unwrap_or_else(|| "0.0.0.0".to_string());
429 info!("MCP server listening on http://{host}:{port}/mcp.");
430 info!("MCP protocol version is {}.", LATEST_PROTOCOL_VERSION);
431 let mcp_server = mcp_server::create_http_mcp_server(mistralrs.clone(), host, port);
432
433 tokio::spawn(async move {
434 if let Err(e) = mcp_server.await {
435 eprintln!("MCP server error: {e}");
436 }
437 })
438 } else {
439 tokio::spawn(async {})
440 };
441
442 let oai_port = if let Some(port) = args.port {
443 let ip = args
444 .serve_ip
445 .clone()
446 .unwrap_or_else(|| "0.0.0.0".to_string());
447
448 let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?;
450
451 let app = MistralRsServerRouterBuilder::new()
452 .with_mistralrs(mistralrs)
453 .build()
454 .await?;
455
456 info!("OpenAI-compatible server listening on http://{ip}:{port}.");
457
458 tokio::spawn(async move {
459 if let Err(e) = axum::serve(listener, app).await {
460 eprintln!("OpenAI server error: {e}");
461 }
462 })
463 } else {
464 tokio::spawn(async {})
465 };
466
467 let (_, _) = join!(oai_port, mcp_port);
468
469 Ok(())
470}