1use crate::tools::{Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType};
2use crate::transport::{HttpTransport, McpTransport, ProcessTransport, WebSocketTransport};
3use crate::types::McpToolResult;
4use crate::{McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo};
5use anyhow::Result;
6use rust_mcp_schema::Resource;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::Semaphore;
12
13#[async_trait::async_trait]
15pub trait McpServerConnection: Send + Sync {
16 fn server_id(&self) -> &str;
18
19 fn server_name(&self) -> &str;
21
22 async fn list_tools(&self) -> Result<Vec<McpToolInfo>>;
24
25 async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String>;
27
28 async fn list_resources(&self) -> Result<Vec<Resource>>;
30
31 async fn read_resource(&self, uri: &str) -> Result<String>;
33
34 async fn ping(&self) -> Result<()>;
36}
37
38pub struct McpClient {
72 config: McpClientConfig,
74 servers: HashMap<String, Arc<dyn McpServerConnection>>,
76 tools: HashMap<String, McpToolInfo>,
78 tool_callbacks: HashMap<String, Arc<ToolCallback>>,
80 tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
82 concurrency_semaphore: Arc<Semaphore>,
84}
85
86impl McpClient {
87 pub fn new(config: McpClientConfig) -> Self {
89 let max_concurrent = config.max_concurrent_calls.unwrap_or(10);
90 Self {
91 config,
92 servers: HashMap::new(),
93 tools: HashMap::new(),
94 tool_callbacks: HashMap::new(),
95 tool_callbacks_with_tools: HashMap::new(),
96 concurrency_semaphore: Arc::new(Semaphore::new(max_concurrent)),
97 }
98 }
99
100 pub async fn initialize(&mut self) -> Result<()> {
102 for server_config in &self.config.servers {
103 if server_config.enabled {
104 let connection = self.create_connection(server_config).await?;
105 self.servers.insert(server_config.id.clone(), connection);
106 }
107 }
108
109 if self.config.auto_register_tools {
110 self.discover_and_register_tools().await?;
111 }
112
113 Ok(())
114 }
115
116 pub fn get_tool_callbacks(&self) -> &HashMap<String, Arc<ToolCallback>> {
118 &self.tool_callbacks
119 }
120
121 pub fn get_tool_callbacks_with_tools(&self) -> &HashMap<String, ToolCallbackWithTool> {
123 &self.tool_callbacks_with_tools
124 }
125
126 pub fn get_tools(&self) -> &HashMap<String, McpToolInfo> {
128 &self.tools
129 }
130
131 async fn create_connection(
133 &self,
134 config: &McpServerConfig,
135 ) -> Result<Arc<dyn McpServerConnection>> {
136 match &config.source {
137 McpServerSource::Http {
138 url,
139 timeout_secs,
140 headers,
141 } => {
142 let mut merged_headers = headers.clone().unwrap_or_default();
144 if let Some(token) = &config.bearer_token {
145 merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
146 }
147
148 let connection = HttpMcpConnection::new(
149 config.id.clone(),
150 config.name.clone(),
151 url.clone(),
152 *timeout_secs,
153 Some(merged_headers),
154 )
155 .await?;
156 Ok(Arc::new(connection))
157 }
158 McpServerSource::Process {
159 command,
160 args,
161 work_dir,
162 env,
163 } => {
164 let connection = ProcessMcpConnection::new(
165 config.id.clone(),
166 config.name.clone(),
167 command.clone(),
168 args.clone(),
169 work_dir.clone(),
170 env.clone(),
171 )
172 .await?;
173 Ok(Arc::new(connection))
174 }
175 McpServerSource::WebSocket {
176 url,
177 timeout_secs,
178 headers,
179 } => {
180 let mut merged_headers = headers.clone().unwrap_or_default();
182 if let Some(token) = &config.bearer_token {
183 merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
184 }
185
186 let connection = WebSocketMcpConnection::new(
187 config.id.clone(),
188 config.name.clone(),
189 url.clone(),
190 *timeout_secs,
191 Some(merged_headers),
192 )
193 .await?;
194 Ok(Arc::new(connection))
195 }
196 }
197 }
198
199 async fn discover_and_register_tools(&mut self) -> Result<()> {
201 for (server_id, connection) in &self.servers {
202 let tools = connection.list_tools().await?;
203 let server_config = self
204 .config
205 .servers
206 .iter()
207 .find(|s| &s.id == server_id)
208 .ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?;
209
210 for tool in tools {
211 let tool_name = if let Some(prefix) = &server_config.tool_prefix {
212 format!("{}_{}", prefix, tool.name)
213 } else {
214 tool.name.clone()
215 };
216
217 let connection_clone = Arc::clone(connection);
219 let original_tool_name = tool.name.clone();
220 let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
221 let timeout_duration =
222 Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
223
224 let callback: Arc<ToolCallback> = Arc::new(move |called_function| {
225 let connection = Arc::clone(&connection_clone);
226 let tool_name = original_tool_name.clone();
227 let semaphore = Arc::clone(&semaphore_clone);
228 let arguments: serde_json::Value =
229 serde_json::from_str(&called_function.arguments)?;
230
231 let rt = tokio::runtime::Handle::current();
233 std::thread::spawn(move || {
234 rt.block_on(async move {
235 let _permit = semaphore.acquire().await.map_err(|_| {
237 anyhow::anyhow!("Failed to acquire concurrency permit")
238 })?;
239
240 match tokio::time::timeout(
242 timeout_duration,
243 connection.call_tool(&tool_name, arguments),
244 )
245 .await
246 {
247 Ok(result) => result,
248 Err(_) => Err(anyhow::anyhow!(
249 "Tool call timed out after {} seconds",
250 timeout_duration.as_secs()
251 )),
252 }
253 })
254 })
255 .join()
256 .map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
257 });
258
259 let function_def = Function {
261 name: tool_name.clone(),
262 description: tool.description.clone(),
263 parameters: Self::convert_mcp_schema_to_parameters(&tool.input_schema),
264 };
265
266 let tool_def = Tool {
267 tp: ToolType::Function,
268 function: function_def,
269 };
270
271 self.tool_callbacks
273 .insert(tool_name.clone(), callback.clone());
274 self.tool_callbacks_with_tools.insert(
275 tool_name.clone(),
276 ToolCallbackWithTool {
277 callback,
278 tool: tool_def,
279 },
280 );
281 self.tools.insert(tool_name, tool);
282 }
283 }
284
285 Ok(())
286 }
287
288 fn convert_mcp_schema_to_parameters(
290 schema: &serde_json::Value,
291 ) -> Option<HashMap<String, serde_json::Value>> {
292 match schema {
294 serde_json::Value::Object(obj) => {
295 let mut params = HashMap::new();
296
297 if let Some(properties) = obj.get("properties") {
299 if let serde_json::Value::Object(props) = properties {
300 for (key, value) in props {
301 params.insert(key.clone(), value.clone());
302 }
303 }
304 } else {
305 for (key, value) in obj {
307 params.insert(key.clone(), value.clone());
308 }
309 }
310
311 if params.is_empty() {
312 None
313 } else {
314 Some(params)
315 }
316 }
317 _ => {
318 None
320 }
321 }
322 }
323}
324
325pub struct HttpMcpConnection {
327 server_id: String,
328 server_name: String,
329 transport: Arc<dyn McpTransport>,
330}
331
332impl HttpMcpConnection {
333 pub async fn new(
334 server_id: String,
335 server_name: String,
336 url: String,
337 timeout_secs: Option<u64>,
338 headers: Option<HashMap<String, String>>,
339 ) -> Result<Self> {
340 let transport = HttpTransport::new(url, timeout_secs, headers)?;
341
342 let connection = Self {
343 server_id,
344 server_name,
345 transport: Arc::new(transport),
346 };
347
348 connection.initialize().await?;
350
351 Ok(connection)
352 }
353
354 async fn initialize(&self) -> Result<()> {
355 let init_params = serde_json::json!({
356 "protocolVersion": "2025-03-26",
357 "capabilities": {
358 "tools": {}
359 },
360 "clientInfo": {
361 "name": "mistral.rs",
362 "version": "0.6.0"
363 }
364 });
365
366 self.transport
367 .send_request("initialize", init_params)
368 .await?;
369 self.transport.send_initialization_notification().await?;
370 Ok(())
371 }
372}
373
374#[async_trait::async_trait]
375impl McpServerConnection for HttpMcpConnection {
376 fn server_id(&self) -> &str {
377 &self.server_id
378 }
379
380 fn server_name(&self) -> &str {
381 &self.server_name
382 }
383
384 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
385 let result = self
386 .transport
387 .send_request("tools/list", Value::Null)
388 .await?;
389
390 let tools = result
391 .get("tools")
392 .and_then(|t| t.as_array())
393 .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
394
395 let mut tool_infos = Vec::new();
396 for tool in tools {
397 let name = tool
398 .get("name")
399 .and_then(|n| n.as_str())
400 .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
401 .to_string();
402
403 let description = tool
404 .get("description")
405 .and_then(|d| d.as_str())
406 .map(|s| s.to_string());
407
408 let input_schema = tool
409 .get("inputSchema")
410 .cloned()
411 .unwrap_or(Value::Object(serde_json::Map::new()));
412
413 tool_infos.push(McpToolInfo {
414 name,
415 description,
416 input_schema,
417 server_id: self.server_id.clone(),
418 server_name: self.server_name.clone(),
419 });
420 }
421
422 Ok(tool_infos)
423 }
424
425 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
426 let params = serde_json::json!({
427 "name": name,
428 "arguments": arguments
429 });
430
431 let result = self.transport.send_request("tools/call", params).await?;
432
433 let tool_result: McpToolResult = serde_json::from_value(result)?;
435
436 if tool_result.is_error.unwrap_or(false) {
438 return Err(anyhow::anyhow!(
439 "Tool execution failed: {}",
440 tool_result.to_string()
441 ));
442 }
443
444 Ok(tool_result.to_string())
445 }
446
447 async fn list_resources(&self) -> Result<Vec<Resource>> {
448 let result = self
449 .transport
450 .send_request("resources/list", Value::Null)
451 .await?;
452
453 let resources = result
454 .get("resources")
455 .and_then(|r| r.as_array())
456 .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
457
458 let mut resource_list = Vec::new();
459 for resource in resources {
460 let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
461 resource_list.push(mcp_resource);
462 }
463
464 Ok(resource_list)
465 }
466
467 async fn read_resource(&self, uri: &str) -> Result<String> {
468 let params = serde_json::json!({ "uri": uri });
469 let result = self
470 .transport
471 .send_request("resources/read", params)
472 .await?;
473
474 if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
476 if let Some(first_content) = contents.first() {
477 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
478 return Ok(text.to_string());
479 }
480 }
481 }
482
483 Err(anyhow::anyhow!("No readable content found in resource"))
484 }
485
486 async fn ping(&self) -> Result<()> {
487 self.transport.send_request("ping", Value::Null).await?;
489 Ok(())
490 }
491}
492
493pub struct ProcessMcpConnection {
495 server_id: String,
496 server_name: String,
497 transport: Arc<dyn McpTransport>,
498}
499
500impl ProcessMcpConnection {
501 pub async fn new(
502 server_id: String,
503 server_name: String,
504 command: String,
505 args: Vec<String>,
506 work_dir: Option<String>,
507 env: Option<HashMap<String, String>>,
508 ) -> Result<Self> {
509 let transport = ProcessTransport::new(command, args, work_dir, env).await?;
510
511 let connection = Self {
512 server_id,
513 server_name,
514 transport: Arc::new(transport),
515 };
516
517 connection.initialize().await?;
519
520 Ok(connection)
521 }
522
523 async fn initialize(&self) -> Result<()> {
524 let init_params = serde_json::json!({
525 "protocolVersion": "2025-03-26",
526 "capabilities": {
527 "tools": {}
528 },
529 "clientInfo": {
530 "name": "mistral.rs",
531 "version": "0.6.0"
532 }
533 });
534
535 self.transport
536 .send_request("initialize", init_params)
537 .await?;
538 self.transport.send_initialization_notification().await?;
539 Ok(())
540 }
541}
542
543#[async_trait::async_trait]
544impl McpServerConnection for ProcessMcpConnection {
545 fn server_id(&self) -> &str {
546 &self.server_id
547 }
548
549 fn server_name(&self) -> &str {
550 &self.server_name
551 }
552
553 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
554 let result = self
555 .transport
556 .send_request("tools/list", Value::Null)
557 .await?;
558
559 let tools = result
560 .get("tools")
561 .and_then(|t| t.as_array())
562 .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
563
564 let mut tool_infos = Vec::new();
565 for tool in tools {
566 let name = tool
567 .get("name")
568 .and_then(|n| n.as_str())
569 .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
570 .to_string();
571
572 let description = tool
573 .get("description")
574 .and_then(|d| d.as_str())
575 .map(|s| s.to_string());
576
577 let input_schema = tool
578 .get("inputSchema")
579 .cloned()
580 .unwrap_or(Value::Object(serde_json::Map::new()));
581
582 tool_infos.push(McpToolInfo {
583 name,
584 description,
585 input_schema,
586 server_id: self.server_id.clone(),
587 server_name: self.server_name.clone(),
588 });
589 }
590
591 Ok(tool_infos)
592 }
593
594 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
595 let params = serde_json::json!({
596 "name": name,
597 "arguments": arguments
598 });
599
600 let result = self.transport.send_request("tools/call", params).await?;
601
602 let tool_result: McpToolResult = serde_json::from_value(result)?;
604
605 if tool_result.is_error.unwrap_or(false) {
607 return Err(anyhow::anyhow!(
608 "Tool execution failed: {}",
609 tool_result.to_string()
610 ));
611 }
612
613 Ok(tool_result.to_string())
614 }
615
616 async fn list_resources(&self) -> Result<Vec<Resource>> {
617 let result = self
618 .transport
619 .send_request("resources/list", Value::Null)
620 .await?;
621
622 let resources = result
623 .get("resources")
624 .and_then(|r| r.as_array())
625 .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
626
627 let mut resource_list = Vec::new();
628 for resource in resources {
629 let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
630 resource_list.push(mcp_resource);
631 }
632
633 Ok(resource_list)
634 }
635
636 async fn read_resource(&self, uri: &str) -> Result<String> {
637 let params = serde_json::json!({ "uri": uri });
638 let result = self
639 .transport
640 .send_request("resources/read", params)
641 .await?;
642
643 if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
645 if let Some(first_content) = contents.first() {
646 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
647 return Ok(text.to_string());
648 }
649 }
650 }
651
652 Err(anyhow::anyhow!("No readable content found in resource"))
653 }
654
655 async fn ping(&self) -> Result<()> {
656 self.transport.send_request("ping", Value::Null).await?;
658 Ok(())
659 }
660}
661
662pub struct WebSocketMcpConnection {
664 server_id: String,
665 server_name: String,
666 transport: Arc<dyn McpTransport>,
667}
668
669impl WebSocketMcpConnection {
670 pub async fn new(
671 server_id: String,
672 server_name: String,
673 url: String,
674 timeout_secs: Option<u64>,
675 headers: Option<HashMap<String, String>>,
676 ) -> Result<Self> {
677 let transport = WebSocketTransport::new(url, timeout_secs, headers).await?;
678
679 let connection = Self {
680 server_id,
681 server_name,
682 transport: Arc::new(transport),
683 };
684
685 connection.initialize().await?;
687
688 Ok(connection)
689 }
690
691 async fn initialize(&self) -> Result<()> {
692 let init_params = serde_json::json!({
693 "protocolVersion": "2025-03-26",
694 "capabilities": {
695 "tools": {}
696 },
697 "clientInfo": {
698 "name": "mistral.rs",
699 "version": "0.6.0"
700 }
701 });
702
703 self.transport
704 .send_request("initialize", init_params)
705 .await?;
706 self.transport.send_initialization_notification().await?;
707 Ok(())
708 }
709}
710
711#[async_trait::async_trait]
712impl McpServerConnection for WebSocketMcpConnection {
713 fn server_id(&self) -> &str {
714 &self.server_id
715 }
716
717 fn server_name(&self) -> &str {
718 &self.server_name
719 }
720
721 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
722 let result = self
723 .transport
724 .send_request("tools/list", Value::Null)
725 .await?;
726
727 let tools = result
728 .get("tools")
729 .and_then(|t| t.as_array())
730 .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
731
732 let mut tool_infos = Vec::new();
733 for tool in tools {
734 let name = tool
735 .get("name")
736 .and_then(|n| n.as_str())
737 .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
738 .to_string();
739
740 let description = tool
741 .get("description")
742 .and_then(|d| d.as_str())
743 .map(|s| s.to_string());
744
745 let input_schema = tool
746 .get("inputSchema")
747 .cloned()
748 .unwrap_or(Value::Object(serde_json::Map::new()));
749
750 tool_infos.push(McpToolInfo {
751 name,
752 description,
753 input_schema,
754 server_id: self.server_id.clone(),
755 server_name: self.server_name.clone(),
756 });
757 }
758
759 Ok(tool_infos)
760 }
761
762 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
763 let params = serde_json::json!({
764 "name": name,
765 "arguments": arguments
766 });
767
768 let result = self.transport.send_request("tools/call", params).await?;
769
770 let tool_result: McpToolResult = serde_json::from_value(result)?;
772
773 if tool_result.is_error.unwrap_or(false) {
775 return Err(anyhow::anyhow!(
776 "Tool execution failed: {}",
777 tool_result.to_string()
778 ));
779 }
780
781 Ok(tool_result.to_string())
782 }
783
784 async fn list_resources(&self) -> Result<Vec<Resource>> {
785 let result = self
786 .transport
787 .send_request("resources/list", Value::Null)
788 .await?;
789
790 let resources = result
791 .get("resources")
792 .and_then(|r| r.as_array())
793 .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
794
795 let mut resource_list = Vec::new();
796 for resource in resources {
797 let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
798 resource_list.push(mcp_resource);
799 }
800
801 Ok(resource_list)
802 }
803
804 async fn read_resource(&self, uri: &str) -> Result<String> {
805 let params = serde_json::json!({ "uri": uri });
806 let result = self
807 .transport
808 .send_request("resources/read", params)
809 .await?;
810
811 if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
813 if let Some(first_content) = contents.first() {
814 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
815 return Ok(text.to_string());
816 }
817 }
818 }
819
820 Err(anyhow::anyhow!("No readable content found in resource"))
821 }
822
823 async fn ping(&self) -> Result<()> {
824 self.transport.send_request("ping", Value::Null).await?;
826 Ok(())
827 }
828}