mistralrs_mcp/
transport.rs

1use anyhow::Result;
2use futures_util::stream::{SplitSink, SplitStream};
3use futures_util::{SinkExt, StreamExt};
4use http::{Request, Uri};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::time::Duration;
8use tokio::net::TcpStream;
9use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
10
11/// Transport layer for MCP communication
12#[async_trait::async_trait]
13pub trait McpTransport: Send + Sync {
14    /// Send a JSON-RPC request and receive a response
15    async fn send_request(&self, method: &str, params: Value) -> Result<Value>;
16
17    /// Check if the transport connection is healthy
18    async fn ping(&self) -> Result<()>;
19
20    /// Close the transport connection
21    async fn close(&self) -> Result<()>;
22}
23
24/// HTTP-based MCP transport
25///
26/// Provides communication with MCP servers over HTTP using JSON-RPC 2.0 protocol.
27/// This transport is ideal for RESTful MCP services, public APIs, and servers
28/// behind load balancers. Supports both regular JSON responses and Server-Sent Events (SSE).
29///
30/// # Features
31///
32/// - **HTTP/HTTPS Support**: Secure communication with TLS encryption
33/// - **Server-Sent Events**: Handles streaming responses via SSE format
34/// - **Bearer Token Authentication**: Automatic Authorization header injection
35/// - **Custom Headers**: Support for additional headers (API keys, versioning, etc.)
36/// - **Configurable Timeouts**: Request-level timeout control
37/// - **Error Handling**: Comprehensive JSON-RPC and HTTP error handling
38///
39/// # Use Cases
40///
41/// - **Public MCP APIs**: Connect to hosted MCP services
42/// - **RESTful Services**: Integration with REST-based tool providers
43/// - **Load-Balanced Servers**: Works well behind HTTP load balancers
44/// - **Development/Testing**: Easy debugging with standard HTTP tools
45///
46/// # Example Usage
47///
48/// ```rust,no_run
49/// use mistralrs_mcp::transport::{HttpTransport, McpTransport};
50/// use std::collections::HashMap;
51///
52/// #[tokio::main]
53/// async fn main() -> anyhow::Result<()> {
54///     // Create headers with Bearer token and API version
55///     let mut headers = HashMap::new();
56///     headers.insert("Authorization".to_string(), "Bearer your-api-token".to_string());
57///     headers.insert("X-API-Version".to_string(), "v1".to_string());
58///
59///     // Connect to HTTP MCP server
60///     let transport = HttpTransport::new(
61///         "https://api.example.com/mcp".to_string(),
62///         Some(30), // 30 second timeout
63///         Some(headers)
64///     )?;
65///
66///     // Use the transport for MCP communication
67///     let result = transport.send_request("tools/list", serde_json::Value::Null).await?;
68///     println!("Available tools: {}", result);
69///
70///     Ok(())
71/// }
72/// ```
73///
74/// # Protocol Support
75///
76/// This transport implements JSON-RPC 2.0 over HTTP with support for:
77/// - Standard JSON responses
78/// - Server-Sent Events (SSE) for streaming data
79/// - Bearer token authentication
80/// - Custom HTTP headers
81pub struct HttpTransport {
82    client: reqwest::Client,
83    base_url: String,
84    headers: HashMap<String, String>,
85}
86
87impl HttpTransport {
88    /// Creates a new HTTP transport for MCP communication
89    ///
90    /// # Arguments
91    ///
92    /// * `base_url` - Base URL of the MCP server (http:// or https://)
93    /// * `timeout_secs` - Optional timeout for HTTP requests in seconds (defaults to 30s)
94    /// * `headers` - Optional custom headers to include in all requests
95    ///
96    /// # Returns
97    ///
98    /// A configured HttpTransport ready for MCP communication
99    ///
100    /// # Errors
101    ///
102    /// - Invalid URL format
103    /// - HTTP client configuration errors
104    /// - TLS/SSL setup failures
105    ///
106    /// # Example
107    ///
108    /// ```rust,no_run
109    /// use mistralrs_mcp::transport::HttpTransport;
110    /// use std::collections::HashMap;
111    ///
112    /// // Basic HTTP transport
113    /// let transport = HttpTransport::new(
114    ///     "https://api.example.com/mcp".to_string(),
115    ///     Some(60), // 1 minute timeout
116    ///     None
117    /// )?;
118    ///
119    /// // With custom headers and authentication
120    /// let mut headers = HashMap::new();
121    /// headers.insert("Authorization".to_string(), "Bearer token123".to_string());
122    /// headers.insert("X-Client-Version".to_string(), "1.0.0".to_string());
123    ///
124    /// let transport = HttpTransport::new(
125    ///     "https://secure-api.example.com/mcp".to_string(),
126    ///     Some(30),
127    ///     Some(headers)
128    /// )?;
129    /// # Ok::<(), anyhow::Error>(())
130    /// ```
131    pub fn new(
132        base_url: String,
133        timeout_secs: Option<u64>,
134        headers: Option<HashMap<String, String>>,
135    ) -> Result<Self> {
136        let timeout = timeout_secs
137            .map(Duration::from_secs)
138            .unwrap_or(Duration::from_secs(30));
139        let client = reqwest::Client::builder().timeout(timeout).build()?;
140
141        Ok(Self {
142            client,
143            base_url,
144            headers: headers.unwrap_or_default(),
145        })
146    }
147
148    /// Parse Server-Sent Events response to extract JSON-RPC message
149    ///
150    /// Handles SSE format used by some MCP servers for streaming responses.
151    /// SSE format: `data: <json>\n\n` or `event: <type>\ndata: <json>\n\n`
152    ///
153    /// # Arguments
154    ///
155    /// * `sse_text` - Raw SSE response text from the server
156    ///
157    /// # Returns
158    ///
159    /// Parsed JSON value from the SSE data field
160    ///
161    /// # Errors
162    ///
163    /// - No valid JSON data found in SSE response
164    /// - Malformed SSE format
165    /// - JSON parsing errors
166    fn parse_sse_response(sse_text: &str) -> Result<Value> {
167        // SSE format: data: <json>\n\n or event: <type>\ndata: <json>\n\n
168        let mut json_data = None;
169
170        for line in sse_text.lines() {
171            let line = line.trim();
172
173            // Skip empty lines and comments
174            if line.is_empty() || line.starts_with(':') {
175                continue;
176            }
177
178            // Parse SSE field
179            if let Some((field, value)) = line.split_once(':') {
180                let field = field.trim();
181                let value = value.trim();
182
183                match field {
184                    "data" => {
185                        // Try to parse the JSON data
186                        if let Ok(parsed) = serde_json::from_str::<Value>(value) {
187                            json_data = Some(parsed);
188                            break;
189                        }
190                    }
191                    "event" => {
192                        // Handle different event types if needed
193                        continue;
194                    }
195                    _ => {
196                        // Ignore other SSE fields like id, retry, etc.
197                        continue;
198                    }
199                }
200            }
201        }
202
203        json_data.ok_or_else(|| anyhow::anyhow!("No valid JSON data found in SSE response"))
204    }
205}
206
207#[async_trait::async_trait]
208impl McpTransport for HttpTransport {
209    /// Sends an MCP request over HTTP and returns the response
210    ///
211    /// This method implements JSON-RPC 2.0 over HTTP with support for both
212    /// standard JSON responses and Server-Sent Events (SSE). It handles
213    /// authentication, custom headers, and comprehensive error reporting.
214    ///
215    /// # Arguments
216    ///
217    /// * `method` - The MCP method name (e.g., "tools/list", "tools/call", "resources/read")
218    /// * `params` - JSON parameters for the method call
219    ///
220    /// # Returns
221    ///
222    /// The result portion of the JSON-RPC response
223    ///
224    /// # Errors
225    ///
226    /// - HTTP connection errors (network issues, DNS resolution)
227    /// - HTTP status errors (4xx, 5xx responses)
228    /// - JSON serialization/deserialization errors
229    /// - MCP server errors (returned in JSON-RPC error field)
230    /// - SSE parsing errors for streaming responses
231    ///
232    /// # Example
233    ///
234    /// ```rust,no_run
235    /// use mistralrs_mcp::transport::{HttpTransport, McpTransport};
236    /// use serde_json::json;
237    ///
238    /// #[tokio::main]
239    /// async fn main() -> anyhow::Result<()> {
240    ///     let transport = HttpTransport::new(
241    ///         "https://api.example.com/mcp".to_string(),
242    ///         None,
243    ///         None
244    ///     )?;
245    ///     
246    ///     // List available tools
247    ///     let tools = transport.send_request("tools/list", serde_json::Value::Null).await?;
248    ///     
249    ///     // Call a specific tool
250    ///     let params = json!({
251    ///         "name": "search",
252    ///         "arguments": {"query": "example search"}
253    ///     });
254    ///     let result = transport.send_request("tools/call", params).await?;
255    ///     
256    ///     Ok(())
257    /// }
258    /// ```
259    async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
260        // Ensure params is an object, not null
261        let params = if params.is_null() {
262            serde_json::json!({})
263        } else {
264            params
265        };
266
267        let request_body = serde_json::json!({
268            "jsonrpc": "2.0",
269            "id": 1,
270            "method": method,
271            "params": params
272        });
273
274        let mut request_builder = self
275            .client
276            .post(&self.base_url)
277            .json(&request_body)
278            .header("Accept", "application/json, text/event-stream");
279
280        // Add custom headers
281        for (key, value) in &self.headers {
282            request_builder = request_builder.header(key, value);
283        }
284
285        let response = request_builder.send().await?;
286
287        // Check content type and handle accordingly
288        let content_type = response
289            .headers()
290            .get("content-type")
291            .and_then(|v| v.to_str().ok())
292            .unwrap_or("");
293
294        let response_body: Value = if content_type.contains("text/event-stream") {
295            // Handle Server-Sent Events
296            let response_text = response.text().await?;
297            Self::parse_sse_response(&response_text)?
298        } else {
299            // Handle regular JSON response
300            response.json().await?
301        };
302
303        // Check for JSON-RPC errors
304        if let Some(error) = response_body.get("error") {
305            return Err(anyhow::anyhow!("MCP server error: {}", error));
306        }
307
308        response_body
309            .get("result")
310            .cloned()
311            .ok_or_else(|| anyhow::anyhow!("No result in MCP response"))
312    }
313
314    /// Tests the HTTP connection by sending a ping request
315    ///
316    /// Sends a "ping" method call to verify that the MCP server is responsive
317    /// and the HTTP connection is working properly.
318    ///
319    /// # Returns
320    ///
321    /// Ok(()) if the ping was successful
322    ///
323    /// # Errors
324    ///
325    /// - HTTP connection errors
326    /// - Server unavailable or unresponsive
327    /// - Authentication failures
328    async fn ping(&self) -> Result<()> {
329        self.send_request("ping", Value::Null).await?;
330        Ok(())
331    }
332
333    /// Closes the HTTP transport connection
334    ///
335    /// HTTP connections are stateless and managed by the underlying HTTP client,
336    /// so this method is a no-op but provided for interface compatibility.
337    ///
338    /// # Returns
339    ///
340    /// Always returns Ok(()) as HTTP connections don't require explicit cleanup
341    async fn close(&self) -> Result<()> {
342        // HTTP connections don't need explicit closing
343        Ok(())
344    }
345}
346
347/// Process-based MCP transport using stdin/stdout communication
348///
349/// Provides communication with local MCP servers running as separate processes
350/// using JSON-RPC 2.0 over stdin/stdout pipes. This transport is ideal for
351/// local tools, development servers, and sandboxed environments where you need
352/// process isolation and direct control over the MCP server lifecycle.
353///
354/// # Features
355///
356/// - **Process Isolation**: Each MCP server runs in its own process for security
357/// - **No Network Overhead**: Direct pipe communication for maximum performance
358/// - **Environment Control**: Full control over working directory and environment variables
359/// - **Resource Management**: Automatic process cleanup and lifecycle management
360/// - **Synchronous Communication**: Request/response correlation over stdin/stdout
361/// - **Error Handling**: Comprehensive process and communication error handling
362///
363/// # Use Cases
364///
365/// - **Local Development**: Running MCP servers during development and testing
366/// - **Filesystem Tools**: Local file operations and system utilities
367/// - **Sandboxed Execution**: Isolated execution environments for security
368/// - **Custom Tools**: Private or proprietary MCP servers
369/// - **CI/CD Integration**: Running MCP servers in automated environments
370///
371/// # Example Usage
372///
373/// ```rust,no_run
374/// use mistralrs_mcp::transport::{ProcessTransport, McpTransport};
375/// use std::collections::HashMap;
376///
377/// #[tokio::main]
378/// async fn main() -> anyhow::Result<()> {
379///     // Basic process transport
380///     let transport = ProcessTransport::new(
381///         "mcp-server-filesystem".to_string(),
382///         vec!["--root".to_string(), "/tmp".to_string()],
383///         None,
384///         None
385///     ).await?;
386///
387///     // With custom working directory and environment
388///     let mut env = HashMap::new();
389///     env.insert("MCP_LOG_LEVEL".to_string(), "debug".to_string());
390///     env.insert("MCP_TIMEOUT".to_string(), "30".to_string());
391///
392///     let transport = ProcessTransport::new(
393///         "/usr/local/bin/my-mcp-server".to_string(),
394///         vec!["--config".to_string(), "production.json".to_string()],
395///         Some("/opt/mcp-server".to_string()), // Working directory
396///         Some(env) // Environment variables
397///     ).await?;
398///
399///     // Use the transport for MCP communication
400///     let result = transport.send_request("tools/list", serde_json::Value::Null).await?;
401///     println!("Available tools: {}", result);
402///
403///     Ok(())
404/// }
405/// ```
406///
407/// # Process Management
408///
409/// The transport automatically manages the child process lifecycle:
410/// - Spawns the process with configured arguments and environment
411/// - Sets up stdin/stdout pipes for JSON-RPC communication
412/// - Monitors process health and handles crashes
413/// - Cleans up resources when the transport is dropped or closed
414///
415/// # Communication Protocol
416///
417/// Uses JSON-RPC 2.0 over stdin/stdout with line-delimited messages:
418/// - Each request is a single line of JSON sent to stdin
419/// - Each response is a single line of JSON read from stdout
420/// - Stderr is captured for debugging and error reporting
421pub struct ProcessTransport {
422    child: std::sync::Arc<tokio::sync::Mutex<tokio::process::Child>>,
423    stdin: std::sync::Arc<tokio::sync::Mutex<tokio::process::ChildStdin>>,
424    stdout_reader:
425        std::sync::Arc<tokio::sync::Mutex<tokio::io::BufReader<tokio::process::ChildStdout>>>,
426}
427
428impl ProcessTransport {
429    /// Creates a new process transport by spawning an MCP server process
430    ///
431    /// This constructor spawns a new process with the specified command, arguments,
432    /// and environment, then sets up stdin/stdout pipes for JSON-RPC communication.
433    /// The process is ready to receive MCP requests immediately after creation.
434    ///
435    /// # Arguments
436    ///
437    /// * `command` - The command to execute (e.g., "mcp-server-filesystem", "/usr/bin/python")
438    /// * `args` - Command-line arguments to pass to the process
439    /// * `work_dir` - Optional working directory for the process (defaults to current directory)
440    /// * `env` - Optional environment variables to set for the process
441    ///
442    /// # Returns
443    ///
444    /// A configured ProcessTransport with the spawned process ready for communication
445    ///
446    /// # Errors
447    ///
448    /// - Command not found or not executable
449    /// - Permission denied errors
450    /// - Process spawn failures
451    /// - Pipe setup errors (stdin/stdout/stderr)
452    /// - Working directory access errors
453    ///
454    /// # Example
455    ///
456    /// ```rust,no_run
457    /// use mistralrs_mcp::transport::ProcessTransport;
458    /// use std::collections::HashMap;
459    ///
460    /// #[tokio::main]
461    /// async fn main() -> anyhow::Result<()> {
462    ///     // Simple filesystem server
463    ///     let transport = ProcessTransport::new(
464    ///         "mcp-server-filesystem".to_string(),
465    ///         vec!["--root".to_string(), "/home/user/documents".to_string()],
466    ///         None,
467    ///         None
468    ///     ).await?;
469    ///
470    ///     // Python-based MCP server with custom environment
471    ///     let mut env = HashMap::new();
472    ///     env.insert("PYTHONPATH".to_string(), "/opt/mcp-servers".to_string());
473    ///     env.insert("MCP_DEBUG".to_string(), "1".to_string());
474    ///
475    ///     let transport = ProcessTransport::new(
476    ///         "python".to_string(),
477    ///         vec!["-m".to_string(), "my_mcp_server".to_string(), "--port".to_string(), "8080".to_string()],
478    ///         Some("/opt/mcp-servers".to_string()),
479    ///         Some(env)
480    ///     ).await?;
481    ///
482    ///     Ok(())
483    /// }
484    /// ```
485    pub async fn new(
486        command: String,
487        args: Vec<String>,
488        work_dir: Option<String>,
489        env: Option<HashMap<String, String>>,
490    ) -> Result<Self> {
491        use tokio::io::BufReader;
492        use tokio::process::Command;
493
494        let mut cmd = Command::new(command);
495        cmd.args(args)
496            .stdin(std::process::Stdio::piped())
497            .stdout(std::process::Stdio::piped())
498            .stderr(std::process::Stdio::piped());
499
500        if let Some(dir) = work_dir {
501            cmd.current_dir(dir);
502        }
503
504        if let Some(env_vars) = env {
505            for (key, value) in env_vars {
506                cmd.env(key, value);
507            }
508        }
509
510        let mut child = cmd.spawn()?;
511        let stdin = child
512            .stdin
513            .take()
514            .ok_or_else(|| anyhow::anyhow!("Failed to get stdin handle"))?;
515        let stdout = child
516            .stdout
517            .take()
518            .ok_or_else(|| anyhow::anyhow!("Failed to get stdout handle"))?;
519        let stdout_reader = BufReader::new(stdout);
520
521        Ok(Self {
522            child: std::sync::Arc::new(tokio::sync::Mutex::new(child)),
523            stdin: std::sync::Arc::new(tokio::sync::Mutex::new(stdin)),
524            stdout_reader: std::sync::Arc::new(tokio::sync::Mutex::new(stdout_reader)),
525        })
526    }
527}
528
529#[async_trait::async_trait]
530impl McpTransport for ProcessTransport {
531    /// Sends an MCP request to the child process and returns the response
532    ///
533    /// This method implements JSON-RPC 2.0 over stdin/stdout pipes. It sends
534    /// a line-delimited JSON request to the process stdin and reads the
535    /// corresponding response from stdout. Communication is synchronous with
536    /// proper request/response correlation.
537    ///
538    /// # Arguments
539    ///
540    /// * `method` - The MCP method name (e.g., "tools/list", "tools/call", "resources/read")
541    /// * `params` - JSON parameters for the method call
542    ///
543    /// # Returns
544    ///
545    /// The result portion of the JSON-RPC response
546    ///
547    /// # Errors
548    ///
549    /// - Process communication errors (broken pipes)
550    /// - Process crashes or unexpected termination
551    /// - JSON serialization/deserialization errors
552    /// - MCP server errors (returned in JSON-RPC error field)
553    /// - I/O errors on stdin/stdout
554    ///
555    /// # Example
556    ///
557    /// ```rust,no_run
558    /// use mistralrs_mcp::transport::{ProcessTransport, McpTransport};
559    /// use serde_json::json;
560    ///
561    /// #[tokio::main]
562    /// async fn main() -> anyhow::Result<()> {
563    ///     let transport = ProcessTransport::new(
564    ///         "mcp-server-filesystem".to_string(),
565    ///         vec!["--root".to_string(), "/tmp".to_string()],
566    ///         None,
567    ///         None
568    ///     ).await?;
569    ///     
570    ///     // List available tools
571    ///     let tools = transport.send_request("tools/list", serde_json::Value::Null).await?;
572    ///     
573    ///     // Call a specific tool
574    ///     let params = json!({
575    ///         "name": "read_file",
576    ///         "arguments": {"path": "/tmp/example.txt"}
577    ///     });
578    ///     let result = transport.send_request("tools/call", params).await?;
579    ///     
580    ///     Ok(())
581    /// }
582    /// ```
583    async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
584        use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
585
586        // Ensure params is an object, not null
587        let params = if params.is_null() {
588            serde_json::json!({})
589        } else {
590            params
591        };
592
593        let request_body = serde_json::json!({
594            "jsonrpc": "2.0",
595            "id": 1,
596            "method": method,
597            "params": params
598        });
599
600        // Send request via stdin
601        let request_line = serde_json::to_string(&request_body)? + "\n";
602
603        let mut stdin = self.stdin.lock().await;
604        stdin.write_all(request_line.as_bytes()).await?;
605        stdin.flush().await?;
606        drop(stdin);
607
608        // Read response from stdout
609        let mut stdout_reader = self.stdout_reader.lock().await;
610        let mut response_line = String::new();
611        stdout_reader.read_line(&mut response_line).await?;
612        drop(stdout_reader);
613
614        let response_body: Value = serde_json::from_str(response_line.trim())?;
615
616        // Check for JSON-RPC errors
617        if let Some(error) = response_body.get("error") {
618            return Err(anyhow::anyhow!("MCP server error: {}", error));
619        }
620
621        response_body
622            .get("result")
623            .cloned()
624            .ok_or_else(|| anyhow::anyhow!("No result in MCP response"))
625    }
626
627    /// Tests the process connection by sending a ping request
628    ///
629    /// Sends a "ping" method call to verify that the MCP server process is
630    /// responsive and the stdin/stdout communication is working properly.
631    ///
632    /// # Returns
633    ///
634    /// Ok(()) if the ping was successful
635    ///
636    /// # Errors
637    ///
638    /// - Process communication errors
639    /// - Process crashed or terminated
640    /// - Broken stdin/stdout pipes
641    async fn ping(&self) -> Result<()> {
642        self.send_request("ping", Value::Null).await?;
643        Ok(())
644    }
645
646    /// Terminates the child process and cleans up resources
647    ///
648    /// This method forcefully terminates the MCP server process and closes
649    /// all associated pipes. Any pending requests will fail after this call.
650    /// The transport cannot be used after closing.
651    ///
652    /// # Returns
653    ///
654    /// Ok(()) if the process was terminated successfully
655    ///
656    /// # Errors
657    ///
658    /// - Process termination errors
659    /// - Resource cleanup failures
660    ///
661    /// # Note
662    ///
663    /// This method sends SIGKILL to the process, which may not allow for
664    /// graceful cleanup. Consider implementing graceful shutdown through
665    /// MCP protocol methods before calling this method.
666    async fn close(&self) -> Result<()> {
667        let mut child = self.child.lock().await;
668        child.kill().await?;
669        Ok(())
670    }
671}
672
673/// WebSocket-based MCP transport
674///
675/// Provides real-time bidirectional communication with MCP servers over WebSocket connections.
676/// This transport supports secure connections (WSS), Bearer token authentication, and concurrent
677/// request/response handling with proper JSON-RPC 2.0 message correlation.
678///
679/// # Features
680///
681/// - **Async WebSocket Communication**: Built on tokio-tungstenite for high-performance async I/O
682/// - **Request/Response Matching**: Automatic correlation of responses using atomic request IDs
683/// - **Bearer Token Support**: Authentication via Authorization header during handshake
684/// - **Connection Management**: Proper ping/pong and connection lifecycle handling
685/// - **Concurrent Operations**: Split stream architecture allows simultaneous read/write operations
686///
687/// # Architecture
688///
689/// The transport uses a split-stream design where the WebSocket connection is divided into
690/// separate read and write halves, each protected by async mutexes. This allows concurrent
691/// operations while maintaining thread safety. Request IDs are generated atomically to ensure
692/// unique identification of requests and proper response correlation.
693///
694/// # Example Usage
695///
696/// ```rust,no_run
697/// use mistralrs_mcp::transport::{WebSocketTransport, McpTransport};
698/// use std::collections::HashMap;
699///
700/// #[tokio::main]
701/// async fn main() -> anyhow::Result<()> {
702///     // Create headers with Bearer token
703///     let mut headers = HashMap::new();
704///     headers.insert("Authorization".to_string(), "Bearer your-token".to_string());
705///
706///     // Connect to WebSocket MCP server
707///     let transport = WebSocketTransport::new(
708///         "wss://api.example.com/mcp".to_string(),
709///         Some(30), // 30 second timeout
710///         Some(headers)
711///     ).await?;
712///
713///     // Use the transport for MCP communication
714///     let result = transport.send_request("tools/list", serde_json::Value::Null).await?;
715///     println!("Available tools: {}", result);
716///
717///     Ok(())
718/// }
719/// ```
720///
721/// # Protocol Compliance
722///
723/// This transport implements the Model Context Protocol (MCP) specification over WebSocket,
724/// adhering to JSON-RPC 2.0 message format with proper error handling and response correlation.
725pub struct WebSocketTransport {
726    write: std::sync::Arc<
727        tokio::sync::Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
728    >,
729    read:
730        std::sync::Arc<tokio::sync::Mutex<SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>>,
731    request_id: std::sync::Arc<std::sync::atomic::AtomicU64>,
732}
733
734impl WebSocketTransport {
735    /// Creates a new WebSocket transport connection to an MCP server
736    ///
737    /// # Arguments
738    ///
739    /// * `url` - WebSocket URL (ws:// or wss://)
740    /// * `_timeout_secs` - Connection timeout (currently unused, reserved for future use)
741    /// * `headers` - Optional HTTP headers for WebSocket handshake (e.g., Bearer tokens)
742    ///
743    /// # Returns
744    ///
745    /// A configured WebSocketTransport ready for MCP communication
746    ///
747    /// # Errors
748    ///
749    /// - Invalid URL format
750    /// - WebSocket connection failure
751    /// - Header parsing errors
752    /// - Network connectivity issues
753    ///
754    /// # Example
755    ///
756    /// ```rust,no_run
757    /// use mistralrs_mcp::transport::WebSocketTransport;
758    /// use std::collections::HashMap;
759    ///
760    /// #[tokio::main]
761    /// async fn main() -> anyhow::Result<()> {
762    ///     let mut headers = HashMap::new();
763    ///     headers.insert("Authorization".to_string(), "Bearer token123".to_string());
764    ///     
765    ///     let transport = WebSocketTransport::new(
766    ///         "wss://mcp.example.com/api".to_string(),
767    ///         Some(30),
768    ///         Some(headers)
769    ///     ).await?;
770    ///     
771    ///     Ok(())
772    /// }
773    /// ```
774    pub async fn new(
775        url: String,
776        _timeout_secs: Option<u64>,
777        headers: Option<HashMap<String, String>>,
778    ) -> Result<Self> {
779        // Create request with headers
780        let uri: Uri = url
781            .parse()
782            .map_err(|e| anyhow::anyhow!("Invalid WebSocket URL: {}", e))?;
783        let mut request = Request::builder()
784            .uri(uri)
785            .body(())
786            .map_err(|e| anyhow::anyhow!("Failed to create WebSocket request: {}", e))?;
787
788        // Add headers if provided
789        if let Some(headers) = headers {
790            let req_headers = request.headers_mut();
791            for (key, value) in headers {
792                let header_name = key
793                    .parse::<http::header::HeaderName>()
794                    .map_err(|e| anyhow::anyhow!("Invalid header key: {}", e))?;
795                let header_value = value
796                    .parse::<http::header::HeaderValue>()
797                    .map_err(|e| anyhow::anyhow!("Invalid header value: {}", e))?;
798                req_headers.insert(header_name, header_value);
799            }
800        }
801
802        // Connect to WebSocket
803        let (ws_stream, _) = connect_async(request)
804            .await
805            .map_err(|e| anyhow::anyhow!("WebSocket connection failed: {}", e))?;
806
807        // Split the stream
808        let (write, read) = ws_stream.split();
809
810        Ok(Self {
811            write: std::sync::Arc::new(tokio::sync::Mutex::new(write)),
812            read: std::sync::Arc::new(tokio::sync::Mutex::new(read)),
813            request_id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)),
814        })
815    }
816}
817
818#[async_trait::async_trait]
819impl McpTransport for WebSocketTransport {
820    /// Sends an MCP request over WebSocket and waits for the corresponding response
821    ///
822    /// This method implements the JSON-RPC 2.0 protocol over WebSocket, handling:
823    /// - Unique request ID generation for response correlation
824    /// - Concurrent request processing with proper message ordering
825    /// - Error handling for both transport and protocol errors
826    /// - Message filtering to match responses with requests
827    ///
828    /// # Arguments
829    ///
830    /// * `method` - The MCP method name (e.g., "tools/list", "tools/call")
831    /// * `params` - JSON parameters for the method call
832    ///
833    /// # Returns
834    ///
835    /// The result portion of the JSON-RPC response
836    ///
837    /// # Errors
838    ///
839    /// - WebSocket connection errors
840    /// - JSON serialization/deserialization errors  
841    /// - MCP server errors (returned in JSON-RPC error field)
842    /// - Timeout or connection closure
843    ///
844    /// # Example
845    ///
846    /// ```rust,no_run
847    /// use mistralrs_mcp::transport::{WebSocketTransport, McpTransport};
848    /// use serde_json::json;
849    ///
850    /// #[tokio::main]
851    /// async fn main() -> anyhow::Result<()> {
852    ///     let transport = WebSocketTransport::new(
853    ///         "wss://api.example.com/mcp".to_string(),
854    ///         None,
855    ///         None
856    ///     ).await?;
857    ///     
858    ///     // List available tools
859    ///     let tools = transport.send_request("tools/list", serde_json::Value::Null).await?;
860    ///     
861    ///     // Call a specific tool
862    ///     let params = json!({"query": "example search"});
863    ///     let result = transport.send_request("tools/call", params).await?;
864    ///     
865    ///     Ok(())
866    /// }
867    /// ```
868    async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
869        // Ensure params is an object, not null
870        let params = if params.is_null() {
871            serde_json::json!({})
872        } else {
873            params
874        };
875
876        // Generate unique request ID
877        let id = self
878            .request_id
879            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
880
881        let request_body = serde_json::json!({
882            "jsonrpc": "2.0",
883            "id": id,
884            "method": method,
885            "params": params
886        });
887
888        // Send request
889        let message = Message::Text(serde_json::to_string(&request_body)?);
890
891        {
892            let mut write = self.write.lock().await;
893            write
894                .send(message)
895                .await
896                .map_err(|e| anyhow::anyhow!("Failed to send WebSocket message: {}", e))?;
897        }
898
899        // Read response
900        loop {
901            let mut read = self.read.lock().await;
902            let msg = read
903                .next()
904                .await
905                .ok_or_else(|| anyhow::anyhow!("WebSocket connection closed"))?
906                .map_err(|e| anyhow::anyhow!("WebSocket read error: {}", e))?;
907            drop(read);
908
909            match msg {
910                Message::Text(text) => {
911                    let response_body: Value = serde_json::from_str(&text)?;
912
913                    // Check if this is the response to our request
914                    if let Some(response_id) = response_body.get("id").and_then(|v| v.as_u64()) {
915                        if response_id == id {
916                            // Check for JSON-RPC errors
917                            if let Some(error) = response_body.get("error") {
918                                return Err(anyhow::anyhow!("MCP server error: {}", error));
919                            }
920
921                            return response_body
922                                .get("result")
923                                .cloned()
924                                .ok_or_else(|| anyhow::anyhow!("No result in MCP response"));
925                        }
926                    }
927                    // If it's not our response, continue reading
928                }
929                Message::Binary(_) => {
930                    // Handle binary messages if needed, for now skip
931                    continue;
932                }
933                Message::Close(_) => {
934                    return Err(anyhow::anyhow!("WebSocket connection closed by server"));
935                }
936                Message::Ping(_) | Message::Pong(_) => {
937                    // Handle ping/pong frames, continue reading
938                    continue;
939                }
940                Message::Frame(_) => {
941                    // Raw frames, continue reading
942                    continue;
943                }
944            }
945        }
946    }
947
948    /// Sends a WebSocket ping frame to test connection health
949    ///
950    /// This method sends a ping frame to the server and expects a pong response,
951    /// which helps verify that the WebSocket connection is still active and responsive.
952    ///
953    /// # Returns
954    ///
955    /// Ok(()) if the ping was sent successfully
956    ///
957    /// # Errors
958    ///
959    /// - WebSocket send errors
960    /// - Connection closure
961    async fn ping(&self) -> Result<()> {
962        let ping_message = Message::Ping(vec![]);
963        let mut write = self.write.lock().await;
964        write
965            .send(ping_message)
966            .await
967            .map_err(|e| anyhow::anyhow!("Failed to send ping: {}", e))?;
968        Ok(())
969    }
970
971    /// Gracefully closes the WebSocket connection
972    ///
973    /// Sends a close frame to the server to properly terminate the connection
974    /// according to the WebSocket protocol. The server should respond with its
975    /// own close frame to complete the closing handshake.
976    ///
977    /// # Returns
978    ///
979    /// Ok(()) if the close frame was sent successfully
980    ///
981    /// # Errors
982    ///
983    /// - WebSocket send errors
984    /// - Connection already closed
985    async fn close(&self) -> Result<()> {
986        let close_message = Message::Close(None);
987        let mut write = self.write.lock().await;
988        write
989            .send(close_message)
990            .await
991            .map_err(|e| anyhow::anyhow!("Failed to send close message: {}", e))?;
992        Ok(())
993    }
994}