mistralrs_mcp/
client.rs

1use crate::tools::{Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType};
2use crate::transport::{HttpTransport, McpTransport, ProcessTransport, WebSocketTransport};
3use crate::types::McpToolResult;
4use crate::{McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo};
5use anyhow::Result;
6use rust_mcp_schema::Resource;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::Semaphore;
12
13/// Trait for MCP server connections
14#[async_trait::async_trait]
15pub trait McpServerConnection: Send + Sync {
16    /// Get the server ID
17    fn server_id(&self) -> &str;
18
19    /// Get the server name
20    fn server_name(&self) -> &str;
21
22    /// List available tools from this server
23    async fn list_tools(&self) -> Result<Vec<McpToolInfo>>;
24
25    /// Call a tool on this server
26    async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String>;
27
28    /// List available resources from this server
29    async fn list_resources(&self) -> Result<Vec<Resource>>;
30
31    /// Read a resource from this server
32    async fn read_resource(&self, uri: &str) -> Result<String>;
33
34    /// Check if the connection is healthy
35    async fn ping(&self) -> Result<()>;
36}
37
38/// MCP client that manages connections to multiple MCP servers
39///
40/// The main interface for interacting with Model Context Protocol servers.
41/// Handles connection lifecycle, tool discovery, and provides integration
42/// with tool calling systems.
43///
44/// # Features
45///
46/// - **Multi-server Management**: Connects to and manages multiple MCP servers simultaneously
47/// - **Automatic Tool Discovery**: Discovers available tools from connected servers
48/// - **Tool Registration**: Converts MCP tools to internal Tool format for seamless integration
49/// - **Connection Pooling**: Maintains persistent connections for efficient tool execution
50/// - **Error Handling**: Robust error handling with proper cleanup and reconnection logic
51///
52/// # Example
53///
54/// ```rust,no_run
55/// use mistralrs_mcp::{McpClient, McpClientConfig};
56///
57/// #[tokio::main]
58/// async fn main() -> anyhow::Result<()> {
59///     let config = McpClientConfig::default();
60///     let mut client = McpClient::new(config);
61///     
62///     // Initialize all configured server connections
63///     client.initialize().await?;
64///     
65///     // Get tool callbacks for model integration
66///     let callbacks = client.get_tool_callbacks_with_tools();
67///     
68///     Ok(())
69/// }
70/// ```
71pub struct McpClient {
72    /// Configuration for the client including server list and policies
73    config: McpClientConfig,
74    /// Active connections to MCP servers, indexed by server ID
75    servers: HashMap<String, Arc<dyn McpServerConnection>>,
76    /// Registry of discovered tools from all connected servers
77    tools: HashMap<String, McpToolInfo>,
78    /// Legacy tool callbacks for backward compatibility
79    tool_callbacks: HashMap<String, Arc<ToolCallback>>,
80    /// Tool callbacks with associated Tool definitions for automatic tool calling
81    tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
82    /// Semaphore to control maximum concurrent tool calls
83    concurrency_semaphore: Arc<Semaphore>,
84}
85
86impl McpClient {
87    /// Create a new MCP client with the given configuration
88    pub fn new(config: McpClientConfig) -> Self {
89        let max_concurrent = config.max_concurrent_calls.unwrap_or(10);
90        Self {
91            config,
92            servers: HashMap::new(),
93            tools: HashMap::new(),
94            tool_callbacks: HashMap::new(),
95            tool_callbacks_with_tools: HashMap::new(),
96            concurrency_semaphore: Arc::new(Semaphore::new(max_concurrent)),
97        }
98    }
99
100    /// Initialize connections to all configured servers
101    pub async fn initialize(&mut self) -> Result<()> {
102        for server_config in &self.config.servers {
103            if server_config.enabled {
104                let connection = self.create_connection(server_config).await?;
105                self.servers.insert(server_config.id.clone(), connection);
106            }
107        }
108
109        if self.config.auto_register_tools {
110            self.discover_and_register_tools().await?;
111        }
112
113        Ok(())
114    }
115
116    /// Get tool callbacks that can be used with the existing tool calling system
117    pub fn get_tool_callbacks(&self) -> &HashMap<String, Arc<ToolCallback>> {
118        &self.tool_callbacks
119    }
120
121    /// Get tool callbacks with their associated Tool definitions
122    pub fn get_tool_callbacks_with_tools(&self) -> &HashMap<String, ToolCallbackWithTool> {
123        &self.tool_callbacks_with_tools
124    }
125
126    /// Get discovered tools information
127    pub fn get_tools(&self) -> &HashMap<String, McpToolInfo> {
128        &self.tools
129    }
130
131    /// Create connection based on server source type
132    async fn create_connection(
133        &self,
134        config: &McpServerConfig,
135    ) -> Result<Arc<dyn McpServerConnection>> {
136        match &config.source {
137            McpServerSource::Http {
138                url,
139                timeout_secs,
140                headers,
141            } => {
142                // Merge Bearer token with existing headers if provided
143                let mut merged_headers = headers.clone().unwrap_or_default();
144                if let Some(token) = &config.bearer_token {
145                    merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
146                }
147
148                let connection = HttpMcpConnection::new(
149                    config.id.clone(),
150                    config.name.clone(),
151                    url.clone(),
152                    *timeout_secs,
153                    Some(merged_headers),
154                )
155                .await?;
156                Ok(Arc::new(connection))
157            }
158            McpServerSource::Process {
159                command,
160                args,
161                work_dir,
162                env,
163            } => {
164                let connection = ProcessMcpConnection::new(
165                    config.id.clone(),
166                    config.name.clone(),
167                    command.clone(),
168                    args.clone(),
169                    work_dir.clone(),
170                    env.clone(),
171                )
172                .await?;
173                Ok(Arc::new(connection))
174            }
175            McpServerSource::WebSocket {
176                url,
177                timeout_secs,
178                headers,
179            } => {
180                // Merge Bearer token with existing headers if provided
181                let mut merged_headers = headers.clone().unwrap_or_default();
182                if let Some(token) = &config.bearer_token {
183                    merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
184                }
185
186                let connection = WebSocketMcpConnection::new(
187                    config.id.clone(),
188                    config.name.clone(),
189                    url.clone(),
190                    *timeout_secs,
191                    Some(merged_headers),
192                )
193                .await?;
194                Ok(Arc::new(connection))
195            }
196        }
197    }
198
199    /// Discover tools from all connected servers and register them
200    async fn discover_and_register_tools(&mut self) -> Result<()> {
201        for (server_id, connection) in &self.servers {
202            let tools = connection.list_tools().await?;
203            let server_config = self
204                .config
205                .servers
206                .iter()
207                .find(|s| &s.id == server_id)
208                .ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?;
209
210            for tool in tools {
211                let tool_name = if let Some(prefix) = &server_config.tool_prefix {
212                    format!("{}_{}", prefix, tool.name)
213                } else {
214                    tool.name.clone()
215                };
216
217                // Create tool callback that calls the MCP server with timeout and concurrency controls
218                let connection_clone = Arc::clone(connection);
219                let original_tool_name = tool.name.clone();
220                let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
221                let timeout_duration =
222                    Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
223
224                let callback: Arc<ToolCallback> = Arc::new(move |called_function| {
225                    let connection = Arc::clone(&connection_clone);
226                    let tool_name = original_tool_name.clone();
227                    let semaphore = Arc::clone(&semaphore_clone);
228                    let arguments: serde_json::Value =
229                        serde_json::from_str(&called_function.arguments)?;
230
231                    // Use tokio::task::spawn_blocking to handle the async-to-sync bridge
232                    let rt = tokio::runtime::Handle::current();
233                    std::thread::spawn(move || {
234                        rt.block_on(async move {
235                            // Acquire semaphore permit for concurrency control
236                            let _permit = semaphore.acquire().await.map_err(|_| {
237                                anyhow::anyhow!("Failed to acquire concurrency permit")
238                            })?;
239
240                            // Execute tool call with timeout
241                            match tokio::time::timeout(
242                                timeout_duration,
243                                connection.call_tool(&tool_name, arguments),
244                            )
245                            .await
246                            {
247                                Ok(result) => result,
248                                Err(_) => Err(anyhow::anyhow!(
249                                    "Tool call timed out after {} seconds",
250                                    timeout_duration.as_secs()
251                                )),
252                            }
253                        })
254                    })
255                    .join()
256                    .map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
257                });
258
259                // Convert MCP tool schema to Tool definition
260                let function_def = Function {
261                    name: tool_name.clone(),
262                    description: tool.description.clone(),
263                    parameters: Self::convert_mcp_schema_to_parameters(&tool.input_schema),
264                };
265
266                let tool_def = Tool {
267                    tp: ToolType::Function,
268                    function: function_def,
269                };
270
271                // Store in both collections for backward compatibility
272                self.tool_callbacks
273                    .insert(tool_name.clone(), callback.clone());
274                self.tool_callbacks_with_tools.insert(
275                    tool_name.clone(),
276                    ToolCallbackWithTool {
277                        callback,
278                        tool: tool_def,
279                    },
280                );
281                self.tools.insert(tool_name, tool);
282            }
283        }
284
285        Ok(())
286    }
287
288    /// Convert MCP tool input schema to Tool parameters format
289    fn convert_mcp_schema_to_parameters(
290        schema: &serde_json::Value,
291    ) -> Option<HashMap<String, serde_json::Value>> {
292        // MCP tools can have various schema formats, we'll try to convert common ones
293        match schema {
294            serde_json::Value::Object(obj) => {
295                let mut params = HashMap::new();
296
297                // If it's a JSON schema object, extract properties
298                if let Some(properties) = obj.get("properties") {
299                    if let serde_json::Value::Object(props) = properties {
300                        for (key, value) in props {
301                            params.insert(key.clone(), value.clone());
302                        }
303                    }
304                } else {
305                    // If it's just a direct object, use it as-is
306                    for (key, value) in obj {
307                        params.insert(key.clone(), value.clone());
308                    }
309                }
310
311                if params.is_empty() {
312                    None
313                } else {
314                    Some(params)
315                }
316            }
317            _ => {
318                // For non-object schemas, we can't easily convert to parameters
319                None
320            }
321        }
322    }
323}
324
325/// HTTP-based MCP server connection
326pub struct HttpMcpConnection {
327    server_id: String,
328    server_name: String,
329    transport: Arc<dyn McpTransport>,
330}
331
332impl HttpMcpConnection {
333    pub async fn new(
334        server_id: String,
335        server_name: String,
336        url: String,
337        timeout_secs: Option<u64>,
338        headers: Option<HashMap<String, String>>,
339    ) -> Result<Self> {
340        let transport = HttpTransport::new(url, timeout_secs, headers)?;
341
342        let connection = Self {
343            server_id,
344            server_name,
345            transport: Arc::new(transport),
346        };
347
348        // Initialize the connection
349        connection.initialize().await?;
350
351        Ok(connection)
352    }
353
354    async fn initialize(&self) -> Result<()> {
355        let init_params = serde_json::json!({
356            "protocolVersion": "2025-03-26",
357            "capabilities": {
358                "tools": {}
359            },
360            "clientInfo": {
361                "name": "mistral.rs",
362                "version": "0.6.0"
363            }
364        });
365
366        self.transport
367            .send_request("initialize", init_params)
368            .await?;
369        self.transport.send_initialization_notification().await?;
370        Ok(())
371    }
372}
373
374#[async_trait::async_trait]
375impl McpServerConnection for HttpMcpConnection {
376    fn server_id(&self) -> &str {
377        &self.server_id
378    }
379
380    fn server_name(&self) -> &str {
381        &self.server_name
382    }
383
384    async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
385        let result = self
386            .transport
387            .send_request("tools/list", Value::Null)
388            .await?;
389
390        let tools = result
391            .get("tools")
392            .and_then(|t| t.as_array())
393            .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
394
395        let mut tool_infos = Vec::new();
396        for tool in tools {
397            let name = tool
398                .get("name")
399                .and_then(|n| n.as_str())
400                .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
401                .to_string();
402
403            let description = tool
404                .get("description")
405                .and_then(|d| d.as_str())
406                .map(|s| s.to_string());
407
408            let input_schema = tool
409                .get("inputSchema")
410                .cloned()
411                .unwrap_or(Value::Object(serde_json::Map::new()));
412
413            tool_infos.push(McpToolInfo {
414                name,
415                description,
416                input_schema,
417                server_id: self.server_id.clone(),
418                server_name: self.server_name.clone(),
419            });
420        }
421
422        Ok(tool_infos)
423    }
424
425    async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
426        let params = serde_json::json!({
427            "name": name,
428            "arguments": arguments
429        });
430
431        let result = self.transport.send_request("tools/call", params).await?;
432
433        // Parse the MCP tool result
434        let tool_result: McpToolResult = serde_json::from_value(result)?;
435
436        // Check if the result indicates an error
437        if tool_result.is_error.unwrap_or(false) {
438            return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
439        }
440
441        Ok(tool_result.to_string())
442    }
443
444    async fn list_resources(&self) -> Result<Vec<Resource>> {
445        let result = self
446            .transport
447            .send_request("resources/list", Value::Null)
448            .await?;
449
450        let resources = result
451            .get("resources")
452            .and_then(|r| r.as_array())
453            .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
454
455        let mut resource_list = Vec::new();
456        for resource in resources {
457            let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
458            resource_list.push(mcp_resource);
459        }
460
461        Ok(resource_list)
462    }
463
464    async fn read_resource(&self, uri: &str) -> Result<String> {
465        let params = serde_json::json!({ "uri": uri });
466        let result = self
467            .transport
468            .send_request("resources/read", params)
469            .await?;
470
471        // Extract content from the response
472        if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
473            if let Some(first_content) = contents.first() {
474                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
475                    return Ok(text.to_string());
476                }
477            }
478        }
479
480        Err(anyhow::anyhow!("No readable content found in resource"))
481    }
482
483    async fn ping(&self) -> Result<()> {
484        // Send a simple ping to check if the server is responsive
485        self.transport.send_request("ping", Value::Null).await?;
486        Ok(())
487    }
488}
489
490/// Process-based MCP server connection
491pub struct ProcessMcpConnection {
492    server_id: String,
493    server_name: String,
494    transport: Arc<dyn McpTransport>,
495}
496
497impl ProcessMcpConnection {
498    pub async fn new(
499        server_id: String,
500        server_name: String,
501        command: String,
502        args: Vec<String>,
503        work_dir: Option<String>,
504        env: Option<HashMap<String, String>>,
505    ) -> Result<Self> {
506        let transport = ProcessTransport::new(command, args, work_dir, env).await?;
507
508        let connection = Self {
509            server_id,
510            server_name,
511            transport: Arc::new(transport),
512        };
513
514        // Initialize the connection
515        connection.initialize().await?;
516
517        Ok(connection)
518    }
519
520    async fn initialize(&self) -> Result<()> {
521        let init_params = serde_json::json!({
522            "protocolVersion": "2025-03-26",
523            "capabilities": {
524                "tools": {}
525            },
526            "clientInfo": {
527                "name": "mistral.rs",
528                "version": "0.6.0"
529            }
530        });
531
532        self.transport
533            .send_request("initialize", init_params)
534            .await?;
535        self.transport.send_initialization_notification().await?;
536        Ok(())
537    }
538}
539
540#[async_trait::async_trait]
541impl McpServerConnection for ProcessMcpConnection {
542    fn server_id(&self) -> &str {
543        &self.server_id
544    }
545
546    fn server_name(&self) -> &str {
547        &self.server_name
548    }
549
550    async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
551        let result = self
552            .transport
553            .send_request("tools/list", Value::Null)
554            .await?;
555
556        let tools = result
557            .get("tools")
558            .and_then(|t| t.as_array())
559            .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
560
561        let mut tool_infos = Vec::new();
562        for tool in tools {
563            let name = tool
564                .get("name")
565                .and_then(|n| n.as_str())
566                .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
567                .to_string();
568
569            let description = tool
570                .get("description")
571                .and_then(|d| d.as_str())
572                .map(|s| s.to_string());
573
574            let input_schema = tool
575                .get("inputSchema")
576                .cloned()
577                .unwrap_or(Value::Object(serde_json::Map::new()));
578
579            tool_infos.push(McpToolInfo {
580                name,
581                description,
582                input_schema,
583                server_id: self.server_id.clone(),
584                server_name: self.server_name.clone(),
585            });
586        }
587
588        Ok(tool_infos)
589    }
590
591    async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
592        let params = serde_json::json!({
593            "name": name,
594            "arguments": arguments
595        });
596
597        let result = self.transport.send_request("tools/call", params).await?;
598
599        // Parse the MCP tool result
600        let tool_result: McpToolResult = serde_json::from_value(result)?;
601
602        // Check if the result indicates an error
603        if tool_result.is_error.unwrap_or(false) {
604            return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
605        }
606
607        Ok(tool_result.to_string())
608    }
609
610    async fn list_resources(&self) -> Result<Vec<Resource>> {
611        let result = self
612            .transport
613            .send_request("resources/list", Value::Null)
614            .await?;
615
616        let resources = result
617            .get("resources")
618            .and_then(|r| r.as_array())
619            .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
620
621        let mut resource_list = Vec::new();
622        for resource in resources {
623            let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
624            resource_list.push(mcp_resource);
625        }
626
627        Ok(resource_list)
628    }
629
630    async fn read_resource(&self, uri: &str) -> Result<String> {
631        let params = serde_json::json!({ "uri": uri });
632        let result = self
633            .transport
634            .send_request("resources/read", params)
635            .await?;
636
637        // Extract content from the response
638        if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
639            if let Some(first_content) = contents.first() {
640                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
641                    return Ok(text.to_string());
642                }
643            }
644        }
645
646        Err(anyhow::anyhow!("No readable content found in resource"))
647    }
648
649    async fn ping(&self) -> Result<()> {
650        // Send a simple ping to check if the server is responsive
651        self.transport.send_request("ping", Value::Null).await?;
652        Ok(())
653    }
654}
655
656/// WebSocket-based MCP server connection
657pub struct WebSocketMcpConnection {
658    server_id: String,
659    server_name: String,
660    transport: Arc<dyn McpTransport>,
661}
662
663impl WebSocketMcpConnection {
664    pub async fn new(
665        server_id: String,
666        server_name: String,
667        url: String,
668        timeout_secs: Option<u64>,
669        headers: Option<HashMap<String, String>>,
670    ) -> Result<Self> {
671        let transport = WebSocketTransport::new(url, timeout_secs, headers).await?;
672
673        let connection = Self {
674            server_id,
675            server_name,
676            transport: Arc::new(transport),
677        };
678
679        // Initialize the connection
680        connection.initialize().await?;
681
682        Ok(connection)
683    }
684
685    async fn initialize(&self) -> Result<()> {
686        let init_params = serde_json::json!({
687            "protocolVersion": "2025-03-26",
688            "capabilities": {
689                "tools": {}
690            },
691            "clientInfo": {
692                "name": "mistral.rs",
693                "version": "0.6.0"
694            }
695        });
696
697        self.transport
698            .send_request("initialize", init_params)
699            .await?;
700        self.transport.send_initialization_notification().await?;
701        Ok(())
702    }
703}
704
705#[async_trait::async_trait]
706impl McpServerConnection for WebSocketMcpConnection {
707    fn server_id(&self) -> &str {
708        &self.server_id
709    }
710
711    fn server_name(&self) -> &str {
712        &self.server_name
713    }
714
715    async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
716        let result = self
717            .transport
718            .send_request("tools/list", Value::Null)
719            .await?;
720
721        let tools = result
722            .get("tools")
723            .and_then(|t| t.as_array())
724            .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
725
726        let mut tool_infos = Vec::new();
727        for tool in tools {
728            let name = tool
729                .get("name")
730                .and_then(|n| n.as_str())
731                .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
732                .to_string();
733
734            let description = tool
735                .get("description")
736                .and_then(|d| d.as_str())
737                .map(|s| s.to_string());
738
739            let input_schema = tool
740                .get("inputSchema")
741                .cloned()
742                .unwrap_or(Value::Object(serde_json::Map::new()));
743
744            tool_infos.push(McpToolInfo {
745                name,
746                description,
747                input_schema,
748                server_id: self.server_id.clone(),
749                server_name: self.server_name.clone(),
750            });
751        }
752
753        Ok(tool_infos)
754    }
755
756    async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
757        let params = serde_json::json!({
758            "name": name,
759            "arguments": arguments
760        });
761
762        let result = self.transport.send_request("tools/call", params).await?;
763
764        // Parse the MCP tool result
765        let tool_result: McpToolResult = serde_json::from_value(result)?;
766
767        // Check if the result indicates an error
768        if tool_result.is_error.unwrap_or(false) {
769            return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
770        }
771
772        Ok(tool_result.to_string())
773    }
774
775    async fn list_resources(&self) -> Result<Vec<Resource>> {
776        let result = self
777            .transport
778            .send_request("resources/list", Value::Null)
779            .await?;
780
781        let resources = result
782            .get("resources")
783            .and_then(|r| r.as_array())
784            .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
785
786        let mut resource_list = Vec::new();
787        for resource in resources {
788            let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
789            resource_list.push(mcp_resource);
790        }
791
792        Ok(resource_list)
793    }
794
795    async fn read_resource(&self, uri: &str) -> Result<String> {
796        let params = serde_json::json!({ "uri": uri });
797        let result = self
798            .transport
799            .send_request("resources/read", params)
800            .await?;
801
802        // Extract content from the response
803        if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
804            if let Some(first_content) = contents.first() {
805                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
806                    return Ok(text.to_string());
807                }
808            }
809        }
810
811        Err(anyhow::anyhow!("No readable content found in resource"))
812    }
813
814    async fn ping(&self) -> Result<()> {
815        // Send a simple ping to check if the server is responsive
816        self.transport.send_request("ping", Value::Null).await?;
817        Ok(())
818    }
819}