mistralrs_server_core/streaming.rs
1//! SSE streaming utilities.
2
3use std::env;
4
5use mistralrs_core::Response;
6use tokio::sync::mpsc::Receiver;
7
8use crate::types::SharedMistralRsState;
9
10/// Default keep-alive interval for Server-Sent Events (SSE) streams in milliseconds.
11pub const DEFAULT_KEEP_ALIVE_INTERVAL_MS: u64 = 10_000;
12
13/// Represents the current state of a streaming response.
14pub enum DoneState {
15 /// The stream is actively processing and sending response chunks
16 Running,
17 /// The stream has finished processing and is about to send the `[DONE]` message
18 SendingDone,
19 /// The stream has completed entirely
20 Done,
21}
22
23/// A streaming response handler.
24///
25/// It processes incoming response chunks from a model and converts them
26/// into Server-Sent Events (SSE) format for real-time streaming to clients.
27pub struct BaseStreamer<R, C, D> {
28 /// Channel receiver for incoming model responses
29 pub rx: Receiver<Response>,
30 /// Current state of the streaming operation
31 pub done_state: DoneState,
32 /// Underlying mistral.rs instance
33 pub state: SharedMistralRsState,
34 /// Whether to store chunks for the completion callback
35 pub store_chunks: bool,
36 /// All chunks received during streaming (if `store_chunks` is true)
37 pub chunks: Vec<R>,
38 /// Optional callback to process each chunk before sending
39 pub on_chunk: Option<C>,
40 /// Optional callback to execute when streaming completes
41 pub on_done: Option<D>,
42}
43
44/// Generic function to create a SSE streamer with optional callbacks.
45pub(crate) fn base_create_streamer<R, C, D>(
46 rx: Receiver<Response>,
47 state: SharedMistralRsState,
48 on_chunk: Option<C>,
49 on_done: Option<D>,
50) -> BaseStreamer<R, C, D> {
51 let store_chunks = on_done.is_some();
52
53 BaseStreamer {
54 rx,
55 done_state: DoneState::Running,
56 store_chunks,
57 state,
58 chunks: Vec::new(),
59 on_chunk,
60 on_done,
61 }
62}
63
64/// Gets the keep-alive interval for SSE streams from environment or default.
65pub fn get_keep_alive_interval() -> u64 {
66 env::var("KEEP_ALIVE_INTERVAL")
67 .map(|val| {
68 val.parse::<u64>().unwrap_or_else(|e| {
69 tracing::warn!("Failed to parse KEEP_ALIVE_INTERVAL: {}. Using default.", e);
70 DEFAULT_KEEP_ALIVE_INTERVAL_MS
71 })
72 })
73 .unwrap_or(DEFAULT_KEEP_ALIVE_INTERVAL_MS)
74}