mistralrs_mcp/
transport.rs

1use anyhow::Result;
2use futures_util::stream::{SplitSink, SplitStream};
3use futures_util::{SinkExt, StreamExt};
4use http::{Request, Uri};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::time::Duration;
8use tokio::net::TcpStream;
9use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
10
11/// Transport layer for MCP communication
12#[async_trait::async_trait]
13pub trait McpTransport: Send + Sync {
14    /// Send a JSON-RPC request and receive a response
15    async fn send_request(&self, method: &str, params: Value) -> Result<Value>;
16
17    /// Check if the transport connection is healthy
18    async fn ping(&self) -> Result<()>;
19
20    /// Close the transport connection
21    async fn close(&self) -> Result<()>;
22
23    /// Send initialization noification
24    async fn send_initialization_notification(&self) -> Result<()>;
25}
26
27/// HTTP-based MCP transport
28///
29/// Provides communication with MCP servers over HTTP using JSON-RPC 2.0 protocol.
30/// This transport is ideal for RESTful MCP services, public APIs, and servers
31/// behind load balancers. Supports both regular JSON responses and Server-Sent Events (SSE).
32///
33/// # Features
34///
35/// - **HTTP/HTTPS Support**: Secure communication with TLS encryption
36/// - **Server-Sent Events**: Handles streaming responses via SSE format
37/// - **Bearer Token Authentication**: Automatic Authorization header injection
38/// - **Custom Headers**: Support for additional headers (API keys, versioning, etc.)
39/// - **Configurable Timeouts**: Request-level timeout control
40/// - **Error Handling**: Comprehensive JSON-RPC and HTTP error handling
41///
42/// # Use Cases
43///
44/// - **Public MCP APIs**: Connect to hosted MCP services
45/// - **RESTful Services**: Integration with REST-based tool providers
46/// - **Load-Balanced Servers**: Works well behind HTTP load balancers
47/// - **Development/Testing**: Easy debugging with standard HTTP tools
48///
49/// # Example Usage
50///
51/// ```rust,no_run
52/// use mistralrs_mcp::transport::{HttpTransport, McpTransport};
53/// use std::collections::HashMap;
54///
55/// #[tokio::main]
56/// async fn main() -> anyhow::Result<()> {
57///     // Create headers with Bearer token and API version
58///     let mut headers = HashMap::new();
59///     headers.insert("Authorization".to_string(), "Bearer your-api-token".to_string());
60///     headers.insert("X-API-Version".to_string(), "v1".to_string());
61///
62///     // Connect to HTTP MCP server
63///     let transport = HttpTransport::new(
64///         "https://api.example.com/mcp".to_string(),
65///         Some(30), // 30 second timeout
66///         Some(headers)
67///     )?;
68///
69///     // Use the transport for MCP communication
70///     let result = transport.send_request("tools/list", serde_json::Value::Null).await?;
71///     println!("Available tools: {}", result);
72///
73///     Ok(())
74/// }
75/// ```
76///
77/// # Protocol Support
78///
79/// This transport implements JSON-RPC 2.0 over HTTP with support for:
80/// - Standard JSON responses
81/// - Server-Sent Events (SSE) for streaming data
82/// - Bearer token authentication
83/// - Custom HTTP headers
84pub struct HttpTransport {
85    client: reqwest::Client,
86    base_url: String,
87    headers: HashMap<String, String>,
88    request_id: std::sync::Arc<std::sync::atomic::AtomicU64>,
89}
90
91impl HttpTransport {
92    /// Creates a new HTTP transport for MCP communication
93    ///
94    /// # Arguments
95    ///
96    /// * `base_url` - Base URL of the MCP server (http:// or https://)
97    /// * `timeout_secs` - Optional timeout for HTTP requests in seconds (defaults to 30s)
98    /// * `headers` - Optional custom headers to include in all requests
99    ///
100    /// # Returns
101    ///
102    /// A configured HttpTransport ready for MCP communication
103    ///
104    /// # Errors
105    ///
106    /// - Invalid URL format
107    /// - HTTP client configuration errors
108    /// - TLS/SSL setup failures
109    ///
110    /// # Example
111    ///
112    /// ```rust,no_run
113    /// use mistralrs_mcp::transport::HttpTransport;
114    /// use std::collections::HashMap;
115    ///
116    /// // Basic HTTP transport
117    /// let transport = HttpTransport::new(
118    ///     "https://api.example.com/mcp".to_string(),
119    ///     Some(60), // 1 minute timeout
120    ///     None
121    /// )?;
122    ///
123    /// // With custom headers and authentication
124    /// let mut headers = HashMap::new();
125    /// headers.insert("Authorization".to_string(), "Bearer token123".to_string());
126    /// headers.insert("X-Client-Version".to_string(), "1.0.0".to_string());
127    ///
128    /// let transport = HttpTransport::new(
129    ///     "https://secure-api.example.com/mcp".to_string(),
130    ///     Some(30),
131    ///     Some(headers)
132    /// )?;
133    /// # Ok::<(), anyhow::Error>(())
134    /// ```
135    pub fn new(
136        base_url: String,
137        timeout_secs: Option<u64>,
138        headers: Option<HashMap<String, String>>,
139    ) -> Result<Self> {
140        let timeout = timeout_secs
141            .map(Duration::from_secs)
142            .unwrap_or(Duration::from_secs(30));
143        let client = reqwest::Client::builder().timeout(timeout).build()?;
144
145        Ok(Self {
146            client,
147            base_url,
148            headers: headers.unwrap_or_default(),
149            request_id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)),
150        })
151    }
152
153    /// Parse Server-Sent Events response to extract JSON-RPC message
154    ///
155    /// Handles SSE format used by some MCP servers for streaming responses.
156    /// SSE format: `data: <json>\n\n` or `event: <type>\ndata: <json>\n\n`
157    ///
158    /// # Arguments
159    ///
160    /// * `sse_text` - Raw SSE response text from the server
161    ///
162    /// # Returns
163    ///
164    /// Parsed JSON value from the SSE data field
165    ///
166    /// # Errors
167    ///
168    /// - No valid JSON data found in SSE response
169    /// - Malformed SSE format
170    /// - JSON parsing errors
171    fn parse_sse_response(sse_text: &str) -> Result<Value> {
172        // SSE format: data: <json>\n\n or event: <type>\ndata: <json>\n\n
173        let mut json_data = None;
174
175        for line in sse_text.lines() {
176            let line = line.trim();
177
178            // Skip empty lines and comments
179            if line.is_empty() || line.starts_with(':') {
180                continue;
181            }
182
183            // Parse SSE field
184            if let Some((field, value)) = line.split_once(':') {
185                let field = field.trim();
186                let value = value.trim();
187
188                match field {
189                    "data" => {
190                        // Try to parse the JSON data
191                        if let Ok(parsed) = serde_json::from_str::<Value>(value) {
192                            json_data = Some(parsed);
193                            break;
194                        }
195                    }
196                    "event" => {
197                        // Handle different event types if needed
198                        continue;
199                    }
200                    _ => {
201                        // Ignore other SSE fields like id, retry, etc.
202                        continue;
203                    }
204                }
205            }
206        }
207
208        json_data.ok_or_else(|| anyhow::anyhow!("No valid JSON data found in SSE response"))
209    }
210}
211
212#[async_trait::async_trait]
213impl McpTransport for HttpTransport {
214    /// Sends an MCP request over HTTP and returns the response
215    ///
216    /// This method implements JSON-RPC 2.0 over HTTP with support for both
217    /// standard JSON responses and Server-Sent Events (SSE). It handles
218    /// authentication, custom headers, and comprehensive error reporting.
219    ///
220    /// # Arguments
221    ///
222    /// * `method` - The MCP method name (e.g., "tools/list", "tools/call", "resources/read")
223    /// * `params` - JSON parameters for the method call
224    ///
225    /// # Returns
226    ///
227    /// The result portion of the JSON-RPC response
228    ///
229    /// # Errors
230    ///
231    /// - HTTP connection errors (network issues, DNS resolution)
232    /// - HTTP status errors (4xx, 5xx responses)
233    /// - JSON serialization/deserialization errors
234    /// - MCP server errors (returned in JSON-RPC error field)
235    /// - SSE parsing errors for streaming responses
236    ///
237    /// # Example
238    ///
239    /// ```rust,no_run
240    /// use mistralrs_mcp::transport::{HttpTransport, McpTransport};
241    /// use serde_json::json;
242    ///
243    /// #[tokio::main]
244    /// async fn main() -> anyhow::Result<()> {
245    ///     let transport = HttpTransport::new(
246    ///         "https://api.example.com/mcp".to_string(),
247    ///         None,
248    ///         None
249    ///     )?;
250    ///     
251    ///     // List available tools
252    ///     let tools = transport.send_request("tools/list", serde_json::Value::Null).await?;
253    ///     
254    ///     // Call a specific tool
255    ///     let params = json!({
256    ///         "name": "search",
257    ///         "arguments": {"query": "example search"}
258    ///     });
259    ///     let result = transport.send_request("tools/call", params).await?;
260    ///     
261    ///     Ok(())
262    /// }
263    /// ```
264    async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
265        // Ensure params is an object, not null
266        let params = if params.is_null() {
267            serde_json::json!({})
268        } else {
269            params
270        };
271
272        let id = self
273            .request_id
274            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
275        let request_body = serde_json::json!({
276            "jsonrpc": "2.0",
277            "id": id,
278            "method": method,
279            "params": params
280        });
281
282        let mut request_builder = self
283            .client
284            .post(&self.base_url)
285            .json(&request_body)
286            .header("Accept", "application/json, text/event-stream");
287
288        // Add custom headers
289        for (key, value) in &self.headers {
290            request_builder = request_builder.header(key, value);
291        }
292
293        let response = request_builder.send().await?;
294
295        // Check content type and handle accordingly
296        let content_type = response
297            .headers()
298            .get("content-type")
299            .and_then(|v| v.to_str().ok())
300            .unwrap_or("");
301
302        let response_body: Value = if content_type.contains("text/event-stream") {
303            // Handle Server-Sent Events
304            let response_text = response.text().await?;
305            Self::parse_sse_response(&response_text)?
306        } else {
307            // Handle regular JSON response
308            response.json().await?
309        };
310
311        // Check for JSON-RPC errors
312        if let Some(error) = response_body.get("error") {
313            return Err(anyhow::anyhow!("MCP server error: {}", error));
314        }
315
316        response_body
317            .get("result")
318            .cloned()
319            .ok_or_else(|| anyhow::anyhow!("No result in MCP response"))
320    }
321
322    /// Tests the HTTP connection by sending a ping request
323    ///
324    /// Sends a "ping" method call to verify that the MCP server is responsive
325    /// and the HTTP connection is working properly.
326    ///
327    /// # Returns
328    ///
329    /// Ok(()) if the ping was successful
330    ///
331    /// # Errors
332    ///
333    /// - HTTP connection errors
334    /// - Server unavailable or unresponsive
335    /// - Authentication failures
336    async fn ping(&self) -> Result<()> {
337        self.send_request("ping", Value::Null).await?;
338        Ok(())
339    }
340
341    /// Closes the HTTP transport connection
342    ///
343    /// HTTP connections are stateless and managed by the underlying HTTP client,
344    /// so this method is a no-op but provided for interface compatibility.
345    ///
346    /// # Returns
347    ///
348    /// Always returns Ok(()) as HTTP connections don't require explicit cleanup
349    async fn close(&self) -> Result<()> {
350        // HTTP connections don't need explicit closing
351        Ok(())
352    }
353
354    /// Sends the server a initialization notification to let it know we are done initializing
355    async fn send_initialization_notification(&self) -> Result<()> {
356        let request_body = serde_json::json!({
357            "jsonrpc": "2.0",
358            "method": "notifications/initialized",
359        });
360
361        let mut request_builder = self
362            .client
363            .post(&self.base_url)
364            .json(&request_body)
365            .header("Accept", "application/json, text/event-stream");
366
367        // Add custom headers
368        for (key, value) in &self.headers {
369            request_builder = request_builder.header(key, value);
370        }
371
372        request_builder.send().await?;
373        Ok(())
374    }
375}
376
377/// Process-based MCP transport using stdin/stdout communication
378///
379/// Provides communication with local MCP servers running as separate processes
380/// using JSON-RPC 2.0 over stdin/stdout pipes. This transport is ideal for
381/// local tools, development servers, and sandboxed environments where you need
382/// process isolation and direct control over the MCP server lifecycle.
383///
384/// # Features
385///
386/// - **Process Isolation**: Each MCP server runs in its own process for security
387/// - **No Network Overhead**: Direct pipe communication for maximum performance
388/// - **Environment Control**: Full control over working directory and environment variables
389/// - **Resource Management**: Automatic process cleanup and lifecycle management
390/// - **Synchronous Communication**: Request/response correlation over stdin/stdout
391/// - **Error Handling**: Comprehensive process and communication error handling
392///
393/// # Use Cases
394///
395/// - **Local Development**: Running MCP servers during development and testing
396/// - **Filesystem Tools**: Local file operations and system utilities
397/// - **Sandboxed Execution**: Isolated execution environments for security
398/// - **Custom Tools**: Private or proprietary MCP servers
399/// - **CI/CD Integration**: Running MCP servers in automated environments
400///
401/// # Example Usage
402///
403/// ```rust,no_run
404/// use mistralrs_mcp::transport::{ProcessTransport, McpTransport};
405/// use std::collections::HashMap;
406///
407/// #[tokio::main]
408/// async fn main() -> anyhow::Result<()> {
409///     // Basic process transport
410///     let transport = ProcessTransport::new(
411///         "mcp-server-filesystem".to_string(),
412///         vec!["--root".to_string(), "/tmp".to_string()],
413///         None,
414///         None
415///     ).await?;
416///
417///     // With custom working directory and environment
418///     let mut env = HashMap::new();
419///     env.insert("MCP_LOG_LEVEL".to_string(), "debug".to_string());
420///     env.insert("MCP_TIMEOUT".to_string(), "30".to_string());
421///
422///     let transport = ProcessTransport::new(
423///         "/usr/local/bin/my-mcp-server".to_string(),
424///         vec!["--config".to_string(), "production.json".to_string()],
425///         Some("/opt/mcp-server".to_string()), // Working directory
426///         Some(env) // Environment variables
427///     ).await?;
428///
429///     // Use the transport for MCP communication
430///     let result = transport.send_request("tools/list", serde_json::Value::Null).await?;
431///     println!("Available tools: {}", result);
432///
433///     Ok(())
434/// }
435/// ```
436///
437/// # Process Management
438///
439/// The transport automatically manages the child process lifecycle:
440/// - Spawns the process with configured arguments and environment
441/// - Sets up stdin/stdout pipes for JSON-RPC communication
442/// - Monitors process health and handles crashes
443/// - Cleans up resources when the transport is dropped or closed
444///
445/// # Communication Protocol
446///
447/// Uses JSON-RPC 2.0 over stdin/stdout with line-delimited messages:
448/// - Each request is a single line of JSON sent to stdin
449/// - Each response is a single line of JSON read from stdout
450/// - Stderr is captured for debugging and error reporting
451pub struct ProcessTransport {
452    child: std::sync::Arc<tokio::sync::Mutex<tokio::process::Child>>,
453    request_id: std::sync::Arc<std::sync::atomic::AtomicU64>,
454    stdin: std::sync::Arc<tokio::sync::Mutex<tokio::process::ChildStdin>>,
455    stdout_reader:
456        std::sync::Arc<tokio::sync::Mutex<tokio::io::BufReader<tokio::process::ChildStdout>>>,
457}
458
459impl ProcessTransport {
460    /// Creates a new process transport by spawning an MCP server process
461    ///
462    /// This constructor spawns a new process with the specified command, arguments,
463    /// and environment, then sets up stdin/stdout pipes for JSON-RPC communication.
464    /// The process is ready to receive MCP requests immediately after creation.
465    ///
466    /// # Arguments
467    ///
468    /// * `command` - The command to execute (e.g., "mcp-server-filesystem", "/usr/bin/python")
469    /// * `args` - Command-line arguments to pass to the process
470    /// * `work_dir` - Optional working directory for the process (defaults to current directory)
471    /// * `env` - Optional environment variables to set for the process
472    ///
473    /// # Returns
474    ///
475    /// A configured ProcessTransport with the spawned process ready for communication
476    ///
477    /// # Errors
478    ///
479    /// - Command not found or not executable
480    /// - Permission denied errors
481    /// - Process spawn failures
482    /// - Pipe setup errors (stdin/stdout/stderr)
483    /// - Working directory access errors
484    ///
485    /// # Example
486    ///
487    /// ```rust,no_run
488    /// use mistralrs_mcp::transport::ProcessTransport;
489    /// use std::collections::HashMap;
490    ///
491    /// #[tokio::main]
492    /// async fn main() -> anyhow::Result<()> {
493    ///     // Simple filesystem server
494    ///     let transport = ProcessTransport::new(
495    ///         "mcp-server-filesystem".to_string(),
496    ///         vec!["--root".to_string(), "/home/user/documents".to_string()],
497    ///         None,
498    ///         None
499    ///     ).await?;
500    ///
501    ///     // Python-based MCP server with custom environment
502    ///     let mut env = HashMap::new();
503    ///     env.insert("PYTHONPATH".to_string(), "/opt/mcp-servers".to_string());
504    ///     env.insert("MCP_DEBUG".to_string(), "1".to_string());
505    ///
506    ///     let transport = ProcessTransport::new(
507    ///         "python".to_string(),
508    ///         vec!["-m".to_string(), "my_mcp_server".to_string(), "--port".to_string(), "8080".to_string()],
509    ///         Some("/opt/mcp-servers".to_string()),
510    ///         Some(env)
511    ///     ).await?;
512    ///
513    ///     Ok(())
514    /// }
515    /// ```
516    pub async fn new(
517        command: String,
518        args: Vec<String>,
519        work_dir: Option<String>,
520        env: Option<HashMap<String, String>>,
521    ) -> Result<Self> {
522        use tokio::io::BufReader;
523        use tokio::process::Command;
524
525        let mut cmd = Command::new(command);
526        cmd.args(args)
527            .stdin(std::process::Stdio::piped())
528            .stdout(std::process::Stdio::piped())
529            .stderr(std::process::Stdio::piped());
530
531        if let Some(dir) = work_dir {
532            cmd.current_dir(dir);
533        }
534
535        if let Some(env_vars) = env {
536            for (key, value) in env_vars {
537                cmd.env(key, value);
538            }
539        }
540
541        let mut child = cmd.spawn()?;
542        let stdin = child
543            .stdin
544            .take()
545            .ok_or_else(|| anyhow::anyhow!("Failed to get stdin handle"))?;
546        let stdout = child
547            .stdout
548            .take()
549            .ok_or_else(|| anyhow::anyhow!("Failed to get stdout handle"))?;
550        let stdout_reader = BufReader::new(stdout);
551
552        Ok(Self {
553            child: std::sync::Arc::new(tokio::sync::Mutex::new(child)),
554            request_id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)),
555            stdin: std::sync::Arc::new(tokio::sync::Mutex::new(stdin)),
556            stdout_reader: std::sync::Arc::new(tokio::sync::Mutex::new(stdout_reader)),
557        })
558    }
559}
560
561#[async_trait::async_trait]
562impl McpTransport for ProcessTransport {
563    /// Sends an MCP request to the child process and returns the response
564    ///
565    /// This method implements JSON-RPC 2.0 over stdin/stdout pipes. It sends
566    /// a line-delimited JSON request to the process stdin and reads the
567    /// corresponding response from stdout. Communication is synchronous with
568    /// proper request/response correlation.
569    ///
570    /// # Arguments
571    ///
572    /// * `method` - The MCP method name (e.g., "tools/list", "tools/call", "resources/read")
573    /// * `params` - JSON parameters for the method call
574    ///
575    /// # Returns
576    ///
577    /// The result portion of the JSON-RPC response
578    ///
579    /// # Errors
580    ///
581    /// - Process communication errors (broken pipes)
582    /// - Process crashes or unexpected termination
583    /// - JSON serialization/deserialization errors
584    /// - MCP server errors (returned in JSON-RPC error field)
585    /// - I/O errors on stdin/stdout
586    ///
587    /// # Example
588    ///
589    /// ```rust,no_run
590    /// use mistralrs_mcp::transport::{ProcessTransport, McpTransport};
591    /// use serde_json::json;
592    ///
593    /// #[tokio::main]
594    /// async fn main() -> anyhow::Result<()> {
595    ///     let transport = ProcessTransport::new(
596    ///         "mcp-server-filesystem".to_string(),
597    ///         vec!["--root".to_string(), "/tmp".to_string()],
598    ///         None,
599    ///         None
600    ///     ).await?;
601    ///     
602    ///     // List available tools
603    ///     let tools = transport.send_request("tools/list", serde_json::Value::Null).await?;
604    ///     
605    ///     // Call a specific tool
606    ///     let params = json!({
607    ///         "name": "read_file",
608    ///         "arguments": {"path": "/tmp/example.txt"}
609    ///     });
610    ///     let result = transport.send_request("tools/call", params).await?;
611    ///     
612    ///     Ok(())
613    /// }
614    /// ```
615    async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
616        use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
617
618        // Ensure params is an object, not null
619        let params = if params.is_null() {
620            serde_json::json!({})
621        } else {
622            params
623        };
624        // Generate unique request ID
625        let id = self
626            .request_id
627            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
628        let request_body = serde_json::json!({
629            "jsonrpc": "2.0",
630            "id": id,
631            "method": method,
632            "params": params
633        });
634
635        // Send request via stdin
636        let request_line = serde_json::to_string(&request_body)? + "\n";
637        let mut stdin = self.stdin.lock().await;
638        stdin.write_all(request_line.as_bytes()).await?;
639        stdin.flush().await?;
640        drop(stdin);
641
642        // Read response from stdout
643        let mut stdout_reader = self.stdout_reader.lock().await;
644        let mut response_line = String::new();
645        stdout_reader.read_line(&mut response_line).await?;
646        drop(stdout_reader);
647
648        let response_body: Value = serde_json::from_str(response_line.trim())?;
649
650        // Check for JSON-RPC errors
651        if let Some(error) = response_body.get("error") {
652            return Err(anyhow::anyhow!("MCP server error: {}", error));
653        }
654
655        response_body
656            .get("result")
657            .cloned()
658            .ok_or_else(|| anyhow::anyhow!("No result in MCP response"))
659    }
660
661    /// Tests the process connection by sending a ping request
662    ///
663    /// Sends a "ping" method call to verify that the MCP server process is
664    /// responsive and the stdin/stdout communication is working properly.
665    ///
666    /// # Returns
667    ///
668    /// Ok(()) if the ping was successful
669    ///
670    /// # Errors
671    ///
672    /// - Process communication errors
673    /// - Process crashed or terminated
674    /// - Broken stdin/stdout pipes
675    async fn ping(&self) -> Result<()> {
676        self.send_request("ping", Value::Null).await?;
677        Ok(())
678    }
679
680    /// Terminates the child process and cleans up resources
681    ///
682    /// This method forcefully terminates the MCP server process and closes
683    /// all associated pipes. Any pending requests will fail after this call.
684    /// The transport cannot be used after closing.
685    ///
686    /// # Returns
687    ///
688    /// Ok(()) if the process was terminated successfully
689    ///
690    /// # Errors
691    ///
692    /// - Process termination errors
693    /// - Resource cleanup failures
694    ///
695    /// # Note
696    ///
697    /// This method sends SIGKILL to the process, which may not allow for
698    /// graceful cleanup. Consider implementing graceful shutdown through
699    /// MCP protocol methods before calling this method.
700    async fn close(&self) -> Result<()> {
701        let mut child = self.child.lock().await;
702        child.kill().await?;
703        Ok(())
704    }
705
706    /// Sends the server a initialization notification to let it know we are done initializing
707    async fn send_initialization_notification(&self) -> Result<()> {
708        use tokio::io::AsyncWriteExt;
709        let mut stdin = self.stdin.lock().await;
710        stdin
711            .write_all(
712                format!(
713                    "{}\n",
714                    serde_json::json!({"jsonrpc": "2.0", "method": "notifications/initialized"})
715                )
716                .as_bytes(),
717            )
718            .await?;
719        stdin.flush().await?;
720        drop(stdin);
721        Ok(())
722    }
723}
724
725/// WebSocket-based MCP transport
726///
727/// Provides real-time bidirectional communication with MCP servers over WebSocket connections.
728/// This transport supports secure connections (WSS), Bearer token authentication, and concurrent
729/// request/response handling with proper JSON-RPC 2.0 message correlation.
730///
731/// # Features
732///
733/// - **Async WebSocket Communication**: Built on tokio-tungstenite for high-performance async I/O
734/// - **Request/Response Matching**: Automatic correlation of responses using atomic request IDs
735/// - **Bearer Token Support**: Authentication via Authorization header during handshake
736/// - **Connection Management**: Proper ping/pong and connection lifecycle handling
737/// - **Concurrent Operations**: Split stream architecture allows simultaneous read/write operations
738///
739/// # Architecture
740///
741/// The transport uses a split-stream design where the WebSocket connection is divided into
742/// separate read and write halves, each protected by async mutexes. This allows concurrent
743/// operations while maintaining thread safety. Request IDs are generated atomically to ensure
744/// unique identification of requests and proper response correlation.
745///
746/// # Example Usage
747///
748/// ```rust,no_run
749/// use mistralrs_mcp::transport::{WebSocketTransport, McpTransport};
750/// use std::collections::HashMap;
751///
752/// #[tokio::main]
753/// async fn main() -> anyhow::Result<()> {
754///     // Create headers with Bearer token
755///     let mut headers = HashMap::new();
756///     headers.insert("Authorization".to_string(), "Bearer your-token".to_string());
757///
758///     // Connect to WebSocket MCP server
759///     let transport = WebSocketTransport::new(
760///         "wss://api.example.com/mcp".to_string(),
761///         Some(30), // 30 second timeout
762///         Some(headers)
763///     ).await?;
764///
765///     // Use the transport for MCP communication
766///     let result = transport.send_request("tools/list", serde_json::Value::Null).await?;
767///     println!("Available tools: {}", result);
768///
769///     Ok(())
770/// }
771/// ```
772///
773/// # Protocol Compliance
774///
775/// This transport implements the Model Context Protocol (MCP) specification over WebSocket,
776/// adhering to JSON-RPC 2.0 message format with proper error handling and response correlation.
777pub struct WebSocketTransport {
778    write: std::sync::Arc<
779        tokio::sync::Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
780    >,
781    read:
782        std::sync::Arc<tokio::sync::Mutex<SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>>,
783    request_id: std::sync::Arc<std::sync::atomic::AtomicU64>,
784}
785
786impl WebSocketTransport {
787    /// Creates a new WebSocket transport connection to an MCP server
788    ///
789    /// # Arguments
790    ///
791    /// * `url` - WebSocket URL (ws:// or wss://)
792    /// * `_timeout_secs` - Connection timeout (currently unused, reserved for future use)
793    /// * `headers` - Optional HTTP headers for WebSocket handshake (e.g., Bearer tokens)
794    ///
795    /// # Returns
796    ///
797    /// A configured WebSocketTransport ready for MCP communication
798    ///
799    /// # Errors
800    ///
801    /// - Invalid URL format
802    /// - WebSocket connection failure
803    /// - Header parsing errors
804    /// - Network connectivity issues
805    ///
806    /// # Example
807    ///
808    /// ```rust,no_run
809    /// use mistralrs_mcp::transport::WebSocketTransport;
810    /// use std::collections::HashMap;
811    ///
812    /// #[tokio::main]
813    /// async fn main() -> anyhow::Result<()> {
814    ///     let mut headers = HashMap::new();
815    ///     headers.insert("Authorization".to_string(), "Bearer token123".to_string());
816    ///     
817    ///     let transport = WebSocketTransport::new(
818    ///         "wss://mcp.example.com/api".to_string(),
819    ///         Some(30),
820    ///         Some(headers)
821    ///     ).await?;
822    ///     
823    ///     Ok(())
824    /// }
825    /// ```
826    pub async fn new(
827        url: String,
828        _timeout_secs: Option<u64>,
829        headers: Option<HashMap<String, String>>,
830    ) -> Result<Self> {
831        // Create request with headers
832        let uri: Uri = url
833            .parse()
834            .map_err(|e| anyhow::anyhow!("Invalid WebSocket URL: {}", e))?;
835        let mut request = Request::builder()
836            .uri(uri)
837            .body(())
838            .map_err(|e| anyhow::anyhow!("Failed to create WebSocket request: {}", e))?;
839
840        // Add headers if provided
841        if let Some(headers) = headers {
842            let req_headers = request.headers_mut();
843            for (key, value) in headers {
844                let header_name = key
845                    .parse::<http::header::HeaderName>()
846                    .map_err(|e| anyhow::anyhow!("Invalid header key: {}", e))?;
847                let header_value = value
848                    .parse::<http::header::HeaderValue>()
849                    .map_err(|e| anyhow::anyhow!("Invalid header value: {}", e))?;
850                req_headers.insert(header_name, header_value);
851            }
852        }
853
854        // Connect to WebSocket
855        let (ws_stream, _) = connect_async(request)
856            .await
857            .map_err(|e| anyhow::anyhow!("WebSocket connection failed: {}", e))?;
858
859        // Split the stream
860        let (write, read) = ws_stream.split();
861
862        Ok(Self {
863            write: std::sync::Arc::new(tokio::sync::Mutex::new(write)),
864            read: std::sync::Arc::new(tokio::sync::Mutex::new(read)),
865            request_id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)),
866        })
867    }
868}
869
870#[async_trait::async_trait]
871impl McpTransport for WebSocketTransport {
872    /// Sends an MCP request over WebSocket and waits for the corresponding response
873    ///
874    /// This method implements the JSON-RPC 2.0 protocol over WebSocket, handling:
875    /// - Unique request ID generation for response correlation
876    /// - Concurrent request processing with proper message ordering
877    /// - Error handling for both transport and protocol errors
878    /// - Message filtering to match responses with requests
879    ///
880    /// # Arguments
881    ///
882    /// * `method` - The MCP method name (e.g., "tools/list", "tools/call")
883    /// * `params` - JSON parameters for the method call
884    ///
885    /// # Returns
886    ///
887    /// The result portion of the JSON-RPC response
888    ///
889    /// # Errors
890    ///
891    /// - WebSocket connection errors
892    /// - JSON serialization/deserialization errors  
893    /// - MCP server errors (returned in JSON-RPC error field)
894    /// - Timeout or connection closure
895    ///
896    /// # Example
897    ///
898    /// ```rust,no_run
899    /// use mistralrs_mcp::transport::{WebSocketTransport, McpTransport};
900    /// use serde_json::json;
901    ///
902    /// #[tokio::main]
903    /// async fn main() -> anyhow::Result<()> {
904    ///     let transport = WebSocketTransport::new(
905    ///         "wss://api.example.com/mcp".to_string(),
906    ///         None,
907    ///         None
908    ///     ).await?;
909    ///     
910    ///     // List available tools
911    ///     let tools = transport.send_request("tools/list", serde_json::Value::Null).await?;
912    ///     
913    ///     // Call a specific tool
914    ///     let params = json!({"query": "example search"});
915    ///     let result = transport.send_request("tools/call", params).await?;
916    ///     
917    ///     Ok(())
918    /// }
919    /// ```
920    async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
921        // Ensure params is an object, not null
922        let params = if params.is_null() {
923            serde_json::json!({})
924        } else {
925            params
926        };
927
928        // Generate unique request ID
929        let id = self
930            .request_id
931            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
932
933        let request_body = serde_json::json!({
934            "jsonrpc": "2.0",
935            "id": id,
936            "method": method,
937            "params": params
938        });
939
940        // Send request
941        let message = Message::Text(serde_json::to_string(&request_body)?);
942
943        {
944            let mut write = self.write.lock().await;
945            write
946                .send(message)
947                .await
948                .map_err(|e| anyhow::anyhow!("Failed to send WebSocket message: {}", e))?;
949        }
950
951        // Read response
952        loop {
953            let mut read = self.read.lock().await;
954            let msg = read
955                .next()
956                .await
957                .ok_or_else(|| anyhow::anyhow!("WebSocket connection closed"))?
958                .map_err(|e| anyhow::anyhow!("WebSocket read error: {}", e))?;
959            drop(read);
960
961            match msg {
962                Message::Text(text) => {
963                    let response_body: Value = serde_json::from_str(&text)?;
964
965                    // Check if this is the response to our request
966                    if let Some(response_id) = response_body.get("id").and_then(|v| v.as_u64()) {
967                        if response_id == id {
968                            // Check for JSON-RPC errors
969                            if let Some(error) = response_body.get("error") {
970                                return Err(anyhow::anyhow!("MCP server error: {}", error));
971                            }
972
973                            return response_body
974                                .get("result")
975                                .cloned()
976                                .ok_or_else(|| anyhow::anyhow!("No result in MCP response"));
977                        }
978                    }
979                    // If it's not our response, continue reading
980                }
981                Message::Binary(_) => {
982                    // Handle binary messages if needed, for now skip
983                    continue;
984                }
985                Message::Close(_) => {
986                    return Err(anyhow::anyhow!("WebSocket connection closed by server"));
987                }
988                Message::Ping(_) | Message::Pong(_) => {
989                    // Handle ping/pong frames, continue reading
990                    continue;
991                }
992                Message::Frame(_) => {
993                    // Raw frames, continue reading
994                    continue;
995                }
996            }
997        }
998    }
999
1000    /// Sends a WebSocket ping frame to test connection health
1001    ///
1002    /// This method sends a ping frame to the server and expects a pong response,
1003    /// which helps verify that the WebSocket connection is still active and responsive.
1004    ///
1005    /// # Returns
1006    ///
1007    /// Ok(()) if the ping was sent successfully
1008    ///
1009    /// # Errors
1010    ///
1011    /// - WebSocket send errors
1012    /// - Connection closure
1013    async fn ping(&self) -> Result<()> {
1014        let ping_message = Message::Ping(vec![]);
1015        let mut write = self.write.lock().await;
1016        write
1017            .send(ping_message)
1018            .await
1019            .map_err(|e| anyhow::anyhow!("Failed to send ping: {}", e))?;
1020        Ok(())
1021    }
1022
1023    /// Gracefully closes the WebSocket connection
1024    ///
1025    /// Sends a close frame to the server to properly terminate the connection
1026    /// according to the WebSocket protocol. The server should respond with its
1027    /// own close frame to complete the closing handshake.
1028    ///
1029    /// # Returns
1030    ///
1031    /// Ok(()) if the close frame was sent successfully
1032    ///
1033    /// # Errors
1034    ///
1035    /// - WebSocket send errors
1036    /// - Connection already closed
1037    async fn close(&self) -> Result<()> {
1038        let close_message = Message::Close(None);
1039        let mut write = self.write.lock().await;
1040        write
1041            .send(close_message)
1042            .await
1043            .map_err(|e| anyhow::anyhow!("Failed to send close message: {}", e))?;
1044        Ok(())
1045    }
1046
1047    /// Sends the server a initialization notification to let it know we are done initializing
1048    async fn send_initialization_notification(&self) -> Result<()> {
1049        let request_body = serde_json::json!({
1050            "jsonrpc": "2.0",
1051            "method": "notifications/initialized",
1052        });
1053
1054        // Send request
1055        let message = Message::Text(serde_json::to_string(&request_body)?);
1056
1057        {
1058            let mut write = self.write.lock().await;
1059            write
1060                .send(message)
1061                .await
1062                .map_err(|e| anyhow::anyhow!("Failed to send WebSocket message: {}", e))?;
1063        }
1064        Ok(())
1065    }
1066}