mistralrs_mcp/
client.rs

1use crate::tools::{Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType};
2use crate::transport::{HttpTransport, McpTransport, ProcessTransport, WebSocketTransport};
3use crate::types::McpToolResult;
4use crate::{McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo};
5use anyhow::Result;
6use rust_mcp_schema::Resource;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::Semaphore;
12
13/// Trait for MCP server connections
14#[async_trait::async_trait]
15pub trait McpServerConnection: Send + Sync {
16    /// Get the server ID
17    fn server_id(&self) -> &str;
18
19    /// Get the server name
20    fn server_name(&self) -> &str;
21
22    /// List available tools from this server
23    async fn list_tools(&self) -> Result<Vec<McpToolInfo>>;
24
25    /// Call a tool on this server
26    async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String>;
27
28    /// List available resources from this server
29    async fn list_resources(&self) -> Result<Vec<Resource>>;
30
31    /// Read a resource from this server
32    async fn read_resource(&self, uri: &str) -> Result<String>;
33
34    /// Check if the connection is healthy
35    async fn ping(&self) -> Result<()>;
36}
37
38/// MCP client that manages connections to multiple MCP servers
39///
40/// The main interface for interacting with Model Context Protocol servers.
41/// Handles connection lifecycle, tool discovery, and provides integration
42/// with tool calling systems.
43///
44/// # Features
45///
46/// - **Multi-server Management**: Connects to and manages multiple MCP servers simultaneously
47/// - **Automatic Tool Discovery**: Discovers available tools from connected servers
48/// - **Tool Registration**: Converts MCP tools to internal Tool format for seamless integration
49/// - **Connection Pooling**: Maintains persistent connections for efficient tool execution
50/// - **Error Handling**: Robust error handling with proper cleanup and reconnection logic
51///
52/// # Example
53///
54/// ```rust,no_run
55/// use mistralrs_mcp::{McpClient, McpClientConfig};
56///
57/// #[tokio::main]
58/// async fn main() -> anyhow::Result<()> {
59///     let config = McpClientConfig::default();
60///     let mut client = McpClient::new(config);
61///     
62///     // Initialize all configured server connections
63///     client.initialize().await?;
64///     
65///     // Get tool callbacks for model integration
66///     let callbacks = client.get_tool_callbacks_with_tools();
67///     
68///     Ok(())
69/// }
70/// ```
71pub struct McpClient {
72    /// Configuration for the client including server list and policies
73    config: McpClientConfig,
74    /// Active connections to MCP servers, indexed by server ID
75    servers: HashMap<String, Arc<dyn McpServerConnection>>,
76    /// Registry of discovered tools from all connected servers
77    tools: HashMap<String, McpToolInfo>,
78    /// Legacy tool callbacks for backward compatibility
79    tool_callbacks: HashMap<String, Arc<ToolCallback>>,
80    /// Tool callbacks with associated Tool definitions for automatic tool calling
81    tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
82    /// Semaphore to control maximum concurrent tool calls
83    concurrency_semaphore: Arc<Semaphore>,
84}
85
86impl McpClient {
87    /// Create a new MCP client with the given configuration
88    pub fn new(config: McpClientConfig) -> Self {
89        let max_concurrent = config.max_concurrent_calls.unwrap_or(10);
90        Self {
91            config,
92            servers: HashMap::new(),
93            tools: HashMap::new(),
94            tool_callbacks: HashMap::new(),
95            tool_callbacks_with_tools: HashMap::new(),
96            concurrency_semaphore: Arc::new(Semaphore::new(max_concurrent)),
97        }
98    }
99
100    /// Initialize connections to all configured servers
101    pub async fn initialize(&mut self) -> Result<()> {
102        for server_config in &self.config.servers {
103            if server_config.enabled {
104                let connection = self.create_connection(server_config).await?;
105                self.servers.insert(server_config.id.clone(), connection);
106            }
107        }
108
109        if self.config.auto_register_tools {
110            self.discover_and_register_tools().await?;
111        }
112
113        Ok(())
114    }
115
116    /// Get tool callbacks that can be used with the existing tool calling system
117    pub fn get_tool_callbacks(&self) -> &HashMap<String, Arc<ToolCallback>> {
118        &self.tool_callbacks
119    }
120
121    /// Get tool callbacks with their associated Tool definitions
122    pub fn get_tool_callbacks_with_tools(&self) -> &HashMap<String, ToolCallbackWithTool> {
123        &self.tool_callbacks_with_tools
124    }
125
126    /// Get discovered tools information
127    pub fn get_tools(&self) -> &HashMap<String, McpToolInfo> {
128        &self.tools
129    }
130
131    /// Create connection based on server source type
132    async fn create_connection(
133        &self,
134        config: &McpServerConfig,
135    ) -> Result<Arc<dyn McpServerConnection>> {
136        match &config.source {
137            McpServerSource::Http {
138                url,
139                timeout_secs,
140                headers,
141            } => {
142                // Merge Bearer token with existing headers if provided
143                let mut merged_headers = headers.clone().unwrap_or_default();
144                if let Some(token) = &config.bearer_token {
145                    merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
146                }
147
148                let connection = HttpMcpConnection::new(
149                    config.id.clone(),
150                    config.name.clone(),
151                    url.clone(),
152                    *timeout_secs,
153                    Some(merged_headers),
154                )
155                .await?;
156                Ok(Arc::new(connection))
157            }
158            McpServerSource::Process {
159                command,
160                args,
161                work_dir,
162                env,
163            } => {
164                let connection = ProcessMcpConnection::new(
165                    config.id.clone(),
166                    config.name.clone(),
167                    command.clone(),
168                    args.clone(),
169                    work_dir.clone(),
170                    env.clone(),
171                )
172                .await?;
173                Ok(Arc::new(connection))
174            }
175            McpServerSource::WebSocket {
176                url,
177                timeout_secs,
178                headers,
179            } => {
180                // Merge Bearer token with existing headers if provided
181                let mut merged_headers = headers.clone().unwrap_or_default();
182                if let Some(token) = &config.bearer_token {
183                    merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
184                }
185
186                let connection = WebSocketMcpConnection::new(
187                    config.id.clone(),
188                    config.name.clone(),
189                    url.clone(),
190                    *timeout_secs,
191                    Some(merged_headers),
192                )
193                .await?;
194                Ok(Arc::new(connection))
195            }
196        }
197    }
198
199    /// Discover tools from all connected servers and register them
200    async fn discover_and_register_tools(&mut self) -> Result<()> {
201        for (server_id, connection) in &self.servers {
202            let tools = connection.list_tools().await?;
203            let server_config = self
204                .config
205                .servers
206                .iter()
207                .find(|s| &s.id == server_id)
208                .ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?;
209
210            for tool in tools {
211                let tool_name = if let Some(prefix) = &server_config.tool_prefix {
212                    format!("{}_{}", prefix, tool.name)
213                } else {
214                    tool.name.clone()
215                };
216
217                // Create tool callback that calls the MCP server with timeout and concurrency controls
218                let connection_clone = Arc::clone(connection);
219                let original_tool_name = tool.name.clone();
220                let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
221                let timeout_duration =
222                    Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
223
224                let callback: Arc<ToolCallback> = Arc::new(move |called_function| {
225                    let connection = Arc::clone(&connection_clone);
226                    let tool_name = original_tool_name.clone();
227                    let semaphore = Arc::clone(&semaphore_clone);
228                    let arguments: serde_json::Value =
229                        serde_json::from_str(&called_function.arguments)?;
230
231                    // Use tokio::task::spawn_blocking to handle the async-to-sync bridge
232                    let rt = tokio::runtime::Handle::current();
233                    std::thread::spawn(move || {
234                        rt.block_on(async move {
235                            // Acquire semaphore permit for concurrency control
236                            let _permit = semaphore.acquire().await.map_err(|_| {
237                                anyhow::anyhow!("Failed to acquire concurrency permit")
238                            })?;
239
240                            // Execute tool call with timeout
241                            match tokio::time::timeout(
242                                timeout_duration,
243                                connection.call_tool(&tool_name, arguments),
244                            )
245                            .await
246                            {
247                                Ok(result) => result,
248                                Err(_) => Err(anyhow::anyhow!(
249                                    "Tool call timed out after {} seconds",
250                                    timeout_duration.as_secs()
251                                )),
252                            }
253                        })
254                    })
255                    .join()
256                    .map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
257                });
258
259                // Convert MCP tool schema to Tool definition
260                let function_def = Function {
261                    name: tool_name.clone(),
262                    description: tool.description.clone(),
263                    parameters: Self::convert_mcp_schema_to_parameters(&tool.input_schema),
264                };
265
266                let tool_def = Tool {
267                    tp: ToolType::Function,
268                    function: function_def,
269                };
270
271                // Store in both collections for backward compatibility
272                self.tool_callbacks
273                    .insert(tool_name.clone(), callback.clone());
274                self.tool_callbacks_with_tools.insert(
275                    tool_name.clone(),
276                    ToolCallbackWithTool {
277                        callback,
278                        tool: tool_def,
279                    },
280                );
281                self.tools.insert(tool_name, tool);
282            }
283        }
284
285        Ok(())
286    }
287
288    /// Convert MCP tool input schema to Tool parameters format
289    fn convert_mcp_schema_to_parameters(
290        schema: &serde_json::Value,
291    ) -> Option<HashMap<String, serde_json::Value>> {
292        // MCP tools can have various schema formats, we'll try to convert common ones
293        match schema {
294            serde_json::Value::Object(obj) => {
295                let mut params = HashMap::new();
296
297                // If it's a JSON schema object, extract properties
298                if let Some(properties) = obj.get("properties") {
299                    if let serde_json::Value::Object(props) = properties {
300                        for (key, value) in props {
301                            params.insert(key.clone(), value.clone());
302                        }
303                    }
304                } else {
305                    // If it's just a direct object, use it as-is
306                    for (key, value) in obj {
307                        params.insert(key.clone(), value.clone());
308                    }
309                }
310
311                if params.is_empty() {
312                    None
313                } else {
314                    Some(params)
315                }
316            }
317            _ => {
318                // For non-object schemas, we can't easily convert to parameters
319                None
320            }
321        }
322    }
323}
324
325/// HTTP-based MCP server connection
326pub struct HttpMcpConnection {
327    server_id: String,
328    server_name: String,
329    transport: Arc<dyn McpTransport>,
330}
331
332impl HttpMcpConnection {
333    pub async fn new(
334        server_id: String,
335        server_name: String,
336        url: String,
337        timeout_secs: Option<u64>,
338        headers: Option<HashMap<String, String>>,
339    ) -> Result<Self> {
340        let transport = HttpTransport::new(url, timeout_secs, headers)?;
341
342        let connection = Self {
343            server_id,
344            server_name,
345            transport: Arc::new(transport),
346        };
347
348        // Initialize the connection
349        connection.initialize().await?;
350
351        Ok(connection)
352    }
353
354    async fn initialize(&self) -> Result<()> {
355        let init_params = serde_json::json!({
356            "protocolVersion": "2025-03-26",
357            "capabilities": {
358                "tools": {}
359            },
360            "clientInfo": {
361                "name": "mistral.rs",
362                "version": "0.6.0"
363            }
364        });
365
366        self.transport
367            .send_request("initialize", init_params)
368            .await?;
369        Ok(())
370    }
371}
372
373#[async_trait::async_trait]
374impl McpServerConnection for HttpMcpConnection {
375    fn server_id(&self) -> &str {
376        &self.server_id
377    }
378
379    fn server_name(&self) -> &str {
380        &self.server_name
381    }
382
383    async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
384        let result = self
385            .transport
386            .send_request("tools/list", Value::Null)
387            .await?;
388
389        let tools = result
390            .get("tools")
391            .and_then(|t| t.as_array())
392            .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
393
394        let mut tool_infos = Vec::new();
395        for tool in tools {
396            let name = tool
397                .get("name")
398                .and_then(|n| n.as_str())
399                .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
400                .to_string();
401
402            let description = tool
403                .get("description")
404                .and_then(|d| d.as_str())
405                .map(|s| s.to_string());
406
407            let input_schema = tool
408                .get("inputSchema")
409                .cloned()
410                .unwrap_or(Value::Object(serde_json::Map::new()));
411
412            tool_infos.push(McpToolInfo {
413                name,
414                description,
415                input_schema,
416                server_id: self.server_id.clone(),
417                server_name: self.server_name.clone(),
418            });
419        }
420
421        Ok(tool_infos)
422    }
423
424    async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
425        let params = serde_json::json!({
426            "name": name,
427            "arguments": arguments
428        });
429
430        let result = self.transport.send_request("tools/call", params).await?;
431
432        // Parse the MCP tool result
433        let tool_result: McpToolResult = serde_json::from_value(result)?;
434
435        // Check if the result indicates an error
436        if tool_result.is_error.unwrap_or(false) {
437            return Err(anyhow::anyhow!(
438                "Tool execution failed: {}",
439                tool_result.to_string()
440            ));
441        }
442
443        Ok(tool_result.to_string())
444    }
445
446    async fn list_resources(&self) -> Result<Vec<Resource>> {
447        let result = self
448            .transport
449            .send_request("resources/list", Value::Null)
450            .await?;
451
452        let resources = result
453            .get("resources")
454            .and_then(|r| r.as_array())
455            .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
456
457        let mut resource_list = Vec::new();
458        for resource in resources {
459            let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
460            resource_list.push(mcp_resource);
461        }
462
463        Ok(resource_list)
464    }
465
466    async fn read_resource(&self, uri: &str) -> Result<String> {
467        let params = serde_json::json!({ "uri": uri });
468        let result = self
469            .transport
470            .send_request("resources/read", params)
471            .await?;
472
473        // Extract content from the response
474        if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
475            if let Some(first_content) = contents.first() {
476                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
477                    return Ok(text.to_string());
478                }
479            }
480        }
481
482        Err(anyhow::anyhow!("No readable content found in resource"))
483    }
484
485    async fn ping(&self) -> Result<()> {
486        // Send a simple ping to check if the server is responsive
487        self.transport.send_request("ping", Value::Null).await?;
488        Ok(())
489    }
490}
491
492/// Process-based MCP server connection
493pub struct ProcessMcpConnection {
494    server_id: String,
495    server_name: String,
496    transport: Arc<dyn McpTransport>,
497}
498
499impl ProcessMcpConnection {
500    pub async fn new(
501        server_id: String,
502        server_name: String,
503        command: String,
504        args: Vec<String>,
505        work_dir: Option<String>,
506        env: Option<HashMap<String, String>>,
507    ) -> Result<Self> {
508        let transport = ProcessTransport::new(command, args, work_dir, env).await?;
509
510        let connection = Self {
511            server_id,
512            server_name,
513            transport: Arc::new(transport),
514        };
515
516        // Initialize the connection
517        connection.initialize().await?;
518
519        Ok(connection)
520    }
521
522    async fn initialize(&self) -> Result<()> {
523        let init_params = serde_json::json!({
524            "protocolVersion": "2025-03-26",
525            "capabilities": {
526                "tools": {}
527            },
528            "clientInfo": {
529                "name": "mistral.rs",
530                "version": "0.6.0"
531            }
532        });
533
534        self.transport
535            .send_request("initialize", init_params)
536            .await?;
537        Ok(())
538    }
539}
540
541#[async_trait::async_trait]
542impl McpServerConnection for ProcessMcpConnection {
543    fn server_id(&self) -> &str {
544        &self.server_id
545    }
546
547    fn server_name(&self) -> &str {
548        &self.server_name
549    }
550
551    async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
552        let result = self
553            .transport
554            .send_request("tools/list", Value::Null)
555            .await?;
556
557        let tools = result
558            .get("tools")
559            .and_then(|t| t.as_array())
560            .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
561
562        let mut tool_infos = Vec::new();
563        for tool in tools {
564            let name = tool
565                .get("name")
566                .and_then(|n| n.as_str())
567                .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
568                .to_string();
569
570            let description = tool
571                .get("description")
572                .and_then(|d| d.as_str())
573                .map(|s| s.to_string());
574
575            let input_schema = tool
576                .get("inputSchema")
577                .cloned()
578                .unwrap_or(Value::Object(serde_json::Map::new()));
579
580            tool_infos.push(McpToolInfo {
581                name,
582                description,
583                input_schema,
584                server_id: self.server_id.clone(),
585                server_name: self.server_name.clone(),
586            });
587        }
588
589        Ok(tool_infos)
590    }
591
592    async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
593        let params = serde_json::json!({
594            "name": name,
595            "arguments": arguments
596        });
597
598        let result = self.transport.send_request("tools/call", params).await?;
599
600        // Parse the MCP tool result
601        let tool_result: McpToolResult = serde_json::from_value(result)?;
602
603        // Check if the result indicates an error
604        if tool_result.is_error.unwrap_or(false) {
605            return Err(anyhow::anyhow!(
606                "Tool execution failed: {}",
607                tool_result.to_string()
608            ));
609        }
610
611        Ok(tool_result.to_string())
612    }
613
614    async fn list_resources(&self) -> Result<Vec<Resource>> {
615        let result = self
616            .transport
617            .send_request("resources/list", Value::Null)
618            .await?;
619
620        let resources = result
621            .get("resources")
622            .and_then(|r| r.as_array())
623            .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
624
625        let mut resource_list = Vec::new();
626        for resource in resources {
627            let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
628            resource_list.push(mcp_resource);
629        }
630
631        Ok(resource_list)
632    }
633
634    async fn read_resource(&self, uri: &str) -> Result<String> {
635        let params = serde_json::json!({ "uri": uri });
636        let result = self
637            .transport
638            .send_request("resources/read", params)
639            .await?;
640
641        // Extract content from the response
642        if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
643            if let Some(first_content) = contents.first() {
644                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
645                    return Ok(text.to_string());
646                }
647            }
648        }
649
650        Err(anyhow::anyhow!("No readable content found in resource"))
651    }
652
653    async fn ping(&self) -> Result<()> {
654        // Send a simple ping to check if the server is responsive
655        self.transport.send_request("ping", Value::Null).await?;
656        Ok(())
657    }
658}
659
660/// WebSocket-based MCP server connection
661pub struct WebSocketMcpConnection {
662    server_id: String,
663    server_name: String,
664    transport: Arc<dyn McpTransport>,
665}
666
667impl WebSocketMcpConnection {
668    pub async fn new(
669        server_id: String,
670        server_name: String,
671        url: String,
672        timeout_secs: Option<u64>,
673        headers: Option<HashMap<String, String>>,
674    ) -> Result<Self> {
675        let transport = WebSocketTransport::new(url, timeout_secs, headers).await?;
676
677        let connection = Self {
678            server_id,
679            server_name,
680            transport: Arc::new(transport),
681        };
682
683        // Initialize the connection
684        connection.initialize().await?;
685
686        Ok(connection)
687    }
688
689    async fn initialize(&self) -> Result<()> {
690        let init_params = serde_json::json!({
691            "protocolVersion": "2025-03-26",
692            "capabilities": {
693                "tools": {}
694            },
695            "clientInfo": {
696                "name": "mistral.rs",
697                "version": "0.6.0"
698            }
699        });
700
701        self.transport
702            .send_request("initialize", init_params)
703            .await?;
704        Ok(())
705    }
706}
707
708#[async_trait::async_trait]
709impl McpServerConnection for WebSocketMcpConnection {
710    fn server_id(&self) -> &str {
711        &self.server_id
712    }
713
714    fn server_name(&self) -> &str {
715        &self.server_name
716    }
717
718    async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
719        let result = self
720            .transport
721            .send_request("tools/list", Value::Null)
722            .await?;
723
724        let tools = result
725            .get("tools")
726            .and_then(|t| t.as_array())
727            .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
728
729        let mut tool_infos = Vec::new();
730        for tool in tools {
731            let name = tool
732                .get("name")
733                .and_then(|n| n.as_str())
734                .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
735                .to_string();
736
737            let description = tool
738                .get("description")
739                .and_then(|d| d.as_str())
740                .map(|s| s.to_string());
741
742            let input_schema = tool
743                .get("inputSchema")
744                .cloned()
745                .unwrap_or(Value::Object(serde_json::Map::new()));
746
747            tool_infos.push(McpToolInfo {
748                name,
749                description,
750                input_schema,
751                server_id: self.server_id.clone(),
752                server_name: self.server_name.clone(),
753            });
754        }
755
756        Ok(tool_infos)
757    }
758
759    async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
760        let params = serde_json::json!({
761            "name": name,
762            "arguments": arguments
763        });
764
765        let result = self.transport.send_request("tools/call", params).await?;
766
767        // Parse the MCP tool result
768        let tool_result: McpToolResult = serde_json::from_value(result)?;
769
770        // Check if the result indicates an error
771        if tool_result.is_error.unwrap_or(false) {
772            return Err(anyhow::anyhow!(
773                "Tool execution failed: {}",
774                tool_result.to_string()
775            ));
776        }
777
778        Ok(tool_result.to_string())
779    }
780
781    async fn list_resources(&self) -> Result<Vec<Resource>> {
782        let result = self
783            .transport
784            .send_request("resources/list", Value::Null)
785            .await?;
786
787        let resources = result
788            .get("resources")
789            .and_then(|r| r.as_array())
790            .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
791
792        let mut resource_list = Vec::new();
793        for resource in resources {
794            let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
795            resource_list.push(mcp_resource);
796        }
797
798        Ok(resource_list)
799    }
800
801    async fn read_resource(&self, uri: &str) -> Result<String> {
802        let params = serde_json::json!({ "uri": uri });
803        let result = self
804            .transport
805            .send_request("resources/read", params)
806            .await?;
807
808        // Extract content from the response
809        if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
810            if let Some(first_content) = contents.first() {
811                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
812                    return Ok(text.to_string());
813                }
814            }
815        }
816
817        Err(anyhow::anyhow!("No readable content found in resource"))
818    }
819
820    async fn ping(&self) -> Result<()> {
821        // Send a simple ping to check if the server is responsive
822        self.transport.send_request("ping", Value::Null).await?;
823        Ok(())
824    }
825}