1use anyhow::Result;
2use clap::Parser;
3use mistralrs_core::{
4 initialize_logging, McpClientConfig, ModelSelected, PagedCacheType, SearchEmbeddingModel,
5 TokenSource,
6};
7use rust_mcp_sdk::schema::LATEST_PROTOCOL_VERSION;
8use std::collections::HashMap;
9use tokio::join;
10use tracing::{error, info};
11
12use mistralrs_server_core::{
13 mistralrs_for_server_builder::{
14 configure_paged_attn_from_flags, defaults, get_search_embedding_model,
15 MistralRsForServerBuilder, ModelConfig,
16 },
17 mistralrs_server_router_builder::MistralRsServerRouterBuilder,
18};
19
20mod interactive_mode;
21use interactive_mode::interactive_mode;
22mod mcp_server;
23
24#[derive(Parser)]
25#[command(version, about, long_about = None)]
26struct Args {
27 #[arg(long)]
29 serve_ip: Option<String>,
30
31 #[arg(short, long)]
33 seed: Option<u64>,
34
35 #[arg(short, long)]
37 port: Option<u16>,
38
39 #[clap(long, short)]
41 log: Option<String>,
42
43 #[clap(subcommand)]
45 model: ModelSelected,
46
47 #[arg(long, default_value_t = defaults::MAX_SEQS)]
49 max_seqs: usize,
50
51 #[arg(long, default_value_t = defaults::NO_KV_CACHE)]
53 no_kv_cache: bool,
54
55 #[arg(short, long)]
58 chat_template: Option<String>,
59
60 #[arg(short, long)]
62 jinja_explicit: Option<String>,
63
64 #[arg(long, default_value_t = defaults::TOKEN_SOURCE, value_parser = parse_token_source)]
68 token_source: TokenSource,
69
70 #[clap(long, short, action)]
72 interactive_mode: bool,
73
74 #[arg(long, default_value_t = defaults::PREFIX_CACHE_N)]
76 prefix_cache_n: usize,
77
78 #[arg(short, long, value_parser, value_delimiter = ';')]
83 num_device_layers: Option<Vec<String>>,
84
85 #[arg(long = "isq")]
87 in_situ_quant: Option<String>,
88
89 #[arg(long = "pa-gpu-mem")]
93 paged_attn_gpu_mem: Option<usize>,
94
95 #[arg(long = "pa-gpu-mem-usage")]
100 paged_attn_gpu_mem_usage: Option<f32>,
101
102 #[arg(long = "pa-ctxt-len")]
107 paged_ctxt_len: Option<usize>,
108
109 #[arg(long = "pa-cache-type", value_parser = parse_cache_type)]
112 cache_type: Option<PagedCacheType>,
113
114 #[arg(long = "pa-blk-size")]
117 paged_attn_block_size: Option<usize>,
118
119 #[arg(
121 long = "no-paged-attn",
122 default_value_t = false,
123 conflicts_with = "paged_attn"
124 )]
125 no_paged_attn: bool,
126
127 #[arg(
129 long = "paged-attn",
130 default_value_t = false,
131 conflicts_with_all = ["no_paged_attn", "cpu"]
132 )]
133 paged_attn: bool,
134
135 #[arg(long)]
137 cpu: bool,
138
139 #[arg(long = "enable-search")]
141 enable_search: bool,
142
143 #[arg(long = "search-embedding-model")]
145 search_embedding_model: Option<SearchEmbeddingModel>,
146
147 #[arg(long = "enable-thinking")]
149 enable_thinking: bool,
150
151 #[arg(long)]
153 mcp_port: Option<u16>,
154
155 #[arg(long)]
157 mcp_config: Option<String>,
158}
159
160fn parse_token_source(s: &str) -> Result<TokenSource, String> {
161 s.parse()
162}
163
164fn parse_cache_type(s: &str) -> Result<PagedCacheType, String> {
165 s.parse()
166}
167
168fn load_mcp_config(mcp_config_path: Option<&str>) -> Result<Option<McpClientConfig>> {
170 let config_path = if let Some(path) = mcp_config_path {
171 Some(path.to_string())
172 } else {
173 std::env::var("MCP_CONFIG_PATH").ok()
175 };
176
177 if let Some(path) = config_path {
178 match std::fs::read_to_string(&path) {
179 Ok(config_content) => {
180 match serde_json::from_str::<McpClientConfig>(&config_content) {
181 Ok(config) => {
182 if let Err(e) = validate_mcp_config(&config) {
184 error!("MCP configuration validation failed: {}", e);
185 anyhow::bail!("Invalid MCP configuration: {}", e);
186 }
187
188 info!("Loaded and validated MCP configuration from {}", path);
189 info!("Configured {} MCP servers", config.servers.len());
190 Ok(Some(config))
191 }
192 Err(e) => {
193 error!("Failed to parse MCP configuration: {}", e);
194 error!("Please check your JSON syntax and ensure it matches the MCP configuration schema");
195 anyhow::bail!("Invalid MCP configuration format: {}", e);
196 }
197 }
198 }
199 Err(e) => {
200 error!("Failed to read MCP configuration file {}: {}", path, e);
201 error!("Please ensure the file exists and is readable");
202 anyhow::bail!("Cannot read MCP configuration file: {}", e);
203 }
204 }
205 } else {
206 Ok(None)
207 }
208}
209
210fn validate_mcp_config(config: &McpClientConfig) -> Result<()> {
212 use std::collections::HashSet;
213
214 let mut seen_ids = HashSet::new();
216 for server in &config.servers {
217 if !seen_ids.insert(&server.id) {
218 anyhow::bail!("Duplicate server ID: {}", server.id);
219 }
220
221 if !server
223 .id
224 .chars()
225 .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
226 {
227 anyhow::bail!(
228 "Invalid server ID '{}': must contain only alphanumeric, hyphen, underscore",
229 server.id
230 );
231 }
232
233 match &server.source {
235 mistralrs_core::McpServerSource::Http { url, .. }
236 | mistralrs_core::McpServerSource::WebSocket { url, .. } => {
237 if !url.starts_with("http://")
239 && !url.starts_with("https://")
240 && !url.starts_with("ws://")
241 && !url.starts_with("wss://")
242 {
243 anyhow::bail!("Invalid URL for server '{}': must start with http://, https://, ws://, or wss://", server.id);
244 }
245 if url.len() < 10 {
246 anyhow::bail!("Invalid URL for server '{}': URL too short", server.id);
247 }
248 }
249 mistralrs_core::McpServerSource::Process { command, .. } => {
250 if command.is_empty() {
251 anyhow::bail!("Empty command for server '{}'", server.id);
252 }
253 }
254 }
255 }
256
257 if let Some(timeout) = config.tool_timeout_secs {
259 if timeout == 0 {
260 anyhow::bail!("tool_timeout_secs must be greater than 0");
261 }
262 }
263
264 if let Some(max_calls) = config.max_concurrent_calls {
265 if max_calls == 0 {
266 anyhow::bail!("max_concurrent_calls must be greater than 0");
267 }
268 }
269
270 Ok(())
271}
272
273#[derive(Clone, serde::Deserialize)]
275struct ModelConfigParsed {
276 #[serde(flatten)]
278 model: ModelSelected,
279 chat_template: Option<String>,
281 jinja_explicit: Option<String>,
283 num_device_layers: Option<Vec<String>>,
285 in_situ_quant: Option<String>,
287}
288
289fn load_multi_model_config(config_path: &str) -> Result<Vec<ModelConfig>> {
291 let config_content = std::fs::read_to_string(config_path).map_err(|e| {
292 anyhow::anyhow!(
293 "Failed to read multi-model config file {}: {}",
294 config_path,
295 e
296 )
297 })?;
298
299 let configs_parsed: HashMap<String, ModelConfigParsed> = serde_json::from_str(&config_content)
300 .map_err(|e| anyhow::anyhow!("Failed to parse multi-model config: {}", e))?;
301
302 if configs_parsed.is_empty() {
303 anyhow::bail!("Multi-model configuration file is empty");
304 }
305
306 let mut configs = Vec::new();
307 for (model_id, parsed_config) in configs_parsed {
308 let config = ModelConfig {
309 model_id,
310 model: parsed_config.model,
311 chat_template: parsed_config.chat_template,
312 jinja_explicit: parsed_config.jinja_explicit,
313 num_device_layers: parsed_config.num_device_layers,
314 in_situ_quant: parsed_config.in_situ_quant,
315 };
316 configs.push(config);
317 }
318
319 info!(
320 "Loaded multi-model configuration with {} models",
321 configs.len()
322 );
323 Ok(configs)
324}
325
326#[tokio::main]
327async fn main() -> Result<()> {
328 let args = Args::parse();
329
330 initialize_logging();
331
332 let mcp_config = load_mcp_config(args.mcp_config.as_deref())?;
334
335 let paged_attn = configure_paged_attn_from_flags(args.paged_attn, args.no_paged_attn)?;
336
337 let mistralrs = match args.model {
338 ModelSelected::MultiModel {
339 config,
340 default_model_id,
341 } => {
342 let model_configs = load_multi_model_config(&config)?;
344
345 let mut builder = MistralRsForServerBuilder::new()
346 .with_max_seqs(args.max_seqs)
347 .with_no_kv_cache(args.no_kv_cache)
348 .with_token_source(args.token_source)
349 .with_interactive_mode(args.interactive_mode)
350 .with_prefix_cache_n(args.prefix_cache_n)
351 .set_paged_attn(paged_attn)
352 .with_cpu(args.cpu)
353 .with_enable_search(args.enable_search)
354 .with_seed_optional(args.seed)
355 .with_log_optional(args.log)
356 .with_mcp_config_optional(mcp_config)
357 .with_paged_attn_cache_type(args.cache_type.unwrap_or_default());
358
359 for config in model_configs {
361 builder = builder.add_model_config(config);
362 }
363
364 if let Some(default_id) = default_model_id {
366 builder = builder.with_default_model_id(default_id);
367 }
368
369 if let Some(model) = args.search_embedding_model {
370 builder = builder.with_search_embedding_model(model);
371 }
372
373 builder.build_multi_model().await?
374 }
375 model => {
376 let mut builder = MistralRsForServerBuilder::new()
378 .with_model(model)
379 .with_max_seqs(args.max_seqs)
380 .with_no_kv_cache(args.no_kv_cache)
381 .with_token_source(args.token_source)
382 .with_interactive_mode(args.interactive_mode)
383 .with_prefix_cache_n(args.prefix_cache_n)
384 .set_paged_attn(paged_attn)
385 .with_cpu(args.cpu)
386 .with_enable_search(args.enable_search)
387 .with_seed_optional(args.seed)
388 .with_log_optional(args.log)
389 .with_chat_template_optional(args.chat_template)
390 .with_jinja_explicit_optional(args.jinja_explicit)
391 .with_num_device_layers_optional(args.num_device_layers)
392 .with_in_situ_quant_optional(args.in_situ_quant)
393 .with_paged_attn_gpu_mem_optional(args.paged_attn_gpu_mem)
394 .with_paged_attn_gpu_mem_usage_optional(args.paged_attn_gpu_mem_usage)
395 .with_paged_ctxt_len_optional(args.paged_ctxt_len)
396 .with_paged_attn_block_size_optional(args.paged_attn_block_size)
397 .with_mcp_config_optional(mcp_config)
398 .with_paged_attn_cache_type(args.cache_type.unwrap_or_default());
399
400 if let Some(model) = args.search_embedding_model {
401 builder = builder.with_search_embedding_model(model);
402 }
403
404 builder.build().await?
405 }
406 };
407
408 let search_embedding_model =
410 get_search_embedding_model(args.enable_search, args.search_embedding_model);
411
412 if args.interactive_mode {
413 interactive_mode(
414 mistralrs,
415 search_embedding_model.is_some(),
416 args.enable_thinking.then_some(true),
417 )
418 .await;
419 return Ok(());
420 }
421
422 if !args.interactive_mode && args.port.is_none() && args.mcp_port.is_none() {
423 anyhow::bail!("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port` or `--mcp-port`?")
424 }
425
426 let mcp_port = if let Some(port) = args.mcp_port {
427 let host = args
428 .serve_ip
429 .clone()
430 .unwrap_or_else(|| "0.0.0.0".to_string());
431 info!("MCP server listening on http://{host}:{port}/mcp.");
432 info!("MCP protocol version is {}.", LATEST_PROTOCOL_VERSION);
433 let mcp_server = mcp_server::create_http_mcp_server(mistralrs.clone(), host, port);
434
435 tokio::spawn(async move {
436 if let Err(e) = mcp_server.await {
437 eprintln!("MCP server error: {e}");
438 }
439 })
440 } else {
441 tokio::spawn(async {})
442 };
443
444 let oai_port = if let Some(port) = args.port {
445 let ip = args
446 .serve_ip
447 .clone()
448 .unwrap_or_else(|| "0.0.0.0".to_string());
449
450 let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?;
452
453 let app = MistralRsServerRouterBuilder::new()
454 .with_mistralrs(mistralrs)
455 .build()
456 .await?;
457
458 info!("OpenAI-compatible server listening on http://{ip}:{port}.");
459
460 tokio::spawn(async move {
461 if let Err(e) = axum::serve(listener, app).await {
462 eprintln!("OpenAI server error: {e}");
463 }
464 })
465 } else {
466 tokio::spawn(async {})
467 };
468
469 let (_, _) = join!(oai_port, mcp_port);
470
471 Ok(())
472}