1use crate::tools::{Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType};
2use crate::transport::{HttpTransport, McpTransport, ProcessTransport, WebSocketTransport};
3use crate::types::McpToolResult;
4use crate::{McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo};
5use anyhow::Result;
6use rust_mcp_schema::Resource;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::Semaphore;
12
13#[async_trait::async_trait]
15pub trait McpServerConnection: Send + Sync {
16 fn server_id(&self) -> &str;
18
19 fn server_name(&self) -> &str;
21
22 async fn list_tools(&self) -> Result<Vec<McpToolInfo>>;
24
25 async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String>;
27
28 async fn list_resources(&self) -> Result<Vec<Resource>>;
30
31 async fn read_resource(&self, uri: &str) -> Result<String>;
33
34 async fn ping(&self) -> Result<()>;
36}
37
38pub struct McpClient {
72 config: McpClientConfig,
74 servers: HashMap<String, Arc<dyn McpServerConnection>>,
76 tools: HashMap<String, McpToolInfo>,
78 tool_callbacks: HashMap<String, Arc<ToolCallback>>,
80 tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
82 concurrency_semaphore: Arc<Semaphore>,
84}
85
86impl McpClient {
87 pub fn new(config: McpClientConfig) -> Self {
89 let max_concurrent = config.max_concurrent_calls.unwrap_or(10);
90 Self {
91 config,
92 servers: HashMap::new(),
93 tools: HashMap::new(),
94 tool_callbacks: HashMap::new(),
95 tool_callbacks_with_tools: HashMap::new(),
96 concurrency_semaphore: Arc::new(Semaphore::new(max_concurrent)),
97 }
98 }
99
100 pub async fn initialize(&mut self) -> Result<()> {
102 for server_config in &self.config.servers {
103 if server_config.enabled {
104 let connection = self.create_connection(server_config).await?;
105 self.servers.insert(server_config.id.clone(), connection);
106 }
107 }
108
109 if self.config.auto_register_tools {
110 self.discover_and_register_tools().await?;
111 }
112
113 Ok(())
114 }
115
116 pub fn get_tool_callbacks(&self) -> &HashMap<String, Arc<ToolCallback>> {
118 &self.tool_callbacks
119 }
120
121 pub fn get_tool_callbacks_with_tools(&self) -> &HashMap<String, ToolCallbackWithTool> {
123 &self.tool_callbacks_with_tools
124 }
125
126 pub fn get_tools(&self) -> &HashMap<String, McpToolInfo> {
128 &self.tools
129 }
130
131 async fn create_connection(
133 &self,
134 config: &McpServerConfig,
135 ) -> Result<Arc<dyn McpServerConnection>> {
136 match &config.source {
137 McpServerSource::Http {
138 url,
139 timeout_secs,
140 headers,
141 } => {
142 let mut merged_headers = headers.clone().unwrap_or_default();
144 if let Some(token) = &config.bearer_token {
145 merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
146 }
147
148 let connection = HttpMcpConnection::new(
149 config.id.clone(),
150 config.name.clone(),
151 url.clone(),
152 *timeout_secs,
153 Some(merged_headers),
154 )
155 .await?;
156 Ok(Arc::new(connection))
157 }
158 McpServerSource::Process {
159 command,
160 args,
161 work_dir,
162 env,
163 } => {
164 let connection = ProcessMcpConnection::new(
165 config.id.clone(),
166 config.name.clone(),
167 command.clone(),
168 args.clone(),
169 work_dir.clone(),
170 env.clone(),
171 )
172 .await?;
173 Ok(Arc::new(connection))
174 }
175 McpServerSource::WebSocket {
176 url,
177 timeout_secs,
178 headers,
179 } => {
180 let mut merged_headers = headers.clone().unwrap_or_default();
182 if let Some(token) = &config.bearer_token {
183 merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
184 }
185
186 let connection = WebSocketMcpConnection::new(
187 config.id.clone(),
188 config.name.clone(),
189 url.clone(),
190 *timeout_secs,
191 Some(merged_headers),
192 )
193 .await?;
194 Ok(Arc::new(connection))
195 }
196 }
197 }
198
199 async fn discover_and_register_tools(&mut self) -> Result<()> {
201 for (server_id, connection) in &self.servers {
202 let tools = connection.list_tools().await?;
203 let server_config = self
204 .config
205 .servers
206 .iter()
207 .find(|s| &s.id == server_id)
208 .ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?;
209
210 for tool in tools {
211 let tool_name = if let Some(prefix) = &server_config.tool_prefix {
212 format!("{}_{}", prefix, tool.name)
213 } else {
214 tool.name.clone()
215 };
216
217 let connection_clone = Arc::clone(connection);
219 let original_tool_name = tool.name.clone();
220 let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
221 let timeout_duration =
222 Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
223
224 let callback: Arc<ToolCallback> = Arc::new(move |called_function| {
225 let connection = Arc::clone(&connection_clone);
226 let tool_name = original_tool_name.clone();
227 let semaphore = Arc::clone(&semaphore_clone);
228 let arguments: serde_json::Value =
229 serde_json::from_str(&called_function.arguments)?;
230
231 let rt = tokio::runtime::Handle::current();
233 std::thread::spawn(move || {
234 rt.block_on(async move {
235 let _permit = semaphore.acquire().await.map_err(|_| {
237 anyhow::anyhow!("Failed to acquire concurrency permit")
238 })?;
239
240 match tokio::time::timeout(
242 timeout_duration,
243 connection.call_tool(&tool_name, arguments),
244 )
245 .await
246 {
247 Ok(result) => result,
248 Err(_) => Err(anyhow::anyhow!(
249 "Tool call timed out after {} seconds",
250 timeout_duration.as_secs()
251 )),
252 }
253 })
254 })
255 .join()
256 .map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
257 });
258
259 let function_def = Function {
261 name: tool_name.clone(),
262 description: tool.description.clone(),
263 parameters: Self::convert_mcp_schema_to_parameters(&tool.input_schema),
264 };
265
266 let tool_def = Tool {
267 tp: ToolType::Function,
268 function: function_def,
269 };
270
271 self.tool_callbacks
273 .insert(tool_name.clone(), callback.clone());
274 self.tool_callbacks_with_tools.insert(
275 tool_name.clone(),
276 ToolCallbackWithTool {
277 callback,
278 tool: tool_def,
279 },
280 );
281 self.tools.insert(tool_name, tool);
282 }
283 }
284
285 Ok(())
286 }
287
288 fn convert_mcp_schema_to_parameters(
290 schema: &serde_json::Value,
291 ) -> Option<HashMap<String, serde_json::Value>> {
292 match schema {
294 serde_json::Value::Object(obj) => {
295 let mut params = HashMap::new();
296
297 if let Some(properties) = obj.get("properties") {
299 if let serde_json::Value::Object(props) = properties {
300 for (key, value) in props {
301 params.insert(key.clone(), value.clone());
302 }
303 }
304 } else {
305 for (key, value) in obj {
307 params.insert(key.clone(), value.clone());
308 }
309 }
310
311 if params.is_empty() {
312 None
313 } else {
314 Some(params)
315 }
316 }
317 _ => {
318 None
320 }
321 }
322 }
323}
324
325pub struct HttpMcpConnection {
327 server_id: String,
328 server_name: String,
329 transport: Arc<dyn McpTransport>,
330}
331
332impl HttpMcpConnection {
333 pub async fn new(
334 server_id: String,
335 server_name: String,
336 url: String,
337 timeout_secs: Option<u64>,
338 headers: Option<HashMap<String, String>>,
339 ) -> Result<Self> {
340 let transport = HttpTransport::new(url, timeout_secs, headers)?;
341
342 let connection = Self {
343 server_id,
344 server_name,
345 transport: Arc::new(transport),
346 };
347
348 connection.initialize().await?;
350
351 Ok(connection)
352 }
353
354 async fn initialize(&self) -> Result<()> {
355 let init_params = serde_json::json!({
356 "protocolVersion": "2025-03-26",
357 "capabilities": {
358 "tools": {}
359 },
360 "clientInfo": {
361 "name": "mistral.rs",
362 "version": "0.6.0"
363 }
364 });
365
366 self.transport
367 .send_request("initialize", init_params)
368 .await?;
369 Ok(())
370 }
371}
372
373#[async_trait::async_trait]
374impl McpServerConnection for HttpMcpConnection {
375 fn server_id(&self) -> &str {
376 &self.server_id
377 }
378
379 fn server_name(&self) -> &str {
380 &self.server_name
381 }
382
383 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
384 let result = self
385 .transport
386 .send_request("tools/list", Value::Null)
387 .await?;
388
389 let tools = result
390 .get("tools")
391 .and_then(|t| t.as_array())
392 .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
393
394 let mut tool_infos = Vec::new();
395 for tool in tools {
396 let name = tool
397 .get("name")
398 .and_then(|n| n.as_str())
399 .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
400 .to_string();
401
402 let description = tool
403 .get("description")
404 .and_then(|d| d.as_str())
405 .map(|s| s.to_string());
406
407 let input_schema = tool
408 .get("inputSchema")
409 .cloned()
410 .unwrap_or(Value::Object(serde_json::Map::new()));
411
412 tool_infos.push(McpToolInfo {
413 name,
414 description,
415 input_schema,
416 server_id: self.server_id.clone(),
417 server_name: self.server_name.clone(),
418 });
419 }
420
421 Ok(tool_infos)
422 }
423
424 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
425 let params = serde_json::json!({
426 "name": name,
427 "arguments": arguments
428 });
429
430 let result = self.transport.send_request("tools/call", params).await?;
431
432 let tool_result: McpToolResult = serde_json::from_value(result)?;
434
435 if tool_result.is_error.unwrap_or(false) {
437 return Err(anyhow::anyhow!(
438 "Tool execution failed: {}",
439 tool_result.to_string()
440 ));
441 }
442
443 Ok(tool_result.to_string())
444 }
445
446 async fn list_resources(&self) -> Result<Vec<Resource>> {
447 let result = self
448 .transport
449 .send_request("resources/list", Value::Null)
450 .await?;
451
452 let resources = result
453 .get("resources")
454 .and_then(|r| r.as_array())
455 .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
456
457 let mut resource_list = Vec::new();
458 for resource in resources {
459 let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
460 resource_list.push(mcp_resource);
461 }
462
463 Ok(resource_list)
464 }
465
466 async fn read_resource(&self, uri: &str) -> Result<String> {
467 let params = serde_json::json!({ "uri": uri });
468 let result = self
469 .transport
470 .send_request("resources/read", params)
471 .await?;
472
473 if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
475 if let Some(first_content) = contents.first() {
476 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
477 return Ok(text.to_string());
478 }
479 }
480 }
481
482 Err(anyhow::anyhow!("No readable content found in resource"))
483 }
484
485 async fn ping(&self) -> Result<()> {
486 self.transport.send_request("ping", Value::Null).await?;
488 Ok(())
489 }
490}
491
492pub struct ProcessMcpConnection {
494 server_id: String,
495 server_name: String,
496 transport: Arc<dyn McpTransport>,
497}
498
499impl ProcessMcpConnection {
500 pub async fn new(
501 server_id: String,
502 server_name: String,
503 command: String,
504 args: Vec<String>,
505 work_dir: Option<String>,
506 env: Option<HashMap<String, String>>,
507 ) -> Result<Self> {
508 let transport = ProcessTransport::new(command, args, work_dir, env).await?;
509
510 let connection = Self {
511 server_id,
512 server_name,
513 transport: Arc::new(transport),
514 };
515
516 connection.initialize().await?;
518
519 Ok(connection)
520 }
521
522 async fn initialize(&self) -> Result<()> {
523 let init_params = serde_json::json!({
524 "protocolVersion": "2025-03-26",
525 "capabilities": {
526 "tools": {}
527 },
528 "clientInfo": {
529 "name": "mistral.rs",
530 "version": "0.6.0"
531 }
532 });
533
534 self.transport
535 .send_request("initialize", init_params)
536 .await?;
537 Ok(())
538 }
539}
540
541#[async_trait::async_trait]
542impl McpServerConnection for ProcessMcpConnection {
543 fn server_id(&self) -> &str {
544 &self.server_id
545 }
546
547 fn server_name(&self) -> &str {
548 &self.server_name
549 }
550
551 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
552 let result = self
553 .transport
554 .send_request("tools/list", Value::Null)
555 .await?;
556
557 let tools = result
558 .get("tools")
559 .and_then(|t| t.as_array())
560 .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
561
562 let mut tool_infos = Vec::new();
563 for tool in tools {
564 let name = tool
565 .get("name")
566 .and_then(|n| n.as_str())
567 .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
568 .to_string();
569
570 let description = tool
571 .get("description")
572 .and_then(|d| d.as_str())
573 .map(|s| s.to_string());
574
575 let input_schema = tool
576 .get("inputSchema")
577 .cloned()
578 .unwrap_or(Value::Object(serde_json::Map::new()));
579
580 tool_infos.push(McpToolInfo {
581 name,
582 description,
583 input_schema,
584 server_id: self.server_id.clone(),
585 server_name: self.server_name.clone(),
586 });
587 }
588
589 Ok(tool_infos)
590 }
591
592 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
593 let params = serde_json::json!({
594 "name": name,
595 "arguments": arguments
596 });
597
598 let result = self.transport.send_request("tools/call", params).await?;
599
600 let tool_result: McpToolResult = serde_json::from_value(result)?;
602
603 if tool_result.is_error.unwrap_or(false) {
605 return Err(anyhow::anyhow!(
606 "Tool execution failed: {}",
607 tool_result.to_string()
608 ));
609 }
610
611 Ok(tool_result.to_string())
612 }
613
614 async fn list_resources(&self) -> Result<Vec<Resource>> {
615 let result = self
616 .transport
617 .send_request("resources/list", Value::Null)
618 .await?;
619
620 let resources = result
621 .get("resources")
622 .and_then(|r| r.as_array())
623 .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
624
625 let mut resource_list = Vec::new();
626 for resource in resources {
627 let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
628 resource_list.push(mcp_resource);
629 }
630
631 Ok(resource_list)
632 }
633
634 async fn read_resource(&self, uri: &str) -> Result<String> {
635 let params = serde_json::json!({ "uri": uri });
636 let result = self
637 .transport
638 .send_request("resources/read", params)
639 .await?;
640
641 if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
643 if let Some(first_content) = contents.first() {
644 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
645 return Ok(text.to_string());
646 }
647 }
648 }
649
650 Err(anyhow::anyhow!("No readable content found in resource"))
651 }
652
653 async fn ping(&self) -> Result<()> {
654 self.transport.send_request("ping", Value::Null).await?;
656 Ok(())
657 }
658}
659
660pub struct WebSocketMcpConnection {
662 server_id: String,
663 server_name: String,
664 transport: Arc<dyn McpTransport>,
665}
666
667impl WebSocketMcpConnection {
668 pub async fn new(
669 server_id: String,
670 server_name: String,
671 url: String,
672 timeout_secs: Option<u64>,
673 headers: Option<HashMap<String, String>>,
674 ) -> Result<Self> {
675 let transport = WebSocketTransport::new(url, timeout_secs, headers).await?;
676
677 let connection = Self {
678 server_id,
679 server_name,
680 transport: Arc::new(transport),
681 };
682
683 connection.initialize().await?;
685
686 Ok(connection)
687 }
688
689 async fn initialize(&self) -> Result<()> {
690 let init_params = serde_json::json!({
691 "protocolVersion": "2025-03-26",
692 "capabilities": {
693 "tools": {}
694 },
695 "clientInfo": {
696 "name": "mistral.rs",
697 "version": "0.6.0"
698 }
699 });
700
701 self.transport
702 .send_request("initialize", init_params)
703 .await?;
704 Ok(())
705 }
706}
707
708#[async_trait::async_trait]
709impl McpServerConnection for WebSocketMcpConnection {
710 fn server_id(&self) -> &str {
711 &self.server_id
712 }
713
714 fn server_name(&self) -> &str {
715 &self.server_name
716 }
717
718 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
719 let result = self
720 .transport
721 .send_request("tools/list", Value::Null)
722 .await?;
723
724 let tools = result
725 .get("tools")
726 .and_then(|t| t.as_array())
727 .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
728
729 let mut tool_infos = Vec::new();
730 for tool in tools {
731 let name = tool
732 .get("name")
733 .and_then(|n| n.as_str())
734 .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
735 .to_string();
736
737 let description = tool
738 .get("description")
739 .and_then(|d| d.as_str())
740 .map(|s| s.to_string());
741
742 let input_schema = tool
743 .get("inputSchema")
744 .cloned()
745 .unwrap_or(Value::Object(serde_json::Map::new()));
746
747 tool_infos.push(McpToolInfo {
748 name,
749 description,
750 input_schema,
751 server_id: self.server_id.clone(),
752 server_name: self.server_name.clone(),
753 });
754 }
755
756 Ok(tool_infos)
757 }
758
759 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
760 let params = serde_json::json!({
761 "name": name,
762 "arguments": arguments
763 });
764
765 let result = self.transport.send_request("tools/call", params).await?;
766
767 let tool_result: McpToolResult = serde_json::from_value(result)?;
769
770 if tool_result.is_error.unwrap_or(false) {
772 return Err(anyhow::anyhow!(
773 "Tool execution failed: {}",
774 tool_result.to_string()
775 ));
776 }
777
778 Ok(tool_result.to_string())
779 }
780
781 async fn list_resources(&self) -> Result<Vec<Resource>> {
782 let result = self
783 .transport
784 .send_request("resources/list", Value::Null)
785 .await?;
786
787 let resources = result
788 .get("resources")
789 .and_then(|r| r.as_array())
790 .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
791
792 let mut resource_list = Vec::new();
793 for resource in resources {
794 let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
795 resource_list.push(mcp_resource);
796 }
797
798 Ok(resource_list)
799 }
800
801 async fn read_resource(&self, uri: &str) -> Result<String> {
802 let params = serde_json::json!({ "uri": uri });
803 let result = self
804 .transport
805 .send_request("resources/read", params)
806 .await?;
807
808 if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
810 if let Some(first_content) = contents.first() {
811 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
812 return Ok(text.to_string());
813 }
814 }
815 }
816
817 Err(anyhow::anyhow!("No readable content found in resource"))
818 }
819
820 async fn ping(&self) -> Result<()> {
821 self.transport.send_request("ping", Value::Null).await?;
823 Ok(())
824 }
825}