mistralrs_mcp/transport.rs
1use anyhow::Result;
2use futures_util::stream::{SplitSink, SplitStream};
3use futures_util::{SinkExt, StreamExt};
4use http::{Request, Uri};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::time::Duration;
8use tokio::net::TcpStream;
9use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
10
11/// Transport layer for MCP communication
12#[async_trait::async_trait]
13pub trait McpTransport: Send + Sync {
14 /// Send a JSON-RPC request and receive a response
15 async fn send_request(&self, method: &str, params: Value) -> Result<Value>;
16
17 /// Check if the transport connection is healthy
18 async fn ping(&self) -> Result<()>;
19
20 /// Close the transport connection
21 async fn close(&self) -> Result<()>;
22}
23
24/// HTTP-based MCP transport
25///
26/// Provides communication with MCP servers over HTTP using JSON-RPC 2.0 protocol.
27/// This transport is ideal for RESTful MCP services, public APIs, and servers
28/// behind load balancers. Supports both regular JSON responses and Server-Sent Events (SSE).
29///
30/// # Features
31///
32/// - **HTTP/HTTPS Support**: Secure communication with TLS encryption
33/// - **Server-Sent Events**: Handles streaming responses via SSE format
34/// - **Bearer Token Authentication**: Automatic Authorization header injection
35/// - **Custom Headers**: Support for additional headers (API keys, versioning, etc.)
36/// - **Configurable Timeouts**: Request-level timeout control
37/// - **Error Handling**: Comprehensive JSON-RPC and HTTP error handling
38///
39/// # Use Cases
40///
41/// - **Public MCP APIs**: Connect to hosted MCP services
42/// - **RESTful Services**: Integration with REST-based tool providers
43/// - **Load-Balanced Servers**: Works well behind HTTP load balancers
44/// - **Development/Testing**: Easy debugging with standard HTTP tools
45///
46/// # Example Usage
47///
48/// ```rust,no_run
49/// use mistralrs_mcp::transport::{HttpTransport, McpTransport};
50/// use std::collections::HashMap;
51///
52/// #[tokio::main]
53/// async fn main() -> anyhow::Result<()> {
54/// // Create headers with Bearer token and API version
55/// let mut headers = HashMap::new();
56/// headers.insert("Authorization".to_string(), "Bearer your-api-token".to_string());
57/// headers.insert("X-API-Version".to_string(), "v1".to_string());
58///
59/// // Connect to HTTP MCP server
60/// let transport = HttpTransport::new(
61/// "https://api.example.com/mcp".to_string(),
62/// Some(30), // 30 second timeout
63/// Some(headers)
64/// )?;
65///
66/// // Use the transport for MCP communication
67/// let result = transport.send_request("tools/list", serde_json::Value::Null).await?;
68/// println!("Available tools: {}", result);
69///
70/// Ok(())
71/// }
72/// ```
73///
74/// # Protocol Support
75///
76/// This transport implements JSON-RPC 2.0 over HTTP with support for:
77/// - Standard JSON responses
78/// - Server-Sent Events (SSE) for streaming data
79/// - Bearer token authentication
80/// - Custom HTTP headers
81pub struct HttpTransport {
82 client: reqwest::Client,
83 base_url: String,
84 headers: HashMap<String, String>,
85}
86
87impl HttpTransport {
88 /// Creates a new HTTP transport for MCP communication
89 ///
90 /// # Arguments
91 ///
92 /// * `base_url` - Base URL of the MCP server (http:// or https://)
93 /// * `timeout_secs` - Optional timeout for HTTP requests in seconds (defaults to 30s)
94 /// * `headers` - Optional custom headers to include in all requests
95 ///
96 /// # Returns
97 ///
98 /// A configured HttpTransport ready for MCP communication
99 ///
100 /// # Errors
101 ///
102 /// - Invalid URL format
103 /// - HTTP client configuration errors
104 /// - TLS/SSL setup failures
105 ///
106 /// # Example
107 ///
108 /// ```rust,no_run
109 /// use mistralrs_mcp::transport::HttpTransport;
110 /// use std::collections::HashMap;
111 ///
112 /// // Basic HTTP transport
113 /// let transport = HttpTransport::new(
114 /// "https://api.example.com/mcp".to_string(),
115 /// Some(60), // 1 minute timeout
116 /// None
117 /// )?;
118 ///
119 /// // With custom headers and authentication
120 /// let mut headers = HashMap::new();
121 /// headers.insert("Authorization".to_string(), "Bearer token123".to_string());
122 /// headers.insert("X-Client-Version".to_string(), "1.0.0".to_string());
123 ///
124 /// let transport = HttpTransport::new(
125 /// "https://secure-api.example.com/mcp".to_string(),
126 /// Some(30),
127 /// Some(headers)
128 /// )?;
129 /// # Ok::<(), anyhow::Error>(())
130 /// ```
131 pub fn new(
132 base_url: String,
133 timeout_secs: Option<u64>,
134 headers: Option<HashMap<String, String>>,
135 ) -> Result<Self> {
136 let timeout = timeout_secs
137 .map(Duration::from_secs)
138 .unwrap_or(Duration::from_secs(30));
139 let client = reqwest::Client::builder().timeout(timeout).build()?;
140
141 Ok(Self {
142 client,
143 base_url,
144 headers: headers.unwrap_or_default(),
145 })
146 }
147
148 /// Parse Server-Sent Events response to extract JSON-RPC message
149 ///
150 /// Handles SSE format used by some MCP servers for streaming responses.
151 /// SSE format: `data: <json>\n\n` or `event: <type>\ndata: <json>\n\n`
152 ///
153 /// # Arguments
154 ///
155 /// * `sse_text` - Raw SSE response text from the server
156 ///
157 /// # Returns
158 ///
159 /// Parsed JSON value from the SSE data field
160 ///
161 /// # Errors
162 ///
163 /// - No valid JSON data found in SSE response
164 /// - Malformed SSE format
165 /// - JSON parsing errors
166 fn parse_sse_response(sse_text: &str) -> Result<Value> {
167 // SSE format: data: <json>\n\n or event: <type>\ndata: <json>\n\n
168 let mut json_data = None;
169
170 for line in sse_text.lines() {
171 let line = line.trim();
172
173 // Skip empty lines and comments
174 if line.is_empty() || line.starts_with(':') {
175 continue;
176 }
177
178 // Parse SSE field
179 if let Some((field, value)) = line.split_once(':') {
180 let field = field.trim();
181 let value = value.trim();
182
183 match field {
184 "data" => {
185 // Try to parse the JSON data
186 if let Ok(parsed) = serde_json::from_str::<Value>(value) {
187 json_data = Some(parsed);
188 break;
189 }
190 }
191 "event" => {
192 // Handle different event types if needed
193 continue;
194 }
195 _ => {
196 // Ignore other SSE fields like id, retry, etc.
197 continue;
198 }
199 }
200 }
201 }
202
203 json_data.ok_or_else(|| anyhow::anyhow!("No valid JSON data found in SSE response"))
204 }
205}
206
207#[async_trait::async_trait]
208impl McpTransport for HttpTransport {
209 /// Sends an MCP request over HTTP and returns the response
210 ///
211 /// This method implements JSON-RPC 2.0 over HTTP with support for both
212 /// standard JSON responses and Server-Sent Events (SSE). It handles
213 /// authentication, custom headers, and comprehensive error reporting.
214 ///
215 /// # Arguments
216 ///
217 /// * `method` - The MCP method name (e.g., "tools/list", "tools/call", "resources/read")
218 /// * `params` - JSON parameters for the method call
219 ///
220 /// # Returns
221 ///
222 /// The result portion of the JSON-RPC response
223 ///
224 /// # Errors
225 ///
226 /// - HTTP connection errors (network issues, DNS resolution)
227 /// - HTTP status errors (4xx, 5xx responses)
228 /// - JSON serialization/deserialization errors
229 /// - MCP server errors (returned in JSON-RPC error field)
230 /// - SSE parsing errors for streaming responses
231 ///
232 /// # Example
233 ///
234 /// ```rust,no_run
235 /// use mistralrs_mcp::transport::{HttpTransport, McpTransport};
236 /// use serde_json::json;
237 ///
238 /// #[tokio::main]
239 /// async fn main() -> anyhow::Result<()> {
240 /// let transport = HttpTransport::new(
241 /// "https://api.example.com/mcp".to_string(),
242 /// None,
243 /// None
244 /// )?;
245 ///
246 /// // List available tools
247 /// let tools = transport.send_request("tools/list", serde_json::Value::Null).await?;
248 ///
249 /// // Call a specific tool
250 /// let params = json!({
251 /// "name": "search",
252 /// "arguments": {"query": "example search"}
253 /// });
254 /// let result = transport.send_request("tools/call", params).await?;
255 ///
256 /// Ok(())
257 /// }
258 /// ```
259 async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
260 // Ensure params is an object, not null
261 let params = if params.is_null() {
262 serde_json::json!({})
263 } else {
264 params
265 };
266
267 let request_body = serde_json::json!({
268 "jsonrpc": "2.0",
269 "id": 1,
270 "method": method,
271 "params": params
272 });
273
274 let mut request_builder = self
275 .client
276 .post(&self.base_url)
277 .json(&request_body)
278 .header("Accept", "application/json, text/event-stream");
279
280 // Add custom headers
281 for (key, value) in &self.headers {
282 request_builder = request_builder.header(key, value);
283 }
284
285 let response = request_builder.send().await?;
286
287 // Check content type and handle accordingly
288 let content_type = response
289 .headers()
290 .get("content-type")
291 .and_then(|v| v.to_str().ok())
292 .unwrap_or("");
293
294 let response_body: Value = if content_type.contains("text/event-stream") {
295 // Handle Server-Sent Events
296 let response_text = response.text().await?;
297 Self::parse_sse_response(&response_text)?
298 } else {
299 // Handle regular JSON response
300 response.json().await?
301 };
302
303 // Check for JSON-RPC errors
304 if let Some(error) = response_body.get("error") {
305 return Err(anyhow::anyhow!("MCP server error: {}", error));
306 }
307
308 response_body
309 .get("result")
310 .cloned()
311 .ok_or_else(|| anyhow::anyhow!("No result in MCP response"))
312 }
313
314 /// Tests the HTTP connection by sending a ping request
315 ///
316 /// Sends a "ping" method call to verify that the MCP server is responsive
317 /// and the HTTP connection is working properly.
318 ///
319 /// # Returns
320 ///
321 /// Ok(()) if the ping was successful
322 ///
323 /// # Errors
324 ///
325 /// - HTTP connection errors
326 /// - Server unavailable or unresponsive
327 /// - Authentication failures
328 async fn ping(&self) -> Result<()> {
329 self.send_request("ping", Value::Null).await?;
330 Ok(())
331 }
332
333 /// Closes the HTTP transport connection
334 ///
335 /// HTTP connections are stateless and managed by the underlying HTTP client,
336 /// so this method is a no-op but provided for interface compatibility.
337 ///
338 /// # Returns
339 ///
340 /// Always returns Ok(()) as HTTP connections don't require explicit cleanup
341 async fn close(&self) -> Result<()> {
342 // HTTP connections don't need explicit closing
343 Ok(())
344 }
345}
346
347/// Process-based MCP transport using stdin/stdout communication
348///
349/// Provides communication with local MCP servers running as separate processes
350/// using JSON-RPC 2.0 over stdin/stdout pipes. This transport is ideal for
351/// local tools, development servers, and sandboxed environments where you need
352/// process isolation and direct control over the MCP server lifecycle.
353///
354/// # Features
355///
356/// - **Process Isolation**: Each MCP server runs in its own process for security
357/// - **No Network Overhead**: Direct pipe communication for maximum performance
358/// - **Environment Control**: Full control over working directory and environment variables
359/// - **Resource Management**: Automatic process cleanup and lifecycle management
360/// - **Synchronous Communication**: Request/response correlation over stdin/stdout
361/// - **Error Handling**: Comprehensive process and communication error handling
362///
363/// # Use Cases
364///
365/// - **Local Development**: Running MCP servers during development and testing
366/// - **Filesystem Tools**: Local file operations and system utilities
367/// - **Sandboxed Execution**: Isolated execution environments for security
368/// - **Custom Tools**: Private or proprietary MCP servers
369/// - **CI/CD Integration**: Running MCP servers in automated environments
370///
371/// # Example Usage
372///
373/// ```rust,no_run
374/// use mistralrs_mcp::transport::{ProcessTransport, McpTransport};
375/// use std::collections::HashMap;
376///
377/// #[tokio::main]
378/// async fn main() -> anyhow::Result<()> {
379/// // Basic process transport
380/// let transport = ProcessTransport::new(
381/// "mcp-server-filesystem".to_string(),
382/// vec!["--root".to_string(), "/tmp".to_string()],
383/// None,
384/// None
385/// ).await?;
386///
387/// // With custom working directory and environment
388/// let mut env = HashMap::new();
389/// env.insert("MCP_LOG_LEVEL".to_string(), "debug".to_string());
390/// env.insert("MCP_TIMEOUT".to_string(), "30".to_string());
391///
392/// let transport = ProcessTransport::new(
393/// "/usr/local/bin/my-mcp-server".to_string(),
394/// vec!["--config".to_string(), "production.json".to_string()],
395/// Some("/opt/mcp-server".to_string()), // Working directory
396/// Some(env) // Environment variables
397/// ).await?;
398///
399/// // Use the transport for MCP communication
400/// let result = transport.send_request("tools/list", serde_json::Value::Null).await?;
401/// println!("Available tools: {}", result);
402///
403/// Ok(())
404/// }
405/// ```
406///
407/// # Process Management
408///
409/// The transport automatically manages the child process lifecycle:
410/// - Spawns the process with configured arguments and environment
411/// - Sets up stdin/stdout pipes for JSON-RPC communication
412/// - Monitors process health and handles crashes
413/// - Cleans up resources when the transport is dropped or closed
414///
415/// # Communication Protocol
416///
417/// Uses JSON-RPC 2.0 over stdin/stdout with line-delimited messages:
418/// - Each request is a single line of JSON sent to stdin
419/// - Each response is a single line of JSON read from stdout
420/// - Stderr is captured for debugging and error reporting
421pub struct ProcessTransport {
422 child: std::sync::Arc<tokio::sync::Mutex<tokio::process::Child>>,
423 stdin: std::sync::Arc<tokio::sync::Mutex<tokio::process::ChildStdin>>,
424 stdout_reader:
425 std::sync::Arc<tokio::sync::Mutex<tokio::io::BufReader<tokio::process::ChildStdout>>>,
426}
427
428impl ProcessTransport {
429 /// Creates a new process transport by spawning an MCP server process
430 ///
431 /// This constructor spawns a new process with the specified command, arguments,
432 /// and environment, then sets up stdin/stdout pipes for JSON-RPC communication.
433 /// The process is ready to receive MCP requests immediately after creation.
434 ///
435 /// # Arguments
436 ///
437 /// * `command` - The command to execute (e.g., "mcp-server-filesystem", "/usr/bin/python")
438 /// * `args` - Command-line arguments to pass to the process
439 /// * `work_dir` - Optional working directory for the process (defaults to current directory)
440 /// * `env` - Optional environment variables to set for the process
441 ///
442 /// # Returns
443 ///
444 /// A configured ProcessTransport with the spawned process ready for communication
445 ///
446 /// # Errors
447 ///
448 /// - Command not found or not executable
449 /// - Permission denied errors
450 /// - Process spawn failures
451 /// - Pipe setup errors (stdin/stdout/stderr)
452 /// - Working directory access errors
453 ///
454 /// # Example
455 ///
456 /// ```rust,no_run
457 /// use mistralrs_mcp::transport::ProcessTransport;
458 /// use std::collections::HashMap;
459 ///
460 /// #[tokio::main]
461 /// async fn main() -> anyhow::Result<()> {
462 /// // Simple filesystem server
463 /// let transport = ProcessTransport::new(
464 /// "mcp-server-filesystem".to_string(),
465 /// vec!["--root".to_string(), "/home/user/documents".to_string()],
466 /// None,
467 /// None
468 /// ).await?;
469 ///
470 /// // Python-based MCP server with custom environment
471 /// let mut env = HashMap::new();
472 /// env.insert("PYTHONPATH".to_string(), "/opt/mcp-servers".to_string());
473 /// env.insert("MCP_DEBUG".to_string(), "1".to_string());
474 ///
475 /// let transport = ProcessTransport::new(
476 /// "python".to_string(),
477 /// vec!["-m".to_string(), "my_mcp_server".to_string(), "--port".to_string(), "8080".to_string()],
478 /// Some("/opt/mcp-servers".to_string()),
479 /// Some(env)
480 /// ).await?;
481 ///
482 /// Ok(())
483 /// }
484 /// ```
485 pub async fn new(
486 command: String,
487 args: Vec<String>,
488 work_dir: Option<String>,
489 env: Option<HashMap<String, String>>,
490 ) -> Result<Self> {
491 use tokio::io::BufReader;
492 use tokio::process::Command;
493
494 let mut cmd = Command::new(command);
495 cmd.args(args)
496 .stdin(std::process::Stdio::piped())
497 .stdout(std::process::Stdio::piped())
498 .stderr(std::process::Stdio::piped());
499
500 if let Some(dir) = work_dir {
501 cmd.current_dir(dir);
502 }
503
504 if let Some(env_vars) = env {
505 for (key, value) in env_vars {
506 cmd.env(key, value);
507 }
508 }
509
510 let mut child = cmd.spawn()?;
511 let stdin = child
512 .stdin
513 .take()
514 .ok_or_else(|| anyhow::anyhow!("Failed to get stdin handle"))?;
515 let stdout = child
516 .stdout
517 .take()
518 .ok_or_else(|| anyhow::anyhow!("Failed to get stdout handle"))?;
519 let stdout_reader = BufReader::new(stdout);
520
521 Ok(Self {
522 child: std::sync::Arc::new(tokio::sync::Mutex::new(child)),
523 stdin: std::sync::Arc::new(tokio::sync::Mutex::new(stdin)),
524 stdout_reader: std::sync::Arc::new(tokio::sync::Mutex::new(stdout_reader)),
525 })
526 }
527}
528
529#[async_trait::async_trait]
530impl McpTransport for ProcessTransport {
531 /// Sends an MCP request to the child process and returns the response
532 ///
533 /// This method implements JSON-RPC 2.0 over stdin/stdout pipes. It sends
534 /// a line-delimited JSON request to the process stdin and reads the
535 /// corresponding response from stdout. Communication is synchronous with
536 /// proper request/response correlation.
537 ///
538 /// # Arguments
539 ///
540 /// * `method` - The MCP method name (e.g., "tools/list", "tools/call", "resources/read")
541 /// * `params` - JSON parameters for the method call
542 ///
543 /// # Returns
544 ///
545 /// The result portion of the JSON-RPC response
546 ///
547 /// # Errors
548 ///
549 /// - Process communication errors (broken pipes)
550 /// - Process crashes or unexpected termination
551 /// - JSON serialization/deserialization errors
552 /// - MCP server errors (returned in JSON-RPC error field)
553 /// - I/O errors on stdin/stdout
554 ///
555 /// # Example
556 ///
557 /// ```rust,no_run
558 /// use mistralrs_mcp::transport::{ProcessTransport, McpTransport};
559 /// use serde_json::json;
560 ///
561 /// #[tokio::main]
562 /// async fn main() -> anyhow::Result<()> {
563 /// let transport = ProcessTransport::new(
564 /// "mcp-server-filesystem".to_string(),
565 /// vec!["--root".to_string(), "/tmp".to_string()],
566 /// None,
567 /// None
568 /// ).await?;
569 ///
570 /// // List available tools
571 /// let tools = transport.send_request("tools/list", serde_json::Value::Null).await?;
572 ///
573 /// // Call a specific tool
574 /// let params = json!({
575 /// "name": "read_file",
576 /// "arguments": {"path": "/tmp/example.txt"}
577 /// });
578 /// let result = transport.send_request("tools/call", params).await?;
579 ///
580 /// Ok(())
581 /// }
582 /// ```
583 async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
584 use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
585
586 // Ensure params is an object, not null
587 let params = if params.is_null() {
588 serde_json::json!({})
589 } else {
590 params
591 };
592
593 let request_body = serde_json::json!({
594 "jsonrpc": "2.0",
595 "id": 1,
596 "method": method,
597 "params": params
598 });
599
600 // Send request via stdin
601 let request_line = serde_json::to_string(&request_body)? + "\n";
602
603 let mut stdin = self.stdin.lock().await;
604 stdin.write_all(request_line.as_bytes()).await?;
605 stdin.flush().await?;
606 drop(stdin);
607
608 // Read response from stdout
609 let mut stdout_reader = self.stdout_reader.lock().await;
610 let mut response_line = String::new();
611 stdout_reader.read_line(&mut response_line).await?;
612 drop(stdout_reader);
613
614 let response_body: Value = serde_json::from_str(response_line.trim())?;
615
616 // Check for JSON-RPC errors
617 if let Some(error) = response_body.get("error") {
618 return Err(anyhow::anyhow!("MCP server error: {}", error));
619 }
620
621 response_body
622 .get("result")
623 .cloned()
624 .ok_or_else(|| anyhow::anyhow!("No result in MCP response"))
625 }
626
627 /// Tests the process connection by sending a ping request
628 ///
629 /// Sends a "ping" method call to verify that the MCP server process is
630 /// responsive and the stdin/stdout communication is working properly.
631 ///
632 /// # Returns
633 ///
634 /// Ok(()) if the ping was successful
635 ///
636 /// # Errors
637 ///
638 /// - Process communication errors
639 /// - Process crashed or terminated
640 /// - Broken stdin/stdout pipes
641 async fn ping(&self) -> Result<()> {
642 self.send_request("ping", Value::Null).await?;
643 Ok(())
644 }
645
646 /// Terminates the child process and cleans up resources
647 ///
648 /// This method forcefully terminates the MCP server process and closes
649 /// all associated pipes. Any pending requests will fail after this call.
650 /// The transport cannot be used after closing.
651 ///
652 /// # Returns
653 ///
654 /// Ok(()) if the process was terminated successfully
655 ///
656 /// # Errors
657 ///
658 /// - Process termination errors
659 /// - Resource cleanup failures
660 ///
661 /// # Note
662 ///
663 /// This method sends SIGKILL to the process, which may not allow for
664 /// graceful cleanup. Consider implementing graceful shutdown through
665 /// MCP protocol methods before calling this method.
666 async fn close(&self) -> Result<()> {
667 let mut child = self.child.lock().await;
668 child.kill().await?;
669 Ok(())
670 }
671}
672
673/// WebSocket-based MCP transport
674///
675/// Provides real-time bidirectional communication with MCP servers over WebSocket connections.
676/// This transport supports secure connections (WSS), Bearer token authentication, and concurrent
677/// request/response handling with proper JSON-RPC 2.0 message correlation.
678///
679/// # Features
680///
681/// - **Async WebSocket Communication**: Built on tokio-tungstenite for high-performance async I/O
682/// - **Request/Response Matching**: Automatic correlation of responses using atomic request IDs
683/// - **Bearer Token Support**: Authentication via Authorization header during handshake
684/// - **Connection Management**: Proper ping/pong and connection lifecycle handling
685/// - **Concurrent Operations**: Split stream architecture allows simultaneous read/write operations
686///
687/// # Architecture
688///
689/// The transport uses a split-stream design where the WebSocket connection is divided into
690/// separate read and write halves, each protected by async mutexes. This allows concurrent
691/// operations while maintaining thread safety. Request IDs are generated atomically to ensure
692/// unique identification of requests and proper response correlation.
693///
694/// # Example Usage
695///
696/// ```rust,no_run
697/// use mistralrs_mcp::transport::{WebSocketTransport, McpTransport};
698/// use std::collections::HashMap;
699///
700/// #[tokio::main]
701/// async fn main() -> anyhow::Result<()> {
702/// // Create headers with Bearer token
703/// let mut headers = HashMap::new();
704/// headers.insert("Authorization".to_string(), "Bearer your-token".to_string());
705///
706/// // Connect to WebSocket MCP server
707/// let transport = WebSocketTransport::new(
708/// "wss://api.example.com/mcp".to_string(),
709/// Some(30), // 30 second timeout
710/// Some(headers)
711/// ).await?;
712///
713/// // Use the transport for MCP communication
714/// let result = transport.send_request("tools/list", serde_json::Value::Null).await?;
715/// println!("Available tools: {}", result);
716///
717/// Ok(())
718/// }
719/// ```
720///
721/// # Protocol Compliance
722///
723/// This transport implements the Model Context Protocol (MCP) specification over WebSocket,
724/// adhering to JSON-RPC 2.0 message format with proper error handling and response correlation.
725pub struct WebSocketTransport {
726 write: std::sync::Arc<
727 tokio::sync::Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
728 >,
729 read:
730 std::sync::Arc<tokio::sync::Mutex<SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>>,
731 request_id: std::sync::Arc<std::sync::atomic::AtomicU64>,
732}
733
734impl WebSocketTransport {
735 /// Creates a new WebSocket transport connection to an MCP server
736 ///
737 /// # Arguments
738 ///
739 /// * `url` - WebSocket URL (ws:// or wss://)
740 /// * `_timeout_secs` - Connection timeout (currently unused, reserved for future use)
741 /// * `headers` - Optional HTTP headers for WebSocket handshake (e.g., Bearer tokens)
742 ///
743 /// # Returns
744 ///
745 /// A configured WebSocketTransport ready for MCP communication
746 ///
747 /// # Errors
748 ///
749 /// - Invalid URL format
750 /// - WebSocket connection failure
751 /// - Header parsing errors
752 /// - Network connectivity issues
753 ///
754 /// # Example
755 ///
756 /// ```rust,no_run
757 /// use mistralrs_mcp::transport::WebSocketTransport;
758 /// use std::collections::HashMap;
759 ///
760 /// #[tokio::main]
761 /// async fn main() -> anyhow::Result<()> {
762 /// let mut headers = HashMap::new();
763 /// headers.insert("Authorization".to_string(), "Bearer token123".to_string());
764 ///
765 /// let transport = WebSocketTransport::new(
766 /// "wss://mcp.example.com/api".to_string(),
767 /// Some(30),
768 /// Some(headers)
769 /// ).await?;
770 ///
771 /// Ok(())
772 /// }
773 /// ```
774 pub async fn new(
775 url: String,
776 _timeout_secs: Option<u64>,
777 headers: Option<HashMap<String, String>>,
778 ) -> Result<Self> {
779 // Create request with headers
780 let uri: Uri = url
781 .parse()
782 .map_err(|e| anyhow::anyhow!("Invalid WebSocket URL: {}", e))?;
783 let mut request = Request::builder()
784 .uri(uri)
785 .body(())
786 .map_err(|e| anyhow::anyhow!("Failed to create WebSocket request: {}", e))?;
787
788 // Add headers if provided
789 if let Some(headers) = headers {
790 let req_headers = request.headers_mut();
791 for (key, value) in headers {
792 let header_name = key
793 .parse::<http::header::HeaderName>()
794 .map_err(|e| anyhow::anyhow!("Invalid header key: {}", e))?;
795 let header_value = value
796 .parse::<http::header::HeaderValue>()
797 .map_err(|e| anyhow::anyhow!("Invalid header value: {}", e))?;
798 req_headers.insert(header_name, header_value);
799 }
800 }
801
802 // Connect to WebSocket
803 let (ws_stream, _) = connect_async(request)
804 .await
805 .map_err(|e| anyhow::anyhow!("WebSocket connection failed: {}", e))?;
806
807 // Split the stream
808 let (write, read) = ws_stream.split();
809
810 Ok(Self {
811 write: std::sync::Arc::new(tokio::sync::Mutex::new(write)),
812 read: std::sync::Arc::new(tokio::sync::Mutex::new(read)),
813 request_id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)),
814 })
815 }
816}
817
818#[async_trait::async_trait]
819impl McpTransport for WebSocketTransport {
820 /// Sends an MCP request over WebSocket and waits for the corresponding response
821 ///
822 /// This method implements the JSON-RPC 2.0 protocol over WebSocket, handling:
823 /// - Unique request ID generation for response correlation
824 /// - Concurrent request processing with proper message ordering
825 /// - Error handling for both transport and protocol errors
826 /// - Message filtering to match responses with requests
827 ///
828 /// # Arguments
829 ///
830 /// * `method` - The MCP method name (e.g., "tools/list", "tools/call")
831 /// * `params` - JSON parameters for the method call
832 ///
833 /// # Returns
834 ///
835 /// The result portion of the JSON-RPC response
836 ///
837 /// # Errors
838 ///
839 /// - WebSocket connection errors
840 /// - JSON serialization/deserialization errors
841 /// - MCP server errors (returned in JSON-RPC error field)
842 /// - Timeout or connection closure
843 ///
844 /// # Example
845 ///
846 /// ```rust,no_run
847 /// use mistralrs_mcp::transport::{WebSocketTransport, McpTransport};
848 /// use serde_json::json;
849 ///
850 /// #[tokio::main]
851 /// async fn main() -> anyhow::Result<()> {
852 /// let transport = WebSocketTransport::new(
853 /// "wss://api.example.com/mcp".to_string(),
854 /// None,
855 /// None
856 /// ).await?;
857 ///
858 /// // List available tools
859 /// let tools = transport.send_request("tools/list", serde_json::Value::Null).await?;
860 ///
861 /// // Call a specific tool
862 /// let params = json!({"query": "example search"});
863 /// let result = transport.send_request("tools/call", params).await?;
864 ///
865 /// Ok(())
866 /// }
867 /// ```
868 async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
869 // Ensure params is an object, not null
870 let params = if params.is_null() {
871 serde_json::json!({})
872 } else {
873 params
874 };
875
876 // Generate unique request ID
877 let id = self
878 .request_id
879 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
880
881 let request_body = serde_json::json!({
882 "jsonrpc": "2.0",
883 "id": id,
884 "method": method,
885 "params": params
886 });
887
888 // Send request
889 let message = Message::Text(serde_json::to_string(&request_body)?);
890
891 {
892 let mut write = self.write.lock().await;
893 write
894 .send(message)
895 .await
896 .map_err(|e| anyhow::anyhow!("Failed to send WebSocket message: {}", e))?;
897 }
898
899 // Read response
900 loop {
901 let mut read = self.read.lock().await;
902 let msg = read
903 .next()
904 .await
905 .ok_or_else(|| anyhow::anyhow!("WebSocket connection closed"))?
906 .map_err(|e| anyhow::anyhow!("WebSocket read error: {}", e))?;
907 drop(read);
908
909 match msg {
910 Message::Text(text) => {
911 let response_body: Value = serde_json::from_str(&text)?;
912
913 // Check if this is the response to our request
914 if let Some(response_id) = response_body.get("id").and_then(|v| v.as_u64()) {
915 if response_id == id {
916 // Check for JSON-RPC errors
917 if let Some(error) = response_body.get("error") {
918 return Err(anyhow::anyhow!("MCP server error: {}", error));
919 }
920
921 return response_body
922 .get("result")
923 .cloned()
924 .ok_or_else(|| anyhow::anyhow!("No result in MCP response"));
925 }
926 }
927 // If it's not our response, continue reading
928 }
929 Message::Binary(_) => {
930 // Handle binary messages if needed, for now skip
931 continue;
932 }
933 Message::Close(_) => {
934 return Err(anyhow::anyhow!("WebSocket connection closed by server"));
935 }
936 Message::Ping(_) | Message::Pong(_) => {
937 // Handle ping/pong frames, continue reading
938 continue;
939 }
940 Message::Frame(_) => {
941 // Raw frames, continue reading
942 continue;
943 }
944 }
945 }
946 }
947
948 /// Sends a WebSocket ping frame to test connection health
949 ///
950 /// This method sends a ping frame to the server and expects a pong response,
951 /// which helps verify that the WebSocket connection is still active and responsive.
952 ///
953 /// # Returns
954 ///
955 /// Ok(()) if the ping was sent successfully
956 ///
957 /// # Errors
958 ///
959 /// - WebSocket send errors
960 /// - Connection closure
961 async fn ping(&self) -> Result<()> {
962 let ping_message = Message::Ping(vec![]);
963 let mut write = self.write.lock().await;
964 write
965 .send(ping_message)
966 .await
967 .map_err(|e| anyhow::anyhow!("Failed to send ping: {}", e))?;
968 Ok(())
969 }
970
971 /// Gracefully closes the WebSocket connection
972 ///
973 /// Sends a close frame to the server to properly terminate the connection
974 /// according to the WebSocket protocol. The server should respond with its
975 /// own close frame to complete the closing handshake.
976 ///
977 /// # Returns
978 ///
979 /// Ok(()) if the close frame was sent successfully
980 ///
981 /// # Errors
982 ///
983 /// - WebSocket send errors
984 /// - Connection already closed
985 async fn close(&self) -> Result<()> {
986 let close_message = Message::Close(None);
987 let mut write = self.write.lock().await;
988 write
989 .send(close_message)
990 .await
991 .map_err(|e| anyhow::anyhow!("Failed to send close message: {}", e))?;
992 Ok(())
993 }
994}