mistralrs_mcp/
client.rs

1use crate::tools::{Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType};
2use crate::transport::{HttpTransport, McpTransport, ProcessTransport, WebSocketTransport};
3use crate::types::McpToolResult;
4use crate::{McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo};
5use anyhow::Result;
6use rust_mcp_schema::Resource;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::Semaphore;
12
13/// Trait for MCP server connections
14#[async_trait::async_trait]
15pub trait McpServerConnection: Send + Sync {
16    /// Get the server ID
17    fn server_id(&self) -> &str;
18
19    /// Get the server name
20    fn server_name(&self) -> &str;
21
22    /// List available tools from this server
23    async fn list_tools(&self) -> Result<Vec<McpToolInfo>>;
24
25    /// Call a tool on this server
26    async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String>;
27
28    /// List available resources from this server
29    async fn list_resources(&self) -> Result<Vec<Resource>>;
30
31    /// Read a resource from this server
32    async fn read_resource(&self, uri: &str) -> Result<String>;
33
34    /// Check if the connection is healthy
35    async fn ping(&self) -> Result<()>;
36}
37
38/// MCP client that manages connections to multiple MCP servers
39///
40/// The main interface for interacting with Model Context Protocol servers.
41/// Handles connection lifecycle, tool discovery, and provides integration
42/// with tool calling systems.
43///
44/// # Features
45///
46/// - **Multi-server Management**: Connects to and manages multiple MCP servers simultaneously
47/// - **Automatic Tool Discovery**: Discovers available tools from connected servers
48/// - **Tool Registration**: Converts MCP tools to internal Tool format for seamless integration
49/// - **Connection Pooling**: Maintains persistent connections for efficient tool execution
50/// - **Error Handling**: Robust error handling with proper cleanup and reconnection logic
51///
52/// # Example
53///
54/// ```rust,no_run
55/// use mistralrs_mcp::{McpClient, McpClientConfig};
56///
57/// #[tokio::main]
58/// async fn main() -> anyhow::Result<()> {
59///     let config = McpClientConfig::default();
60///     let mut client = McpClient::new(config);
61///     
62///     // Initialize all configured server connections
63///     client.initialize().await?;
64///     
65///     // Get tool callbacks for model integration
66///     let callbacks = client.get_tool_callbacks_with_tools();
67///     
68///     Ok(())
69/// }
70/// ```
71pub struct McpClient {
72    /// Configuration for the client including server list and policies
73    config: McpClientConfig,
74    /// Active connections to MCP servers, indexed by server ID
75    servers: HashMap<String, Arc<dyn McpServerConnection>>,
76    /// Registry of discovered tools from all connected servers
77    tools: HashMap<String, McpToolInfo>,
78    /// Legacy tool callbacks for backward compatibility
79    tool_callbacks: HashMap<String, Arc<ToolCallback>>,
80    /// Tool callbacks with associated Tool definitions for automatic tool calling
81    tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
82    /// Semaphore to control maximum concurrent tool calls
83    concurrency_semaphore: Arc<Semaphore>,
84}
85
86impl McpClient {
87    /// Create a new MCP client with the given configuration
88    pub fn new(config: McpClientConfig) -> Self {
89        let max_concurrent = config.max_concurrent_calls.unwrap_or(10);
90        Self {
91            config,
92            servers: HashMap::new(),
93            tools: HashMap::new(),
94            tool_callbacks: HashMap::new(),
95            tool_callbacks_with_tools: HashMap::new(),
96            concurrency_semaphore: Arc::new(Semaphore::new(max_concurrent)),
97        }
98    }
99
100    /// Initialize connections to all configured servers
101    pub async fn initialize(&mut self) -> Result<()> {
102        for server_config in &self.config.servers {
103            if server_config.enabled {
104                let connection = self.create_connection(server_config).await?;
105                self.servers.insert(server_config.id.clone(), connection);
106            }
107        }
108
109        if self.config.auto_register_tools {
110            self.discover_and_register_tools().await?;
111        }
112
113        Ok(())
114    }
115
116    /// Get tool callbacks that can be used with the existing tool calling system
117    pub fn get_tool_callbacks(&self) -> &HashMap<String, Arc<ToolCallback>> {
118        &self.tool_callbacks
119    }
120
121    /// Get tool callbacks with their associated Tool definitions
122    pub fn get_tool_callbacks_with_tools(&self) -> &HashMap<String, ToolCallbackWithTool> {
123        &self.tool_callbacks_with_tools
124    }
125
126    /// Get discovered tools information
127    pub fn get_tools(&self) -> &HashMap<String, McpToolInfo> {
128        &self.tools
129    }
130
131    /// Create connection based on server source type
132    async fn create_connection(
133        &self,
134        config: &McpServerConfig,
135    ) -> Result<Arc<dyn McpServerConnection>> {
136        match &config.source {
137            McpServerSource::Http {
138                url,
139                timeout_secs,
140                headers,
141            } => {
142                // Merge Bearer token with existing headers if provided
143                let mut merged_headers = headers.clone().unwrap_or_default();
144                if let Some(token) = &config.bearer_token {
145                    merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
146                }
147
148                let connection = HttpMcpConnection::new(
149                    config.id.clone(),
150                    config.name.clone(),
151                    url.clone(),
152                    *timeout_secs,
153                    Some(merged_headers),
154                )
155                .await?;
156                Ok(Arc::new(connection))
157            }
158            McpServerSource::Process {
159                command,
160                args,
161                work_dir,
162                env,
163            } => {
164                let connection = ProcessMcpConnection::new(
165                    config.id.clone(),
166                    config.name.clone(),
167                    command.clone(),
168                    args.clone(),
169                    work_dir.clone(),
170                    env.clone(),
171                )
172                .await?;
173                Ok(Arc::new(connection))
174            }
175            McpServerSource::WebSocket {
176                url,
177                timeout_secs,
178                headers,
179            } => {
180                // Merge Bearer token with existing headers if provided
181                let mut merged_headers = headers.clone().unwrap_or_default();
182                if let Some(token) = &config.bearer_token {
183                    merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
184                }
185
186                let connection = WebSocketMcpConnection::new(
187                    config.id.clone(),
188                    config.name.clone(),
189                    url.clone(),
190                    *timeout_secs,
191                    Some(merged_headers),
192                )
193                .await?;
194                Ok(Arc::new(connection))
195            }
196        }
197    }
198
199    /// Discover tools from all connected servers and register them
200    async fn discover_and_register_tools(&mut self) -> Result<()> {
201        for (server_id, connection) in &self.servers {
202            let tools = connection.list_tools().await?;
203            let server_config = self
204                .config
205                .servers
206                .iter()
207                .find(|s| &s.id == server_id)
208                .ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?;
209
210            for tool in tools {
211                let tool_name = if let Some(prefix) = &server_config.tool_prefix {
212                    format!("{}_{}", prefix, tool.name)
213                } else {
214                    tool.name.clone()
215                };
216
217                // Create tool callback that calls the MCP server with timeout and concurrency controls
218                let connection_clone = Arc::clone(connection);
219                let original_tool_name = tool.name.clone();
220                let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
221                let timeout_duration =
222                    Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
223
224                let callback: Arc<ToolCallback> = Arc::new(move |called_function| {
225                    let connection = Arc::clone(&connection_clone);
226                    let tool_name = original_tool_name.clone();
227                    let semaphore = Arc::clone(&semaphore_clone);
228                    let arguments: serde_json::Value =
229                        serde_json::from_str(&called_function.arguments)?;
230
231                    // Use tokio::task::spawn_blocking to handle the async-to-sync bridge
232                    let rt = tokio::runtime::Handle::current();
233                    std::thread::spawn(move || {
234                        rt.block_on(async move {
235                            // Acquire semaphore permit for concurrency control
236                            let _permit = semaphore.acquire().await.map_err(|_| {
237                                anyhow::anyhow!("Failed to acquire concurrency permit")
238                            })?;
239
240                            // Execute tool call with timeout
241                            match tokio::time::timeout(
242                                timeout_duration,
243                                connection.call_tool(&tool_name, arguments),
244                            )
245                            .await
246                            {
247                                Ok(result) => result,
248                                Err(_) => Err(anyhow::anyhow!(
249                                    "Tool call timed out after {} seconds",
250                                    timeout_duration.as_secs()
251                                )),
252                            }
253                        })
254                    })
255                    .join()
256                    .map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
257                });
258
259                // Convert MCP tool schema to Tool definition
260                let function_def = Function {
261                    name: tool_name.clone(),
262                    description: tool.description.clone(),
263                    parameters: Self::convert_mcp_schema_to_parameters(&tool.input_schema),
264                };
265
266                let tool_def = Tool {
267                    tp: ToolType::Function,
268                    function: function_def,
269                };
270
271                // Store in both collections for backward compatibility
272                self.tool_callbacks
273                    .insert(tool_name.clone(), callback.clone());
274                self.tool_callbacks_with_tools.insert(
275                    tool_name.clone(),
276                    ToolCallbackWithTool {
277                        callback,
278                        tool: tool_def,
279                    },
280                );
281                self.tools.insert(tool_name, tool);
282            }
283        }
284
285        Ok(())
286    }
287
288    /// Convert MCP tool input schema to Tool parameters format
289    fn convert_mcp_schema_to_parameters(
290        schema: &serde_json::Value,
291    ) -> Option<HashMap<String, serde_json::Value>> {
292        // MCP tools can have various schema formats, we'll try to convert common ones
293        match schema {
294            serde_json::Value::Object(obj) => {
295                let mut params = HashMap::new();
296
297                // If it's a JSON schema object, extract properties
298                if let Some(properties) = obj.get("properties") {
299                    if let serde_json::Value::Object(props) = properties {
300                        for (key, value) in props {
301                            params.insert(key.clone(), value.clone());
302                        }
303                    }
304                } else {
305                    // If it's just a direct object, use it as-is
306                    for (key, value) in obj {
307                        params.insert(key.clone(), value.clone());
308                    }
309                }
310
311                if params.is_empty() {
312                    None
313                } else {
314                    Some(params)
315                }
316            }
317            _ => {
318                // For non-object schemas, we can't easily convert to parameters
319                None
320            }
321        }
322    }
323}
324
325/// HTTP-based MCP server connection
326pub struct HttpMcpConnection {
327    server_id: String,
328    server_name: String,
329    transport: Arc<dyn McpTransport>,
330}
331
332impl HttpMcpConnection {
333    pub async fn new(
334        server_id: String,
335        server_name: String,
336        url: String,
337        timeout_secs: Option<u64>,
338        headers: Option<HashMap<String, String>>,
339    ) -> Result<Self> {
340        let transport = HttpTransport::new(url, timeout_secs, headers)?;
341
342        let connection = Self {
343            server_id,
344            server_name,
345            transport: Arc::new(transport),
346        };
347
348        // Initialize the connection
349        connection.initialize().await?;
350
351        Ok(connection)
352    }
353
354    async fn initialize(&self) -> Result<()> {
355        let init_params = serde_json::json!({
356            "protocolVersion": "2025-03-26",
357            "capabilities": {
358                "tools": {}
359            },
360            "clientInfo": {
361                "name": "mistral.rs",
362                "version": "0.6.0"
363            }
364        });
365
366        self.transport
367            .send_request("initialize", init_params)
368            .await?;
369        self.transport.send_initialization_notification().await?;
370        Ok(())
371    }
372}
373
374#[async_trait::async_trait]
375impl McpServerConnection for HttpMcpConnection {
376    fn server_id(&self) -> &str {
377        &self.server_id
378    }
379
380    fn server_name(&self) -> &str {
381        &self.server_name
382    }
383
384    async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
385        let result = self
386            .transport
387            .send_request("tools/list", Value::Null)
388            .await?;
389
390        let tools = result
391            .get("tools")
392            .and_then(|t| t.as_array())
393            .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
394
395        let mut tool_infos = Vec::new();
396        for tool in tools {
397            let name = tool
398                .get("name")
399                .and_then(|n| n.as_str())
400                .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
401                .to_string();
402
403            let description = tool
404                .get("description")
405                .and_then(|d| d.as_str())
406                .map(|s| s.to_string());
407
408            let input_schema = tool
409                .get("inputSchema")
410                .cloned()
411                .unwrap_or(Value::Object(serde_json::Map::new()));
412
413            tool_infos.push(McpToolInfo {
414                name,
415                description,
416                input_schema,
417                server_id: self.server_id.clone(),
418                server_name: self.server_name.clone(),
419            });
420        }
421
422        Ok(tool_infos)
423    }
424
425    async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
426        let params = serde_json::json!({
427            "name": name,
428            "arguments": arguments
429        });
430
431        let result = self.transport.send_request("tools/call", params).await?;
432
433        // Parse the MCP tool result
434        let tool_result: McpToolResult = serde_json::from_value(result)?;
435
436        // Check if the result indicates an error
437        if tool_result.is_error.unwrap_or(false) {
438            return Err(anyhow::anyhow!(
439                "Tool execution failed: {}",
440                tool_result.to_string()
441            ));
442        }
443
444        Ok(tool_result.to_string())
445    }
446
447    async fn list_resources(&self) -> Result<Vec<Resource>> {
448        let result = self
449            .transport
450            .send_request("resources/list", Value::Null)
451            .await?;
452
453        let resources = result
454            .get("resources")
455            .and_then(|r| r.as_array())
456            .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
457
458        let mut resource_list = Vec::new();
459        for resource in resources {
460            let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
461            resource_list.push(mcp_resource);
462        }
463
464        Ok(resource_list)
465    }
466
467    async fn read_resource(&self, uri: &str) -> Result<String> {
468        let params = serde_json::json!({ "uri": uri });
469        let result = self
470            .transport
471            .send_request("resources/read", params)
472            .await?;
473
474        // Extract content from the response
475        if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
476            if let Some(first_content) = contents.first() {
477                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
478                    return Ok(text.to_string());
479                }
480            }
481        }
482
483        Err(anyhow::anyhow!("No readable content found in resource"))
484    }
485
486    async fn ping(&self) -> Result<()> {
487        // Send a simple ping to check if the server is responsive
488        self.transport.send_request("ping", Value::Null).await?;
489        Ok(())
490    }
491}
492
493/// Process-based MCP server connection
494pub struct ProcessMcpConnection {
495    server_id: String,
496    server_name: String,
497    transport: Arc<dyn McpTransport>,
498}
499
500impl ProcessMcpConnection {
501    pub async fn new(
502        server_id: String,
503        server_name: String,
504        command: String,
505        args: Vec<String>,
506        work_dir: Option<String>,
507        env: Option<HashMap<String, String>>,
508    ) -> Result<Self> {
509        let transport = ProcessTransport::new(command, args, work_dir, env).await?;
510
511        let connection = Self {
512            server_id,
513            server_name,
514            transport: Arc::new(transport),
515        };
516
517        // Initialize the connection
518        connection.initialize().await?;
519
520        Ok(connection)
521    }
522
523    async fn initialize(&self) -> Result<()> {
524        let init_params = serde_json::json!({
525            "protocolVersion": "2025-03-26",
526            "capabilities": {
527                "tools": {}
528            },
529            "clientInfo": {
530                "name": "mistral.rs",
531                "version": "0.6.0"
532            }
533        });
534
535        self.transport
536            .send_request("initialize", init_params)
537            .await?;
538        self.transport.send_initialization_notification().await?;
539        Ok(())
540    }
541}
542
543#[async_trait::async_trait]
544impl McpServerConnection for ProcessMcpConnection {
545    fn server_id(&self) -> &str {
546        &self.server_id
547    }
548
549    fn server_name(&self) -> &str {
550        &self.server_name
551    }
552
553    async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
554        let result = self
555            .transport
556            .send_request("tools/list", Value::Null)
557            .await?;
558
559        let tools = result
560            .get("tools")
561            .and_then(|t| t.as_array())
562            .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
563
564        let mut tool_infos = Vec::new();
565        for tool in tools {
566            let name = tool
567                .get("name")
568                .and_then(|n| n.as_str())
569                .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
570                .to_string();
571
572            let description = tool
573                .get("description")
574                .and_then(|d| d.as_str())
575                .map(|s| s.to_string());
576
577            let input_schema = tool
578                .get("inputSchema")
579                .cloned()
580                .unwrap_or(Value::Object(serde_json::Map::new()));
581
582            tool_infos.push(McpToolInfo {
583                name,
584                description,
585                input_schema,
586                server_id: self.server_id.clone(),
587                server_name: self.server_name.clone(),
588            });
589        }
590
591        Ok(tool_infos)
592    }
593
594    async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
595        let params = serde_json::json!({
596            "name": name,
597            "arguments": arguments
598        });
599
600        let result = self.transport.send_request("tools/call", params).await?;
601
602        // Parse the MCP tool result
603        let tool_result: McpToolResult = serde_json::from_value(result)?;
604
605        // Check if the result indicates an error
606        if tool_result.is_error.unwrap_or(false) {
607            return Err(anyhow::anyhow!(
608                "Tool execution failed: {}",
609                tool_result.to_string()
610            ));
611        }
612
613        Ok(tool_result.to_string())
614    }
615
616    async fn list_resources(&self) -> Result<Vec<Resource>> {
617        let result = self
618            .transport
619            .send_request("resources/list", Value::Null)
620            .await?;
621
622        let resources = result
623            .get("resources")
624            .and_then(|r| r.as_array())
625            .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
626
627        let mut resource_list = Vec::new();
628        for resource in resources {
629            let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
630            resource_list.push(mcp_resource);
631        }
632
633        Ok(resource_list)
634    }
635
636    async fn read_resource(&self, uri: &str) -> Result<String> {
637        let params = serde_json::json!({ "uri": uri });
638        let result = self
639            .transport
640            .send_request("resources/read", params)
641            .await?;
642
643        // Extract content from the response
644        if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
645            if let Some(first_content) = contents.first() {
646                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
647                    return Ok(text.to_string());
648                }
649            }
650        }
651
652        Err(anyhow::anyhow!("No readable content found in resource"))
653    }
654
655    async fn ping(&self) -> Result<()> {
656        // Send a simple ping to check if the server is responsive
657        self.transport.send_request("ping", Value::Null).await?;
658        Ok(())
659    }
660}
661
662/// WebSocket-based MCP server connection
663pub struct WebSocketMcpConnection {
664    server_id: String,
665    server_name: String,
666    transport: Arc<dyn McpTransport>,
667}
668
669impl WebSocketMcpConnection {
670    pub async fn new(
671        server_id: String,
672        server_name: String,
673        url: String,
674        timeout_secs: Option<u64>,
675        headers: Option<HashMap<String, String>>,
676    ) -> Result<Self> {
677        let transport = WebSocketTransport::new(url, timeout_secs, headers).await?;
678
679        let connection = Self {
680            server_id,
681            server_name,
682            transport: Arc::new(transport),
683        };
684
685        // Initialize the connection
686        connection.initialize().await?;
687
688        Ok(connection)
689    }
690
691    async fn initialize(&self) -> Result<()> {
692        let init_params = serde_json::json!({
693            "protocolVersion": "2025-03-26",
694            "capabilities": {
695                "tools": {}
696            },
697            "clientInfo": {
698                "name": "mistral.rs",
699                "version": "0.6.0"
700            }
701        });
702
703        self.transport
704            .send_request("initialize", init_params)
705            .await?;
706        self.transport.send_initialization_notification().await?;
707        Ok(())
708    }
709}
710
711#[async_trait::async_trait]
712impl McpServerConnection for WebSocketMcpConnection {
713    fn server_id(&self) -> &str {
714        &self.server_id
715    }
716
717    fn server_name(&self) -> &str {
718        &self.server_name
719    }
720
721    async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
722        let result = self
723            .transport
724            .send_request("tools/list", Value::Null)
725            .await?;
726
727        let tools = result
728            .get("tools")
729            .and_then(|t| t.as_array())
730            .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
731
732        let mut tool_infos = Vec::new();
733        for tool in tools {
734            let name = tool
735                .get("name")
736                .and_then(|n| n.as_str())
737                .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
738                .to_string();
739
740            let description = tool
741                .get("description")
742                .and_then(|d| d.as_str())
743                .map(|s| s.to_string());
744
745            let input_schema = tool
746                .get("inputSchema")
747                .cloned()
748                .unwrap_or(Value::Object(serde_json::Map::new()));
749
750            tool_infos.push(McpToolInfo {
751                name,
752                description,
753                input_schema,
754                server_id: self.server_id.clone(),
755                server_name: self.server_name.clone(),
756            });
757        }
758
759        Ok(tool_infos)
760    }
761
762    async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
763        let params = serde_json::json!({
764            "name": name,
765            "arguments": arguments
766        });
767
768        let result = self.transport.send_request("tools/call", params).await?;
769
770        // Parse the MCP tool result
771        let tool_result: McpToolResult = serde_json::from_value(result)?;
772
773        // Check if the result indicates an error
774        if tool_result.is_error.unwrap_or(false) {
775            return Err(anyhow::anyhow!(
776                "Tool execution failed: {}",
777                tool_result.to_string()
778            ));
779        }
780
781        Ok(tool_result.to_string())
782    }
783
784    async fn list_resources(&self) -> Result<Vec<Resource>> {
785        let result = self
786            .transport
787            .send_request("resources/list", Value::Null)
788            .await?;
789
790        let resources = result
791            .get("resources")
792            .and_then(|r| r.as_array())
793            .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
794
795        let mut resource_list = Vec::new();
796        for resource in resources {
797            let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
798            resource_list.push(mcp_resource);
799        }
800
801        Ok(resource_list)
802    }
803
804    async fn read_resource(&self, uri: &str) -> Result<String> {
805        let params = serde_json::json!({ "uri": uri });
806        let result = self
807            .transport
808            .send_request("resources/read", params)
809            .await?;
810
811        // Extract content from the response
812        if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
813            if let Some(first_content) = contents.first() {
814                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
815                    return Ok(text.to_string());
816                }
817            }
818        }
819
820        Err(anyhow::anyhow!("No readable content found in resource"))
821    }
822
823    async fn ping(&self) -> Result<()> {
824        // Send a simple ping to check if the server is responsive
825        self.transport.send_request("ping", Value::Null).await?;
826        Ok(())
827    }
828}