mistralrs_server_core/
streaming.rs

1//! SSE streaming utilities.
2
3use std::env;
4
5use mistralrs_core::Response;
6use tokio::sync::mpsc::Receiver;
7
8use crate::types::SharedMistralRsState;
9
10/// Default keep-alive interval for Server-Sent Events (SSE) streams in milliseconds.
11pub const DEFAULT_KEEP_ALIVE_INTERVAL_MS: u64 = 10_000;
12
13/// Represents the current state of a streaming response.
14pub enum DoneState {
15    /// The stream is actively processing and sending response chunks
16    Running,
17    /// The stream has finished processing and is about to send the `[DONE]` message
18    SendingDone,
19    /// The stream has completed entirely
20    Done,
21}
22
23/// A streaming response handler.
24///
25/// It processes incoming response chunks from a model and converts them
26/// into Server-Sent Events (SSE) format for real-time streaming to clients.
27pub struct BaseStreamer<R, C, D> {
28    /// Channel receiver for incoming model responses
29    pub rx: Receiver<Response>,
30    /// Current state of the streaming operation
31    pub done_state: DoneState,
32    /// Underlying mistral.rs instance
33    pub state: SharedMistralRsState,
34    /// Whether to store chunks for the completion callback
35    pub store_chunks: bool,
36    /// All chunks received during streaming (if `store_chunks` is true)
37    pub chunks: Vec<R>,
38    /// Optional callback to process each chunk before sending
39    pub on_chunk: Option<C>,
40    /// Optional callback to execute when streaming completes
41    pub on_done: Option<D>,
42}
43
44/// Generic function to create a SSE streamer with optional callbacks.
45pub(crate) fn base_create_streamer<R, C, D>(
46    rx: Receiver<Response>,
47    state: SharedMistralRsState,
48    on_chunk: Option<C>,
49    on_done: Option<D>,
50) -> BaseStreamer<R, C, D> {
51    let store_chunks = on_done.is_some();
52
53    BaseStreamer {
54        rx,
55        done_state: DoneState::Running,
56        store_chunks,
57        state,
58        chunks: Vec::new(),
59        on_chunk,
60        on_done,
61    }
62}
63
64/// Gets the keep-alive interval for SSE streams from environment or default.
65pub fn get_keep_alive_interval() -> u64 {
66    env::var("KEEP_ALIVE_INTERVAL")
67        .map(|val| {
68            val.parse::<u64>().unwrap_or_else(|e| {
69                tracing::warn!("Failed to parse KEEP_ALIVE_INTERVAL: {}. Using default.", e);
70                DEFAULT_KEEP_ALIVE_INTERVAL_MS
71            })
72        })
73        .unwrap_or(DEFAULT_KEEP_ALIVE_INTERVAL_MS)
74}