mistralrs_mcp/transport.rs
1use anyhow::Result;
2use futures_util::stream::{SplitSink, SplitStream};
3use futures_util::{SinkExt, StreamExt};
4use http::{Request, Uri};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::time::Duration;
8use tokio::net::TcpStream;
9use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
10
11/// Transport layer for MCP communication
12#[async_trait::async_trait]
13pub trait McpTransport: Send + Sync {
14 /// Send a JSON-RPC request and receive a response
15 async fn send_request(&self, method: &str, params: Value) -> Result<Value>;
16
17 /// Check if the transport connection is healthy
18 async fn ping(&self) -> Result<()>;
19
20 /// Close the transport connection
21 async fn close(&self) -> Result<()>;
22
23 /// Send initialization noification
24 async fn send_initialization_notification(&self) -> Result<()>;
25}
26
27/// HTTP-based MCP transport
28///
29/// Provides communication with MCP servers over HTTP using JSON-RPC 2.0 protocol.
30/// This transport is ideal for RESTful MCP services, public APIs, and servers
31/// behind load balancers. Supports both regular JSON responses and Server-Sent Events (SSE).
32///
33/// # Features
34///
35/// - **HTTP/HTTPS Support**: Secure communication with TLS encryption
36/// - **Server-Sent Events**: Handles streaming responses via SSE format
37/// - **Bearer Token Authentication**: Automatic Authorization header injection
38/// - **Custom Headers**: Support for additional headers (API keys, versioning, etc.)
39/// - **Configurable Timeouts**: Request-level timeout control
40/// - **Error Handling**: Comprehensive JSON-RPC and HTTP error handling
41///
42/// # Use Cases
43///
44/// - **Public MCP APIs**: Connect to hosted MCP services
45/// - **RESTful Services**: Integration with REST-based tool providers
46/// - **Load-Balanced Servers**: Works well behind HTTP load balancers
47/// - **Development/Testing**: Easy debugging with standard HTTP tools
48///
49/// # Example Usage
50///
51/// ```rust,no_run
52/// use mistralrs_mcp::transport::{HttpTransport, McpTransport};
53/// use std::collections::HashMap;
54///
55/// #[tokio::main]
56/// async fn main() -> anyhow::Result<()> {
57/// // Create headers with Bearer token and API version
58/// let mut headers = HashMap::new();
59/// headers.insert("Authorization".to_string(), "Bearer your-api-token".to_string());
60/// headers.insert("X-API-Version".to_string(), "v1".to_string());
61///
62/// // Connect to HTTP MCP server
63/// let transport = HttpTransport::new(
64/// "https://api.example.com/mcp".to_string(),
65/// Some(30), // 30 second timeout
66/// Some(headers)
67/// )?;
68///
69/// // Use the transport for MCP communication
70/// let result = transport.send_request("tools/list", serde_json::Value::Null).await?;
71/// println!("Available tools: {}", result);
72///
73/// Ok(())
74/// }
75/// ```
76///
77/// # Protocol Support
78///
79/// This transport implements JSON-RPC 2.0 over HTTP with support for:
80/// - Standard JSON responses
81/// - Server-Sent Events (SSE) for streaming data
82/// - Bearer token authentication
83/// - Custom HTTP headers
84pub struct HttpTransport {
85 client: reqwest::Client,
86 base_url: String,
87 headers: HashMap<String, String>,
88 request_id: std::sync::Arc<std::sync::atomic::AtomicU64>,
89}
90
91impl HttpTransport {
92 /// Creates a new HTTP transport for MCP communication
93 ///
94 /// # Arguments
95 ///
96 /// * `base_url` - Base URL of the MCP server (http:// or https://)
97 /// * `timeout_secs` - Optional timeout for HTTP requests in seconds (defaults to 30s)
98 /// * `headers` - Optional custom headers to include in all requests
99 ///
100 /// # Returns
101 ///
102 /// A configured HttpTransport ready for MCP communication
103 ///
104 /// # Errors
105 ///
106 /// - Invalid URL format
107 /// - HTTP client configuration errors
108 /// - TLS/SSL setup failures
109 ///
110 /// # Example
111 ///
112 /// ```rust,no_run
113 /// use mistralrs_mcp::transport::HttpTransport;
114 /// use std::collections::HashMap;
115 ///
116 /// // Basic HTTP transport
117 /// let transport = HttpTransport::new(
118 /// "https://api.example.com/mcp".to_string(),
119 /// Some(60), // 1 minute timeout
120 /// None
121 /// )?;
122 ///
123 /// // With custom headers and authentication
124 /// let mut headers = HashMap::new();
125 /// headers.insert("Authorization".to_string(), "Bearer token123".to_string());
126 /// headers.insert("X-Client-Version".to_string(), "1.0.0".to_string());
127 ///
128 /// let transport = HttpTransport::new(
129 /// "https://secure-api.example.com/mcp".to_string(),
130 /// Some(30),
131 /// Some(headers)
132 /// )?;
133 /// # Ok::<(), anyhow::Error>(())
134 /// ```
135 pub fn new(
136 base_url: String,
137 timeout_secs: Option<u64>,
138 headers: Option<HashMap<String, String>>,
139 ) -> Result<Self> {
140 let timeout = timeout_secs
141 .map(Duration::from_secs)
142 .unwrap_or(Duration::from_secs(30));
143 let client = reqwest::Client::builder().timeout(timeout).build()?;
144
145 Ok(Self {
146 client,
147 base_url,
148 headers: headers.unwrap_or_default(),
149 request_id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)),
150 })
151 }
152
153 /// Parse Server-Sent Events response to extract JSON-RPC message
154 ///
155 /// Handles SSE format used by some MCP servers for streaming responses.
156 /// SSE format: `data: <json>\n\n` or `event: <type>\ndata: <json>\n\n`
157 ///
158 /// # Arguments
159 ///
160 /// * `sse_text` - Raw SSE response text from the server
161 ///
162 /// # Returns
163 ///
164 /// Parsed JSON value from the SSE data field
165 ///
166 /// # Errors
167 ///
168 /// - No valid JSON data found in SSE response
169 /// - Malformed SSE format
170 /// - JSON parsing errors
171 fn parse_sse_response(sse_text: &str) -> Result<Value> {
172 // SSE format: data: <json>\n\n or event: <type>\ndata: <json>\n\n
173 let mut json_data = None;
174
175 for line in sse_text.lines() {
176 let line = line.trim();
177
178 // Skip empty lines and comments
179 if line.is_empty() || line.starts_with(':') {
180 continue;
181 }
182
183 // Parse SSE field
184 if let Some((field, value)) = line.split_once(':') {
185 let field = field.trim();
186 let value = value.trim();
187
188 match field {
189 "data" => {
190 // Try to parse the JSON data
191 if let Ok(parsed) = serde_json::from_str::<Value>(value) {
192 json_data = Some(parsed);
193 break;
194 }
195 }
196 "event" => {
197 // Handle different event types if needed
198 continue;
199 }
200 _ => {
201 // Ignore other SSE fields like id, retry, etc.
202 continue;
203 }
204 }
205 }
206 }
207
208 json_data.ok_or_else(|| anyhow::anyhow!("No valid JSON data found in SSE response"))
209 }
210}
211
212#[async_trait::async_trait]
213impl McpTransport for HttpTransport {
214 /// Sends an MCP request over HTTP and returns the response
215 ///
216 /// This method implements JSON-RPC 2.0 over HTTP with support for both
217 /// standard JSON responses and Server-Sent Events (SSE). It handles
218 /// authentication, custom headers, and comprehensive error reporting.
219 ///
220 /// # Arguments
221 ///
222 /// * `method` - The MCP method name (e.g., "tools/list", "tools/call", "resources/read")
223 /// * `params` - JSON parameters for the method call
224 ///
225 /// # Returns
226 ///
227 /// The result portion of the JSON-RPC response
228 ///
229 /// # Errors
230 ///
231 /// - HTTP connection errors (network issues, DNS resolution)
232 /// - HTTP status errors (4xx, 5xx responses)
233 /// - JSON serialization/deserialization errors
234 /// - MCP server errors (returned in JSON-RPC error field)
235 /// - SSE parsing errors for streaming responses
236 ///
237 /// # Example
238 ///
239 /// ```rust,no_run
240 /// use mistralrs_mcp::transport::{HttpTransport, McpTransport};
241 /// use serde_json::json;
242 ///
243 /// #[tokio::main]
244 /// async fn main() -> anyhow::Result<()> {
245 /// let transport = HttpTransport::new(
246 /// "https://api.example.com/mcp".to_string(),
247 /// None,
248 /// None
249 /// )?;
250 ///
251 /// // List available tools
252 /// let tools = transport.send_request("tools/list", serde_json::Value::Null).await?;
253 ///
254 /// // Call a specific tool
255 /// let params = json!({
256 /// "name": "search",
257 /// "arguments": {"query": "example search"}
258 /// });
259 /// let result = transport.send_request("tools/call", params).await?;
260 ///
261 /// Ok(())
262 /// }
263 /// ```
264 async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
265 // Ensure params is an object, not null
266 let params = if params.is_null() {
267 serde_json::json!({})
268 } else {
269 params
270 };
271
272 let id = self
273 .request_id
274 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
275 let request_body = serde_json::json!({
276 "jsonrpc": "2.0",
277 "id": id,
278 "method": method,
279 "params": params
280 });
281
282 let mut request_builder = self
283 .client
284 .post(&self.base_url)
285 .json(&request_body)
286 .header("Accept", "application/json, text/event-stream");
287
288 // Add custom headers
289 for (key, value) in &self.headers {
290 request_builder = request_builder.header(key, value);
291 }
292
293 let response = request_builder.send().await?;
294
295 // Check content type and handle accordingly
296 let content_type = response
297 .headers()
298 .get("content-type")
299 .and_then(|v| v.to_str().ok())
300 .unwrap_or("");
301
302 let response_body: Value = if content_type.contains("text/event-stream") {
303 // Handle Server-Sent Events
304 let response_text = response.text().await?;
305 Self::parse_sse_response(&response_text)?
306 } else {
307 // Handle regular JSON response
308 response.json().await?
309 };
310
311 // Check for JSON-RPC errors
312 if let Some(error) = response_body.get("error") {
313 return Err(anyhow::anyhow!("MCP server error: {}", error));
314 }
315
316 response_body
317 .get("result")
318 .cloned()
319 .ok_or_else(|| anyhow::anyhow!("No result in MCP response"))
320 }
321
322 /// Tests the HTTP connection by sending a ping request
323 ///
324 /// Sends a "ping" method call to verify that the MCP server is responsive
325 /// and the HTTP connection is working properly.
326 ///
327 /// # Returns
328 ///
329 /// Ok(()) if the ping was successful
330 ///
331 /// # Errors
332 ///
333 /// - HTTP connection errors
334 /// - Server unavailable or unresponsive
335 /// - Authentication failures
336 async fn ping(&self) -> Result<()> {
337 self.send_request("ping", Value::Null).await?;
338 Ok(())
339 }
340
341 /// Closes the HTTP transport connection
342 ///
343 /// HTTP connections are stateless and managed by the underlying HTTP client,
344 /// so this method is a no-op but provided for interface compatibility.
345 ///
346 /// # Returns
347 ///
348 /// Always returns Ok(()) as HTTP connections don't require explicit cleanup
349 async fn close(&self) -> Result<()> {
350 // HTTP connections don't need explicit closing
351 Ok(())
352 }
353
354 /// Sends the server a initialization notification to let it know we are done initializing
355 async fn send_initialization_notification(&self) -> Result<()> {
356 let request_body = serde_json::json!({
357 "jsonrpc": "2.0",
358 "method": "notifications/initialized",
359 });
360
361 let mut request_builder = self
362 .client
363 .post(&self.base_url)
364 .json(&request_body)
365 .header("Accept", "application/json, text/event-stream");
366
367 // Add custom headers
368 for (key, value) in &self.headers {
369 request_builder = request_builder.header(key, value);
370 }
371
372 request_builder.send().await?;
373 Ok(())
374 }
375}
376
377/// Process-based MCP transport using stdin/stdout communication
378///
379/// Provides communication with local MCP servers running as separate processes
380/// using JSON-RPC 2.0 over stdin/stdout pipes. This transport is ideal for
381/// local tools, development servers, and sandboxed environments where you need
382/// process isolation and direct control over the MCP server lifecycle.
383///
384/// # Features
385///
386/// - **Process Isolation**: Each MCP server runs in its own process for security
387/// - **No Network Overhead**: Direct pipe communication for maximum performance
388/// - **Environment Control**: Full control over working directory and environment variables
389/// - **Resource Management**: Automatic process cleanup and lifecycle management
390/// - **Synchronous Communication**: Request/response correlation over stdin/stdout
391/// - **Error Handling**: Comprehensive process and communication error handling
392///
393/// # Use Cases
394///
395/// - **Local Development**: Running MCP servers during development and testing
396/// - **Filesystem Tools**: Local file operations and system utilities
397/// - **Sandboxed Execution**: Isolated execution environments for security
398/// - **Custom Tools**: Private or proprietary MCP servers
399/// - **CI/CD Integration**: Running MCP servers in automated environments
400///
401/// # Example Usage
402///
403/// ```rust,no_run
404/// use mistralrs_mcp::transport::{ProcessTransport, McpTransport};
405/// use std::collections::HashMap;
406///
407/// #[tokio::main]
408/// async fn main() -> anyhow::Result<()> {
409/// // Basic process transport
410/// let transport = ProcessTransport::new(
411/// "mcp-server-filesystem".to_string(),
412/// vec!["--root".to_string(), "/tmp".to_string()],
413/// None,
414/// None
415/// ).await?;
416///
417/// // With custom working directory and environment
418/// let mut env = HashMap::new();
419/// env.insert("MCP_LOG_LEVEL".to_string(), "debug".to_string());
420/// env.insert("MCP_TIMEOUT".to_string(), "30".to_string());
421///
422/// let transport = ProcessTransport::new(
423/// "/usr/local/bin/my-mcp-server".to_string(),
424/// vec!["--config".to_string(), "production.json".to_string()],
425/// Some("/opt/mcp-server".to_string()), // Working directory
426/// Some(env) // Environment variables
427/// ).await?;
428///
429/// // Use the transport for MCP communication
430/// let result = transport.send_request("tools/list", serde_json::Value::Null).await?;
431/// println!("Available tools: {}", result);
432///
433/// Ok(())
434/// }
435/// ```
436///
437/// # Process Management
438///
439/// The transport automatically manages the child process lifecycle:
440/// - Spawns the process with configured arguments and environment
441/// - Sets up stdin/stdout pipes for JSON-RPC communication
442/// - Monitors process health and handles crashes
443/// - Cleans up resources when the transport is dropped or closed
444///
445/// # Communication Protocol
446///
447/// Uses JSON-RPC 2.0 over stdin/stdout with line-delimited messages:
448/// - Each request is a single line of JSON sent to stdin
449/// - Each response is a single line of JSON read from stdout
450/// - Stderr is captured for debugging and error reporting
451pub struct ProcessTransport {
452 child: std::sync::Arc<tokio::sync::Mutex<tokio::process::Child>>,
453 request_id: std::sync::Arc<std::sync::atomic::AtomicU64>,
454 stdin: std::sync::Arc<tokio::sync::Mutex<tokio::process::ChildStdin>>,
455 stdout_reader:
456 std::sync::Arc<tokio::sync::Mutex<tokio::io::BufReader<tokio::process::ChildStdout>>>,
457}
458
459impl ProcessTransport {
460 /// Creates a new process transport by spawning an MCP server process
461 ///
462 /// This constructor spawns a new process with the specified command, arguments,
463 /// and environment, then sets up stdin/stdout pipes for JSON-RPC communication.
464 /// The process is ready to receive MCP requests immediately after creation.
465 ///
466 /// # Arguments
467 ///
468 /// * `command` - The command to execute (e.g., "mcp-server-filesystem", "/usr/bin/python")
469 /// * `args` - Command-line arguments to pass to the process
470 /// * `work_dir` - Optional working directory for the process (defaults to current directory)
471 /// * `env` - Optional environment variables to set for the process
472 ///
473 /// # Returns
474 ///
475 /// A configured ProcessTransport with the spawned process ready for communication
476 ///
477 /// # Errors
478 ///
479 /// - Command not found or not executable
480 /// - Permission denied errors
481 /// - Process spawn failures
482 /// - Pipe setup errors (stdin/stdout/stderr)
483 /// - Working directory access errors
484 ///
485 /// # Example
486 ///
487 /// ```rust,no_run
488 /// use mistralrs_mcp::transport::ProcessTransport;
489 /// use std::collections::HashMap;
490 ///
491 /// #[tokio::main]
492 /// async fn main() -> anyhow::Result<()> {
493 /// // Simple filesystem server
494 /// let transport = ProcessTransport::new(
495 /// "mcp-server-filesystem".to_string(),
496 /// vec!["--root".to_string(), "/home/user/documents".to_string()],
497 /// None,
498 /// None
499 /// ).await?;
500 ///
501 /// // Python-based MCP server with custom environment
502 /// let mut env = HashMap::new();
503 /// env.insert("PYTHONPATH".to_string(), "/opt/mcp-servers".to_string());
504 /// env.insert("MCP_DEBUG".to_string(), "1".to_string());
505 ///
506 /// let transport = ProcessTransport::new(
507 /// "python".to_string(),
508 /// vec!["-m".to_string(), "my_mcp_server".to_string(), "--port".to_string(), "8080".to_string()],
509 /// Some("/opt/mcp-servers".to_string()),
510 /// Some(env)
511 /// ).await?;
512 ///
513 /// Ok(())
514 /// }
515 /// ```
516 pub async fn new(
517 command: String,
518 args: Vec<String>,
519 work_dir: Option<String>,
520 env: Option<HashMap<String, String>>,
521 ) -> Result<Self> {
522 use tokio::io::BufReader;
523 use tokio::process::Command;
524
525 let mut cmd = Command::new(command);
526 cmd.args(args)
527 .stdin(std::process::Stdio::piped())
528 .stdout(std::process::Stdio::piped())
529 .stderr(std::process::Stdio::piped());
530
531 if let Some(dir) = work_dir {
532 cmd.current_dir(dir);
533 }
534
535 if let Some(env_vars) = env {
536 for (key, value) in env_vars {
537 cmd.env(key, value);
538 }
539 }
540
541 let mut child = cmd.spawn()?;
542 let stdin = child
543 .stdin
544 .take()
545 .ok_or_else(|| anyhow::anyhow!("Failed to get stdin handle"))?;
546 let stdout = child
547 .stdout
548 .take()
549 .ok_or_else(|| anyhow::anyhow!("Failed to get stdout handle"))?;
550 let stdout_reader = BufReader::new(stdout);
551
552 Ok(Self {
553 child: std::sync::Arc::new(tokio::sync::Mutex::new(child)),
554 request_id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)),
555 stdin: std::sync::Arc::new(tokio::sync::Mutex::new(stdin)),
556 stdout_reader: std::sync::Arc::new(tokio::sync::Mutex::new(stdout_reader)),
557 })
558 }
559}
560
561#[async_trait::async_trait]
562impl McpTransport for ProcessTransport {
563 /// Sends an MCP request to the child process and returns the response
564 ///
565 /// This method implements JSON-RPC 2.0 over stdin/stdout pipes. It sends
566 /// a line-delimited JSON request to the process stdin and reads the
567 /// corresponding response from stdout. Communication is synchronous with
568 /// proper request/response correlation.
569 ///
570 /// # Arguments
571 ///
572 /// * `method` - The MCP method name (e.g., "tools/list", "tools/call", "resources/read")
573 /// * `params` - JSON parameters for the method call
574 ///
575 /// # Returns
576 ///
577 /// The result portion of the JSON-RPC response
578 ///
579 /// # Errors
580 ///
581 /// - Process communication errors (broken pipes)
582 /// - Process crashes or unexpected termination
583 /// - JSON serialization/deserialization errors
584 /// - MCP server errors (returned in JSON-RPC error field)
585 /// - I/O errors on stdin/stdout
586 ///
587 /// # Example
588 ///
589 /// ```rust,no_run
590 /// use mistralrs_mcp::transport::{ProcessTransport, McpTransport};
591 /// use serde_json::json;
592 ///
593 /// #[tokio::main]
594 /// async fn main() -> anyhow::Result<()> {
595 /// let transport = ProcessTransport::new(
596 /// "mcp-server-filesystem".to_string(),
597 /// vec!["--root".to_string(), "/tmp".to_string()],
598 /// None,
599 /// None
600 /// ).await?;
601 ///
602 /// // List available tools
603 /// let tools = transport.send_request("tools/list", serde_json::Value::Null).await?;
604 ///
605 /// // Call a specific tool
606 /// let params = json!({
607 /// "name": "read_file",
608 /// "arguments": {"path": "/tmp/example.txt"}
609 /// });
610 /// let result = transport.send_request("tools/call", params).await?;
611 ///
612 /// Ok(())
613 /// }
614 /// ```
615 async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
616 use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
617
618 // Ensure params is an object, not null
619 let params = if params.is_null() {
620 serde_json::json!({})
621 } else {
622 params
623 };
624 // Generate unique request ID
625 let id = self
626 .request_id
627 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
628 let request_body = serde_json::json!({
629 "jsonrpc": "2.0",
630 "id": id,
631 "method": method,
632 "params": params
633 });
634
635 // Send request via stdin
636 let request_line = serde_json::to_string(&request_body)? + "\n";
637 let mut stdin = self.stdin.lock().await;
638 stdin.write_all(request_line.as_bytes()).await?;
639 stdin.flush().await?;
640 drop(stdin);
641
642 // Read response from stdout
643 let mut stdout_reader = self.stdout_reader.lock().await;
644 let mut response_line = String::new();
645 stdout_reader.read_line(&mut response_line).await?;
646 drop(stdout_reader);
647
648 let response_body: Value = serde_json::from_str(response_line.trim())?;
649
650 // Check for JSON-RPC errors
651 if let Some(error) = response_body.get("error") {
652 return Err(anyhow::anyhow!("MCP server error: {}", error));
653 }
654
655 response_body
656 .get("result")
657 .cloned()
658 .ok_or_else(|| anyhow::anyhow!("No result in MCP response"))
659 }
660
661 /// Tests the process connection by sending a ping request
662 ///
663 /// Sends a "ping" method call to verify that the MCP server process is
664 /// responsive and the stdin/stdout communication is working properly.
665 ///
666 /// # Returns
667 ///
668 /// Ok(()) if the ping was successful
669 ///
670 /// # Errors
671 ///
672 /// - Process communication errors
673 /// - Process crashed or terminated
674 /// - Broken stdin/stdout pipes
675 async fn ping(&self) -> Result<()> {
676 self.send_request("ping", Value::Null).await?;
677 Ok(())
678 }
679
680 /// Terminates the child process and cleans up resources
681 ///
682 /// This method forcefully terminates the MCP server process and closes
683 /// all associated pipes. Any pending requests will fail after this call.
684 /// The transport cannot be used after closing.
685 ///
686 /// # Returns
687 ///
688 /// Ok(()) if the process was terminated successfully
689 ///
690 /// # Errors
691 ///
692 /// - Process termination errors
693 /// - Resource cleanup failures
694 ///
695 /// # Note
696 ///
697 /// This method sends SIGKILL to the process, which may not allow for
698 /// graceful cleanup. Consider implementing graceful shutdown through
699 /// MCP protocol methods before calling this method.
700 async fn close(&self) -> Result<()> {
701 let mut child = self.child.lock().await;
702 child.kill().await?;
703 Ok(())
704 }
705
706 /// Sends the server a initialization notification to let it know we are done initializing
707 async fn send_initialization_notification(&self) -> Result<()> {
708 use tokio::io::AsyncWriteExt;
709 let mut stdin = self.stdin.lock().await;
710 stdin
711 .write_all(
712 format!(
713 "{}\n",
714 serde_json::json!({"jsonrpc": "2.0", "method": "notifications/initialized"})
715 )
716 .as_bytes(),
717 )
718 .await?;
719 stdin.flush().await?;
720 drop(stdin);
721 Ok(())
722 }
723}
724
725/// WebSocket-based MCP transport
726///
727/// Provides real-time bidirectional communication with MCP servers over WebSocket connections.
728/// This transport supports secure connections (WSS), Bearer token authentication, and concurrent
729/// request/response handling with proper JSON-RPC 2.0 message correlation.
730///
731/// # Features
732///
733/// - **Async WebSocket Communication**: Built on tokio-tungstenite for high-performance async I/O
734/// - **Request/Response Matching**: Automatic correlation of responses using atomic request IDs
735/// - **Bearer Token Support**: Authentication via Authorization header during handshake
736/// - **Connection Management**: Proper ping/pong and connection lifecycle handling
737/// - **Concurrent Operations**: Split stream architecture allows simultaneous read/write operations
738///
739/// # Architecture
740///
741/// The transport uses a split-stream design where the WebSocket connection is divided into
742/// separate read and write halves, each protected by async mutexes. This allows concurrent
743/// operations while maintaining thread safety. Request IDs are generated atomically to ensure
744/// unique identification of requests and proper response correlation.
745///
746/// # Example Usage
747///
748/// ```rust,no_run
749/// use mistralrs_mcp::transport::{WebSocketTransport, McpTransport};
750/// use std::collections::HashMap;
751///
752/// #[tokio::main]
753/// async fn main() -> anyhow::Result<()> {
754/// // Create headers with Bearer token
755/// let mut headers = HashMap::new();
756/// headers.insert("Authorization".to_string(), "Bearer your-token".to_string());
757///
758/// // Connect to WebSocket MCP server
759/// let transport = WebSocketTransport::new(
760/// "wss://api.example.com/mcp".to_string(),
761/// Some(30), // 30 second timeout
762/// Some(headers)
763/// ).await?;
764///
765/// // Use the transport for MCP communication
766/// let result = transport.send_request("tools/list", serde_json::Value::Null).await?;
767/// println!("Available tools: {}", result);
768///
769/// Ok(())
770/// }
771/// ```
772///
773/// # Protocol Compliance
774///
775/// This transport implements the Model Context Protocol (MCP) specification over WebSocket,
776/// adhering to JSON-RPC 2.0 message format with proper error handling and response correlation.
777pub struct WebSocketTransport {
778 write: std::sync::Arc<
779 tokio::sync::Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
780 >,
781 read:
782 std::sync::Arc<tokio::sync::Mutex<SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>>,
783 request_id: std::sync::Arc<std::sync::atomic::AtomicU64>,
784}
785
786impl WebSocketTransport {
787 /// Creates a new WebSocket transport connection to an MCP server
788 ///
789 /// # Arguments
790 ///
791 /// * `url` - WebSocket URL (ws:// or wss://)
792 /// * `_timeout_secs` - Connection timeout (currently unused, reserved for future use)
793 /// * `headers` - Optional HTTP headers for WebSocket handshake (e.g., Bearer tokens)
794 ///
795 /// # Returns
796 ///
797 /// A configured WebSocketTransport ready for MCP communication
798 ///
799 /// # Errors
800 ///
801 /// - Invalid URL format
802 /// - WebSocket connection failure
803 /// - Header parsing errors
804 /// - Network connectivity issues
805 ///
806 /// # Example
807 ///
808 /// ```rust,no_run
809 /// use mistralrs_mcp::transport::WebSocketTransport;
810 /// use std::collections::HashMap;
811 ///
812 /// #[tokio::main]
813 /// async fn main() -> anyhow::Result<()> {
814 /// let mut headers = HashMap::new();
815 /// headers.insert("Authorization".to_string(), "Bearer token123".to_string());
816 ///
817 /// let transport = WebSocketTransport::new(
818 /// "wss://mcp.example.com/api".to_string(),
819 /// Some(30),
820 /// Some(headers)
821 /// ).await?;
822 ///
823 /// Ok(())
824 /// }
825 /// ```
826 pub async fn new(
827 url: String,
828 _timeout_secs: Option<u64>,
829 headers: Option<HashMap<String, String>>,
830 ) -> Result<Self> {
831 // Create request with headers
832 let uri: Uri = url
833 .parse()
834 .map_err(|e| anyhow::anyhow!("Invalid WebSocket URL: {}", e))?;
835 let mut request = Request::builder()
836 .uri(uri)
837 .body(())
838 .map_err(|e| anyhow::anyhow!("Failed to create WebSocket request: {}", e))?;
839
840 // Add headers if provided
841 if let Some(headers) = headers {
842 let req_headers = request.headers_mut();
843 for (key, value) in headers {
844 let header_name = key
845 .parse::<http::header::HeaderName>()
846 .map_err(|e| anyhow::anyhow!("Invalid header key: {}", e))?;
847 let header_value = value
848 .parse::<http::header::HeaderValue>()
849 .map_err(|e| anyhow::anyhow!("Invalid header value: {}", e))?;
850 req_headers.insert(header_name, header_value);
851 }
852 }
853
854 // Connect to WebSocket
855 let (ws_stream, _) = connect_async(request)
856 .await
857 .map_err(|e| anyhow::anyhow!("WebSocket connection failed: {}", e))?;
858
859 // Split the stream
860 let (write, read) = ws_stream.split();
861
862 Ok(Self {
863 write: std::sync::Arc::new(tokio::sync::Mutex::new(write)),
864 read: std::sync::Arc::new(tokio::sync::Mutex::new(read)),
865 request_id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)),
866 })
867 }
868}
869
870#[async_trait::async_trait]
871impl McpTransport for WebSocketTransport {
872 /// Sends an MCP request over WebSocket and waits for the corresponding response
873 ///
874 /// This method implements the JSON-RPC 2.0 protocol over WebSocket, handling:
875 /// - Unique request ID generation for response correlation
876 /// - Concurrent request processing with proper message ordering
877 /// - Error handling for both transport and protocol errors
878 /// - Message filtering to match responses with requests
879 ///
880 /// # Arguments
881 ///
882 /// * `method` - The MCP method name (e.g., "tools/list", "tools/call")
883 /// * `params` - JSON parameters for the method call
884 ///
885 /// # Returns
886 ///
887 /// The result portion of the JSON-RPC response
888 ///
889 /// # Errors
890 ///
891 /// - WebSocket connection errors
892 /// - JSON serialization/deserialization errors
893 /// - MCP server errors (returned in JSON-RPC error field)
894 /// - Timeout or connection closure
895 ///
896 /// # Example
897 ///
898 /// ```rust,no_run
899 /// use mistralrs_mcp::transport::{WebSocketTransport, McpTransport};
900 /// use serde_json::json;
901 ///
902 /// #[tokio::main]
903 /// async fn main() -> anyhow::Result<()> {
904 /// let transport = WebSocketTransport::new(
905 /// "wss://api.example.com/mcp".to_string(),
906 /// None,
907 /// None
908 /// ).await?;
909 ///
910 /// // List available tools
911 /// let tools = transport.send_request("tools/list", serde_json::Value::Null).await?;
912 ///
913 /// // Call a specific tool
914 /// let params = json!({"query": "example search"});
915 /// let result = transport.send_request("tools/call", params).await?;
916 ///
917 /// Ok(())
918 /// }
919 /// ```
920 async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
921 // Ensure params is an object, not null
922 let params = if params.is_null() {
923 serde_json::json!({})
924 } else {
925 params
926 };
927
928 // Generate unique request ID
929 let id = self
930 .request_id
931 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
932
933 let request_body = serde_json::json!({
934 "jsonrpc": "2.0",
935 "id": id,
936 "method": method,
937 "params": params
938 });
939
940 // Send request
941 let message = Message::Text(serde_json::to_string(&request_body)?);
942
943 {
944 let mut write = self.write.lock().await;
945 write
946 .send(message)
947 .await
948 .map_err(|e| anyhow::anyhow!("Failed to send WebSocket message: {}", e))?;
949 }
950
951 // Read response
952 loop {
953 let mut read = self.read.lock().await;
954 let msg = read
955 .next()
956 .await
957 .ok_or_else(|| anyhow::anyhow!("WebSocket connection closed"))?
958 .map_err(|e| anyhow::anyhow!("WebSocket read error: {}", e))?;
959 drop(read);
960
961 match msg {
962 Message::Text(text) => {
963 let response_body: Value = serde_json::from_str(&text)?;
964
965 // Check if this is the response to our request
966 if let Some(response_id) = response_body.get("id").and_then(|v| v.as_u64()) {
967 if response_id == id {
968 // Check for JSON-RPC errors
969 if let Some(error) = response_body.get("error") {
970 return Err(anyhow::anyhow!("MCP server error: {}", error));
971 }
972
973 return response_body
974 .get("result")
975 .cloned()
976 .ok_or_else(|| anyhow::anyhow!("No result in MCP response"));
977 }
978 }
979 // If it's not our response, continue reading
980 }
981 Message::Binary(_) => {
982 // Handle binary messages if needed, for now skip
983 continue;
984 }
985 Message::Close(_) => {
986 return Err(anyhow::anyhow!("WebSocket connection closed by server"));
987 }
988 Message::Ping(_) | Message::Pong(_) => {
989 // Handle ping/pong frames, continue reading
990 continue;
991 }
992 Message::Frame(_) => {
993 // Raw frames, continue reading
994 continue;
995 }
996 }
997 }
998 }
999
1000 /// Sends a WebSocket ping frame to test connection health
1001 ///
1002 /// This method sends a ping frame to the server and expects a pong response,
1003 /// which helps verify that the WebSocket connection is still active and responsive.
1004 ///
1005 /// # Returns
1006 ///
1007 /// Ok(()) if the ping was sent successfully
1008 ///
1009 /// # Errors
1010 ///
1011 /// - WebSocket send errors
1012 /// - Connection closure
1013 async fn ping(&self) -> Result<()> {
1014 let ping_message = Message::Ping(vec![]);
1015 let mut write = self.write.lock().await;
1016 write
1017 .send(ping_message)
1018 .await
1019 .map_err(|e| anyhow::anyhow!("Failed to send ping: {}", e))?;
1020 Ok(())
1021 }
1022
1023 /// Gracefully closes the WebSocket connection
1024 ///
1025 /// Sends a close frame to the server to properly terminate the connection
1026 /// according to the WebSocket protocol. The server should respond with its
1027 /// own close frame to complete the closing handshake.
1028 ///
1029 /// # Returns
1030 ///
1031 /// Ok(()) if the close frame was sent successfully
1032 ///
1033 /// # Errors
1034 ///
1035 /// - WebSocket send errors
1036 /// - Connection already closed
1037 async fn close(&self) -> Result<()> {
1038 let close_message = Message::Close(None);
1039 let mut write = self.write.lock().await;
1040 write
1041 .send(close_message)
1042 .await
1043 .map_err(|e| anyhow::anyhow!("Failed to send close message: {}", e))?;
1044 Ok(())
1045 }
1046
1047 /// Sends the server a initialization notification to let it know we are done initializing
1048 async fn send_initialization_notification(&self) -> Result<()> {
1049 let request_body = serde_json::json!({
1050 "jsonrpc": "2.0",
1051 "method": "notifications/initialized",
1052 });
1053
1054 // Send request
1055 let message = Message::Text(serde_json::to_string(&request_body)?);
1056
1057 {
1058 let mut write = self.write.lock().await;
1059 write
1060 .send(message)
1061 .await
1062 .map_err(|e| anyhow::anyhow!("Failed to send WebSocket message: {}", e))?;
1063 }
1064 Ok(())
1065 }
1066}