1use crate::tools::{Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType};
2use crate::transport::{HttpTransport, McpTransport, ProcessTransport, WebSocketTransport};
3use crate::types::McpToolResult;
4use crate::{McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo};
5use anyhow::Result;
6use rust_mcp_schema::Resource;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::Semaphore;
12
13#[async_trait::async_trait]
15pub trait McpServerConnection: Send + Sync {
16 fn server_id(&self) -> &str;
18
19 fn server_name(&self) -> &str;
21
22 async fn list_tools(&self) -> Result<Vec<McpToolInfo>>;
24
25 async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String>;
27
28 async fn list_resources(&self) -> Result<Vec<Resource>>;
30
31 async fn read_resource(&self, uri: &str) -> Result<String>;
33
34 async fn ping(&self) -> Result<()>;
36}
37
38pub struct McpClient {
72 config: McpClientConfig,
74 servers: HashMap<String, Arc<dyn McpServerConnection>>,
76 tools: HashMap<String, McpToolInfo>,
78 tool_callbacks: HashMap<String, Arc<ToolCallback>>,
80 tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
82 concurrency_semaphore: Arc<Semaphore>,
84}
85
86impl McpClient {
87 pub fn new(config: McpClientConfig) -> Self {
89 let max_concurrent = config.max_concurrent_calls.unwrap_or(10);
90 Self {
91 config,
92 servers: HashMap::new(),
93 tools: HashMap::new(),
94 tool_callbacks: HashMap::new(),
95 tool_callbacks_with_tools: HashMap::new(),
96 concurrency_semaphore: Arc::new(Semaphore::new(max_concurrent)),
97 }
98 }
99
100 pub async fn initialize(&mut self) -> Result<()> {
102 for server_config in &self.config.servers {
103 if server_config.enabled {
104 let connection = self.create_connection(server_config).await?;
105 self.servers.insert(server_config.id.clone(), connection);
106 }
107 }
108
109 if self.config.auto_register_tools {
110 self.discover_and_register_tools().await?;
111 }
112
113 Ok(())
114 }
115
116 pub fn get_tool_callbacks(&self) -> &HashMap<String, Arc<ToolCallback>> {
118 &self.tool_callbacks
119 }
120
121 pub fn get_tool_callbacks_with_tools(&self) -> &HashMap<String, ToolCallbackWithTool> {
123 &self.tool_callbacks_with_tools
124 }
125
126 pub fn get_tools(&self) -> &HashMap<String, McpToolInfo> {
128 &self.tools
129 }
130
131 async fn create_connection(
133 &self,
134 config: &McpServerConfig,
135 ) -> Result<Arc<dyn McpServerConnection>> {
136 match &config.source {
137 McpServerSource::Http {
138 url,
139 timeout_secs,
140 headers,
141 } => {
142 let mut merged_headers = headers.clone().unwrap_or_default();
144 if let Some(token) = &config.bearer_token {
145 merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
146 }
147
148 let connection = HttpMcpConnection::new(
149 config.id.clone(),
150 config.name.clone(),
151 url.clone(),
152 *timeout_secs,
153 Some(merged_headers),
154 )
155 .await?;
156 Ok(Arc::new(connection))
157 }
158 McpServerSource::Process {
159 command,
160 args,
161 work_dir,
162 env,
163 } => {
164 let connection = ProcessMcpConnection::new(
165 config.id.clone(),
166 config.name.clone(),
167 command.clone(),
168 args.clone(),
169 work_dir.clone(),
170 env.clone(),
171 )
172 .await?;
173 Ok(Arc::new(connection))
174 }
175 McpServerSource::WebSocket {
176 url,
177 timeout_secs,
178 headers,
179 } => {
180 let mut merged_headers = headers.clone().unwrap_or_default();
182 if let Some(token) = &config.bearer_token {
183 merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
184 }
185
186 let connection = WebSocketMcpConnection::new(
187 config.id.clone(),
188 config.name.clone(),
189 url.clone(),
190 *timeout_secs,
191 Some(merged_headers),
192 )
193 .await?;
194 Ok(Arc::new(connection))
195 }
196 }
197 }
198
199 async fn discover_and_register_tools(&mut self) -> Result<()> {
201 for (server_id, connection) in &self.servers {
202 let tools = connection.list_tools().await?;
203 let server_config = self
204 .config
205 .servers
206 .iter()
207 .find(|s| &s.id == server_id)
208 .ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?;
209
210 for tool in tools {
211 let tool_name = if let Some(prefix) = &server_config.tool_prefix {
212 format!("{}_{}", prefix, tool.name)
213 } else {
214 tool.name.clone()
215 };
216
217 let connection_clone = Arc::clone(connection);
219 let original_tool_name = tool.name.clone();
220 let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
221 let timeout_duration =
222 Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
223
224 let callback: Arc<ToolCallback> = Arc::new(move |called_function| {
225 let connection = Arc::clone(&connection_clone);
226 let tool_name = original_tool_name.clone();
227 let semaphore = Arc::clone(&semaphore_clone);
228 let arguments: serde_json::Value =
229 serde_json::from_str(&called_function.arguments)?;
230
231 let rt = tokio::runtime::Handle::current();
233 std::thread::spawn(move || {
234 rt.block_on(async move {
235 let _permit = semaphore.acquire().await.map_err(|_| {
237 anyhow::anyhow!("Failed to acquire concurrency permit")
238 })?;
239
240 match tokio::time::timeout(
242 timeout_duration,
243 connection.call_tool(&tool_name, arguments),
244 )
245 .await
246 {
247 Ok(result) => result,
248 Err(_) => Err(anyhow::anyhow!(
249 "Tool call timed out after {} seconds",
250 timeout_duration.as_secs()
251 )),
252 }
253 })
254 })
255 .join()
256 .map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
257 });
258
259 let function_def = Function {
261 name: tool_name.clone(),
262 description: tool.description.clone(),
263 parameters: Self::convert_mcp_schema_to_parameters(&tool.input_schema),
264 };
265
266 let tool_def = Tool {
267 tp: ToolType::Function,
268 function: function_def,
269 };
270
271 self.tool_callbacks
273 .insert(tool_name.clone(), callback.clone());
274 self.tool_callbacks_with_tools.insert(
275 tool_name.clone(),
276 ToolCallbackWithTool {
277 callback,
278 tool: tool_def,
279 },
280 );
281 self.tools.insert(tool_name, tool);
282 }
283 }
284
285 Ok(())
286 }
287
288 fn convert_mcp_schema_to_parameters(
290 schema: &serde_json::Value,
291 ) -> Option<HashMap<String, serde_json::Value>> {
292 match schema {
294 serde_json::Value::Object(obj) => {
295 let mut params = HashMap::new();
296
297 if let Some(properties) = obj.get("properties") {
299 if let serde_json::Value::Object(props) = properties {
300 for (key, value) in props {
301 params.insert(key.clone(), value.clone());
302 }
303 }
304 } else {
305 for (key, value) in obj {
307 params.insert(key.clone(), value.clone());
308 }
309 }
310
311 if params.is_empty() {
312 None
313 } else {
314 Some(params)
315 }
316 }
317 _ => {
318 None
320 }
321 }
322 }
323}
324
325pub struct HttpMcpConnection {
327 server_id: String,
328 server_name: String,
329 transport: Arc<dyn McpTransport>,
330}
331
332impl HttpMcpConnection {
333 pub async fn new(
334 server_id: String,
335 server_name: String,
336 url: String,
337 timeout_secs: Option<u64>,
338 headers: Option<HashMap<String, String>>,
339 ) -> Result<Self> {
340 let transport = HttpTransport::new(url, timeout_secs, headers)?;
341
342 let connection = Self {
343 server_id,
344 server_name,
345 transport: Arc::new(transport),
346 };
347
348 connection.initialize().await?;
350
351 Ok(connection)
352 }
353
354 async fn initialize(&self) -> Result<()> {
355 let init_params = serde_json::json!({
356 "protocolVersion": "2025-03-26",
357 "capabilities": {
358 "tools": {}
359 },
360 "clientInfo": {
361 "name": "mistral.rs",
362 "version": "0.6.0"
363 }
364 });
365
366 self.transport
367 .send_request("initialize", init_params)
368 .await?;
369 self.transport.send_initialization_notification().await?;
370 Ok(())
371 }
372}
373
374#[async_trait::async_trait]
375impl McpServerConnection for HttpMcpConnection {
376 fn server_id(&self) -> &str {
377 &self.server_id
378 }
379
380 fn server_name(&self) -> &str {
381 &self.server_name
382 }
383
384 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
385 let result = self
386 .transport
387 .send_request("tools/list", Value::Null)
388 .await?;
389
390 let tools = result
391 .get("tools")
392 .and_then(|t| t.as_array())
393 .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
394
395 let mut tool_infos = Vec::new();
396 for tool in tools {
397 let name = tool
398 .get("name")
399 .and_then(|n| n.as_str())
400 .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
401 .to_string();
402
403 let description = tool
404 .get("description")
405 .and_then(|d| d.as_str())
406 .map(|s| s.to_string());
407
408 let input_schema = tool
409 .get("inputSchema")
410 .cloned()
411 .unwrap_or(Value::Object(serde_json::Map::new()));
412
413 tool_infos.push(McpToolInfo {
414 name,
415 description,
416 input_schema,
417 server_id: self.server_id.clone(),
418 server_name: self.server_name.clone(),
419 });
420 }
421
422 Ok(tool_infos)
423 }
424
425 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
426 let params = serde_json::json!({
427 "name": name,
428 "arguments": arguments
429 });
430
431 let result = self.transport.send_request("tools/call", params).await?;
432
433 let tool_result: McpToolResult = serde_json::from_value(result)?;
435
436 if tool_result.is_error.unwrap_or(false) {
438 return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
439 }
440
441 Ok(tool_result.to_string())
442 }
443
444 async fn list_resources(&self) -> Result<Vec<Resource>> {
445 let result = self
446 .transport
447 .send_request("resources/list", Value::Null)
448 .await?;
449
450 let resources = result
451 .get("resources")
452 .and_then(|r| r.as_array())
453 .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
454
455 let mut resource_list = Vec::new();
456 for resource in resources {
457 let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
458 resource_list.push(mcp_resource);
459 }
460
461 Ok(resource_list)
462 }
463
464 async fn read_resource(&self, uri: &str) -> Result<String> {
465 let params = serde_json::json!({ "uri": uri });
466 let result = self
467 .transport
468 .send_request("resources/read", params)
469 .await?;
470
471 if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
473 if let Some(first_content) = contents.first() {
474 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
475 return Ok(text.to_string());
476 }
477 }
478 }
479
480 Err(anyhow::anyhow!("No readable content found in resource"))
481 }
482
483 async fn ping(&self) -> Result<()> {
484 self.transport.send_request("ping", Value::Null).await?;
486 Ok(())
487 }
488}
489
490pub struct ProcessMcpConnection {
492 server_id: String,
493 server_name: String,
494 transport: Arc<dyn McpTransport>,
495}
496
497impl ProcessMcpConnection {
498 pub async fn new(
499 server_id: String,
500 server_name: String,
501 command: String,
502 args: Vec<String>,
503 work_dir: Option<String>,
504 env: Option<HashMap<String, String>>,
505 ) -> Result<Self> {
506 let transport = ProcessTransport::new(command, args, work_dir, env).await?;
507
508 let connection = Self {
509 server_id,
510 server_name,
511 transport: Arc::new(transport),
512 };
513
514 connection.initialize().await?;
516
517 Ok(connection)
518 }
519
520 async fn initialize(&self) -> Result<()> {
521 let init_params = serde_json::json!({
522 "protocolVersion": "2025-03-26",
523 "capabilities": {
524 "tools": {}
525 },
526 "clientInfo": {
527 "name": "mistral.rs",
528 "version": "0.6.0"
529 }
530 });
531
532 self.transport
533 .send_request("initialize", init_params)
534 .await?;
535 self.transport.send_initialization_notification().await?;
536 Ok(())
537 }
538}
539
540#[async_trait::async_trait]
541impl McpServerConnection for ProcessMcpConnection {
542 fn server_id(&self) -> &str {
543 &self.server_id
544 }
545
546 fn server_name(&self) -> &str {
547 &self.server_name
548 }
549
550 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
551 let result = self
552 .transport
553 .send_request("tools/list", Value::Null)
554 .await?;
555
556 let tools = result
557 .get("tools")
558 .and_then(|t| t.as_array())
559 .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
560
561 let mut tool_infos = Vec::new();
562 for tool in tools {
563 let name = tool
564 .get("name")
565 .and_then(|n| n.as_str())
566 .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
567 .to_string();
568
569 let description = tool
570 .get("description")
571 .and_then(|d| d.as_str())
572 .map(|s| s.to_string());
573
574 let input_schema = tool
575 .get("inputSchema")
576 .cloned()
577 .unwrap_or(Value::Object(serde_json::Map::new()));
578
579 tool_infos.push(McpToolInfo {
580 name,
581 description,
582 input_schema,
583 server_id: self.server_id.clone(),
584 server_name: self.server_name.clone(),
585 });
586 }
587
588 Ok(tool_infos)
589 }
590
591 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
592 let params = serde_json::json!({
593 "name": name,
594 "arguments": arguments
595 });
596
597 let result = self.transport.send_request("tools/call", params).await?;
598
599 let tool_result: McpToolResult = serde_json::from_value(result)?;
601
602 if tool_result.is_error.unwrap_or(false) {
604 return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
605 }
606
607 Ok(tool_result.to_string())
608 }
609
610 async fn list_resources(&self) -> Result<Vec<Resource>> {
611 let result = self
612 .transport
613 .send_request("resources/list", Value::Null)
614 .await?;
615
616 let resources = result
617 .get("resources")
618 .and_then(|r| r.as_array())
619 .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
620
621 let mut resource_list = Vec::new();
622 for resource in resources {
623 let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
624 resource_list.push(mcp_resource);
625 }
626
627 Ok(resource_list)
628 }
629
630 async fn read_resource(&self, uri: &str) -> Result<String> {
631 let params = serde_json::json!({ "uri": uri });
632 let result = self
633 .transport
634 .send_request("resources/read", params)
635 .await?;
636
637 if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
639 if let Some(first_content) = contents.first() {
640 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
641 return Ok(text.to_string());
642 }
643 }
644 }
645
646 Err(anyhow::anyhow!("No readable content found in resource"))
647 }
648
649 async fn ping(&self) -> Result<()> {
650 self.transport.send_request("ping", Value::Null).await?;
652 Ok(())
653 }
654}
655
656pub struct WebSocketMcpConnection {
658 server_id: String,
659 server_name: String,
660 transport: Arc<dyn McpTransport>,
661}
662
663impl WebSocketMcpConnection {
664 pub async fn new(
665 server_id: String,
666 server_name: String,
667 url: String,
668 timeout_secs: Option<u64>,
669 headers: Option<HashMap<String, String>>,
670 ) -> Result<Self> {
671 let transport = WebSocketTransport::new(url, timeout_secs, headers).await?;
672
673 let connection = Self {
674 server_id,
675 server_name,
676 transport: Arc::new(transport),
677 };
678
679 connection.initialize().await?;
681
682 Ok(connection)
683 }
684
685 async fn initialize(&self) -> Result<()> {
686 let init_params = serde_json::json!({
687 "protocolVersion": "2025-03-26",
688 "capabilities": {
689 "tools": {}
690 },
691 "clientInfo": {
692 "name": "mistral.rs",
693 "version": "0.6.0"
694 }
695 });
696
697 self.transport
698 .send_request("initialize", init_params)
699 .await?;
700 self.transport.send_initialization_notification().await?;
701 Ok(())
702 }
703}
704
705#[async_trait::async_trait]
706impl McpServerConnection for WebSocketMcpConnection {
707 fn server_id(&self) -> &str {
708 &self.server_id
709 }
710
711 fn server_name(&self) -> &str {
712 &self.server_name
713 }
714
715 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
716 let result = self
717 .transport
718 .send_request("tools/list", Value::Null)
719 .await?;
720
721 let tools = result
722 .get("tools")
723 .and_then(|t| t.as_array())
724 .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
725
726 let mut tool_infos = Vec::new();
727 for tool in tools {
728 let name = tool
729 .get("name")
730 .and_then(|n| n.as_str())
731 .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
732 .to_string();
733
734 let description = tool
735 .get("description")
736 .and_then(|d| d.as_str())
737 .map(|s| s.to_string());
738
739 let input_schema = tool
740 .get("inputSchema")
741 .cloned()
742 .unwrap_or(Value::Object(serde_json::Map::new()));
743
744 tool_infos.push(McpToolInfo {
745 name,
746 description,
747 input_schema,
748 server_id: self.server_id.clone(),
749 server_name: self.server_name.clone(),
750 });
751 }
752
753 Ok(tool_infos)
754 }
755
756 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
757 let params = serde_json::json!({
758 "name": name,
759 "arguments": arguments
760 });
761
762 let result = self.transport.send_request("tools/call", params).await?;
763
764 let tool_result: McpToolResult = serde_json::from_value(result)?;
766
767 if tool_result.is_error.unwrap_or(false) {
769 return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
770 }
771
772 Ok(tool_result.to_string())
773 }
774
775 async fn list_resources(&self) -> Result<Vec<Resource>> {
776 let result = self
777 .transport
778 .send_request("resources/list", Value::Null)
779 .await?;
780
781 let resources = result
782 .get("resources")
783 .and_then(|r| r.as_array())
784 .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
785
786 let mut resource_list = Vec::new();
787 for resource in resources {
788 let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
789 resource_list.push(mcp_resource);
790 }
791
792 Ok(resource_list)
793 }
794
795 async fn read_resource(&self, uri: &str) -> Result<String> {
796 let params = serde_json::json!({ "uri": uri });
797 let result = self
798 .transport
799 .send_request("resources/read", params)
800 .await?;
801
802 if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
804 if let Some(first_content) = contents.first() {
805 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
806 return Ok(text.to_string());
807 }
808 }
809 }
810
811 Err(anyhow::anyhow!("No readable content found in resource"))
812 }
813
814 async fn ping(&self) -> Result<()> {
815 self.transport.send_request("ping", Value::Null).await?;
817 Ok(())
818 }
819}