mistralrs_core/
sequence.rs

1use crate::{
2    get_mut_group,
3    pipeline::{text_models_inputs_processor::PagedAttentionMeta, LayerCaches},
4    response::{ChatCompletionChunkResponse, Choice, ChunkChoice, Response, SYSTEM_FINGERPRINT},
5    sampler::{Logprobs, Sampler},
6    ChatCompletionResponse, Usage,
7};
8use crate::{
9    paged_attention::{BlockEngineSequence, LogicalTokenBlock},
10    pipeline::{DiffusionGenerationParams, KvCache},
11    response::CompletionChoice,
12    tools::ToolCallingMatcher,
13    CompletionChunkChoice, CompletionChunkResponse, CompletionResponse, ImageChoice,
14    ImageGenerationResponse, ImageGenerationResponseFormat,
15};
16use candle_core::Tensor;
17use std::{
18    fmt::Display,
19    sync::{Arc, RwLock},
20    time::{SystemTime, UNIX_EPOCH},
21};
22use tokio::sync::{
23    mpsc::{error::SendError, Sender},
24    Mutex, MutexGuard,
25};
26
27#[derive(Clone, Copy, PartialEq, Debug)]
28pub enum StopReason {
29    Eos,
30    StopTok(u32),
31    Length(usize),
32    ModelLength(usize),
33    StopString {
34        stop_string_idx: usize,
35        completion_bytes_pos: usize,
36    },
37    Canceled,
38    GeneratedImage,
39}
40
41impl Display for StopReason {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        match self {
44            StopReason::Eos => write!(f, "stop"),
45            StopReason::Length(_) | StopReason::ModelLength(_) => write!(f, "length"),
46            StopReason::StopTok(_) | StopReason::StopString { .. } => write!(f, "stop"),
47            StopReason::Canceled => write!(f, "canceled"),
48            StopReason::GeneratedImage => write!(f, "generated-image"),
49        }
50    }
51}
52
53#[derive(Clone, Copy, PartialEq, Debug)]
54pub enum SequenceState {
55    Done(StopReason),
56    RunningPrompt,
57    RunningCompletion,
58    Waiting,
59    Error,
60    RunningPrefillPrompt,
61    // For PagedAttention:
62    FinishedAborted,
63    FinishedIgnored,
64    Swapped,
65}
66
67pub enum SequenceRecognizer {
68    Llguidance(Box<llguidance::Matcher>),
69    None,
70}
71
72enum SequenceCustomMetadata {
73    PagedAttention {
74        logical_token_blocks: Vec<LogicalTokenBlock>,
75        block_size: usize,
76    },
77    None,
78}
79
80macro_rules! blocks_to_add_new_tok {
81    ($logical_token_blocks:expr) => {{
82        let last = $logical_token_blocks.last();
83        if !last.is_some_and(|last| last.is_full() || last.is_empty()) {
84            // If we have space
85            0
86        } else {
87            1
88        }
89    }};
90}
91
92impl SequenceCustomMetadata {
93    fn append_token_to_blocks(&mut self, tok: usize) {
94        match self {
95            Self::PagedAttention {
96                logical_token_blocks,
97                block_size,
98            } => {
99                let last = logical_token_blocks.last_mut();
100                match last {
101                    Some(last) => {
102                        last.append_token_id(tok);
103                    }
104                    None => {
105                        logical_token_blocks.push(LogicalTokenBlock::new(*block_size));
106                        logical_token_blocks
107                            .last_mut()
108                            .unwrap()
109                            .append_token_id(tok);
110                    }
111                }
112                if logical_token_blocks.last().as_ref().unwrap().is_full() {
113                    logical_token_blocks.push(LogicalTokenBlock::new(*block_size));
114                }
115            }
116            Self::None => (),
117        }
118    }
119
120    fn pop_token_from_blocks(&mut self) {
121        match self {
122            Self::PagedAttention {
123                logical_token_blocks,
124                block_size: _,
125            } => {
126                let last = logical_token_blocks.last_mut().unwrap();
127                last.pop_token();
128            }
129            Self::None => (),
130        }
131    }
132
133    fn append_tokens_to_blocks(&mut self, toks: Vec<usize>) {
134        for tok in toks {
135            self.append_token_to_blocks(tok);
136        }
137    }
138
139    fn remove_tokens_from_blocks(&mut self, n: usize) {
140        for _ in 0..n {
141            self.pop_token_from_blocks();
142        }
143    }
144}
145
146#[derive(Clone, Copy)]
147pub enum SeqStepType {
148    PromptAndDecode,
149    OneShot,
150}
151
152pub struct Sequence {
153    // Metadata, const
154    id: usize,
155    prompt_len: usize,
156    max_len: Option<usize>,
157    timestamp: u128,
158    sampler: Arc<Sampler>,
159    stop_tokens: Vec<u32>,
160    stop_strings: Vec<String>,
161    return_logprobs: bool,
162    responder: Sender<Response>,
163    response_index: usize,
164    creation_time: u64,
165    prompt: String,
166    sequence_stepping_type: SeqStepType,
167    pub(crate) return_raw_logits: bool,
168    token_offset: usize,
169    eos_tokens: Vec<u32>,
170
171    // Image generation
172    image_gen_response_format: Option<ImageGenerationResponseFormat>,
173    diffusion_params: Option<DiffusionGenerationParams>,
174
175    // Completion requests
176    suffix: Option<String>,
177    prefix: Option<String>,
178
179    // Speculative
180    is_tmp: bool,
181
182    // Prefix caching
183    prefill_prompt_toks: Option<Vec<u32>>,
184
185    // Cache
186    normal_cache: Vec<Option<KvCache>>,
187    normal_draft_cache: Vec<Option<KvCache>>,
188    scaling_cache: Option<Tensor>,
189    cache: LayerCaches,
190    draft_cache: LayerCaches,
191    xlora_cache: Option<LayerCaches>,
192
193    // Preallocated KV cache (k,v)
194    seq_preallocated_cache: Option<(Tensor, Tensor)>,
195
196    // Mutables
197    tokens: Vec<u32>,
198    logprobs: Vec<Logprobs>,
199    cumulative_logprob: f32,
200    last_logprob: f32,
201    last_completion_bytes_len: usize,
202    last_is_done: Option<StopReason>,
203    completion_bytes: Vec<u8>,
204    stream_idx: usize,
205    pub recognizer: SequenceRecognizer,
206    scheduling_urgency: usize, // The number of passes since scheduling
207    input_images: Option<Vec<image::DynamicImage>>,
208    pub cached_pixel_values: Option<Tensor>,
209    pub cached_img_thw: Option<Tensor>,
210    pub cached_vid_thw: Option<Tensor>,
211    pub has_changed_prompt: bool,
212
213    // GPU things
214    pub prompt_tok_per_sec: f32,
215    pub prompt_timestamp: Option<u128>,
216    pub total_prompt_time: Option<u128>,
217    group: Arc<Mutex<SequenceGroup>>,
218    state: RwLock<SequenceState>,
219
220    // Custom backend metadata
221    custom_metadata: SequenceCustomMetadata,
222
223    // Tool calls
224    pub tools: Option<Arc<ToolCallingMatcher>>,
225}
226
227impl BlockEngineSequence for Sequence {
228    fn blocks_to_add_new_tok(&self) -> usize {
229        match &self.custom_metadata {
230            SequenceCustomMetadata::PagedAttention {
231                logical_token_blocks,
232                block_size: _,
233            } => {
234                blocks_to_add_new_tok!(logical_token_blocks)
235            }
236            SequenceCustomMetadata::None => unreachable!(),
237        }
238    }
239
240    fn get_id(&self) -> usize {
241        self.id
242    }
243
244    fn get_logical_token_blocks(&self) -> usize {
245        match &self.custom_metadata {
246            SequenceCustomMetadata::PagedAttention {
247                logical_token_blocks,
248                block_size: _,
249            } => logical_token_blocks.len(),
250            SequenceCustomMetadata::None => unreachable!(),
251        }
252    }
253}
254
255impl Sequence {
256    #[allow(clippy::too_many_arguments)]
257    pub fn new_waiting(
258        tokens: Vec<u32>,
259        prompt: String,
260        id: usize,
261        timestamp: u128,
262        layers: usize,
263        responder: Sender<Response>,
264        sampler: Sampler,
265        stop_tokens: Vec<u32>,
266        stop_strings: Vec<String>,
267        max_len: Option<usize>,
268        return_logprobs: bool,
269        is_xlora: bool,
270        group: Arc<Mutex<SequenceGroup>>,
271        response_index: usize,
272        creation_time: u64,
273        recognizer: SequenceRecognizer,
274        suffix: Option<String>,
275        prefix: Option<String>,
276        input_images: Option<Vec<image::DynamicImage>>,
277        // Paged attention
278        block_size: Option<usize>,
279        //
280        tools: Option<Arc<ToolCallingMatcher>>,
281        image_gen_response_format: Option<ImageGenerationResponseFormat>,
282        sequence_stepping_type: SeqStepType,
283        diffusion_params: Option<DiffusionGenerationParams>,
284        // Preallocated KV cache (k,v)
285        seq_preallocated_cache: Option<(Tensor, Tensor)>,
286        //
287        return_raw_logits: bool,
288        eos_tokens: Vec<u32>,
289    ) -> Self {
290        let prompt_len = tokens.len();
291        let mut custom_metadata = if let Some(block_size) = block_size {
292            SequenceCustomMetadata::PagedAttention {
293                logical_token_blocks: Vec::new(),
294                block_size,
295            }
296        } else {
297            SequenceCustomMetadata::None
298        };
299        custom_metadata
300            .append_tokens_to_blocks(tokens.iter().map(|x| *x as usize).collect::<Vec<_>>());
301        Self {
302            tokens,
303            prompt,
304            logprobs: Vec::new(),
305            prompt_len,
306            id,
307            timestamp,
308            state: RwLock::new(SequenceState::Waiting),
309            normal_cache: vec![None; layers],
310            normal_draft_cache: vec![None; layers],
311            cache: vec![None; layers],
312            draft_cache: vec![None; layers],
313            xlora_cache: if is_xlora {
314                Some(vec![None; layers])
315            } else {
316                None
317            },
318            seq_preallocated_cache,
319            responder,
320            sampler: sampler.into(),
321            stop_tokens,
322            stop_strings,
323            max_len,
324            return_logprobs,
325            prompt_tok_per_sec: 0.,
326            prompt_timestamp: None,
327            group,
328            scaling_cache: None,
329            response_index,
330            creation_time,
331            recognizer,
332            prefill_prompt_toks: None,
333            suffix,
334            prefix,
335            cumulative_logprob: 0.,
336            completion_bytes: Vec::new(),
337            stream_idx: 0,
338            last_completion_bytes_len: 0,
339            last_logprob: 0.0,
340            last_is_done: None,
341            is_tmp: false,
342            scheduling_urgency: 0,
343            input_images,
344            custom_metadata,
345            tools,
346            image_gen_response_format,
347            sequence_stepping_type,
348            diffusion_params,
349            cached_pixel_values: None,
350            cached_img_thw: None,
351            cached_vid_thw: None,
352            return_raw_logits,
353            token_offset: 0,
354            eos_tokens,
355            has_changed_prompt: false,
356            total_prompt_time: None,
357        }
358    }
359
360    pub fn add_urgency(mut self) -> Self {
361        self.scheduling_urgency += 1;
362        self
363    }
364
365    pub fn reset_urgency(mut self) -> Self {
366        self.scheduling_urgency = 0;
367        self
368    }
369
370    /// Simple metric: (scheduling urgency) + log2(length)
371    /// Takes into account: urgency (scales linear) and length (scales logarithmic)
372    /// Scaling urgency is the number of scheduling passes where we have not been scheduled.
373    pub fn compute_priority(&self) -> f64 {
374        #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
375        (self.scheduling_urgency as f64) + (self.len() as f64).log2()
376    }
377
378    pub fn prefill(
379        mut self,
380        cache: LayerCaches,
381        xlora_cache: Option<LayerCaches>,
382        toks: Vec<u32>,
383    ) -> Self {
384        self.cache = cache;
385        self.xlora_cache = xlora_cache;
386        self.prefill_prompt_toks = Some(toks);
387        self.set_state(SequenceState::RunningPrefillPrompt);
388        self
389    }
390
391    pub fn prefill_v2(
392        mut self,
393        cache: Vec<Option<KvCache>>,
394        toks: Vec<u32>,
395        offset: usize,
396    ) -> Self {
397        self.normal_cache = cache;
398        self.prefill_prompt_toks = Some(toks);
399        self.set_state(SequenceState::RunningPrefillPrompt);
400        self.token_offset = offset;
401        self
402    }
403
404    /// This is the number of tokens. If the KV cache is Some, then it will use that.
405    pub fn len(&self) -> usize {
406        if let Some(toks) = &self.prefill_prompt_toks {
407            return toks.len();
408        }
409        if self.is_tmp {
410            return self.tokens.len();
411        }
412        // Use xlora cache first because of non granular
413        if self.xlora_cache.as_ref().is_some_and(|c| c[0].is_some()) {
414            self.xlora_cache.as_ref().unwrap()[0]
415                .as_ref()
416                .unwrap()
417                .0
418                .dims()[2]
419                + 1
420        } else if let Some((_, x)) = &self.cache[0] {
421            x.dims()[2] + 1
422        } else {
423            self.tokens.len()
424        }
425    }
426
427    pub fn id(&self) -> &usize {
428        &self.id
429    }
430
431    pub fn is_running(&self) -> bool {
432        matches!(
433            *self.state.read().unwrap(),
434            SequenceState::RunningCompletion | SequenceState::RunningPrompt // | SequenceState::RunningPrefillPrompt
435        )
436    }
437
438    pub fn is_completion(&self) -> bool {
439        matches!(
440            *self.state.read().unwrap(),
441            SequenceState::RunningCompletion
442        )
443    }
444
445    pub fn is_prompt(&self) -> bool {
446        matches!(
447            *self.state.read().unwrap(),
448            SequenceState::RunningPrompt | SequenceState::RunningPrefillPrompt
449        )
450    }
451
452    pub fn is_waiting(&self) -> bool {
453        matches!(*self.state.read().unwrap(), SequenceState::Waiting)
454    }
455
456    pub fn is_finished_paged_attn(&self) -> bool {
457        matches!(
458            *self.state.read().unwrap(),
459            SequenceState::FinishedAborted
460                | SequenceState::FinishedIgnored
461                | SequenceState::Done(_)
462        )
463    }
464
465    pub fn get_toks(&self) -> &[u32] {
466        if let Some(toks) = &self.prefill_prompt_toks {
467            return toks;
468        }
469        &self.tokens
470    }
471
472    pub fn get_initial_prompt(&self) -> &str {
473        &self.prompt
474    }
475
476    pub fn set_initial_prompt(&mut self, new: String) {
477        self.prompt = new;
478    }
479
480    pub fn token_offset(&self) -> usize {
481        self.token_offset
482    }
483
484    /// This will also set prompt_len
485    pub(crate) fn set_toks_and_reallocate(
486        &mut self,
487        toks: Vec<u32>,
488        paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
489    ) {
490        self.tokens.clone_from(&toks);
491        self.prompt_len = self.tokens.len();
492        // Handle possible block engine
493        match &mut self.custom_metadata {
494            SequenceCustomMetadata::PagedAttention {
495                logical_token_blocks,
496                block_size: _,
497            } => {
498                logical_token_blocks.clear();
499            }
500            SequenceCustomMetadata::None => (),
501        }
502        self.custom_metadata
503            .append_tokens_to_blocks(toks.iter().map(|x| *x as usize).collect::<Vec<_>>());
504
505        if let Some(metadata) = paged_attn_metadata {
506            // Free and then reallocate as appropriate
507            metadata.block_engine.free_sequence(*self.id());
508            metadata.block_engine.allocate(self);
509        }
510    }
511
512    pub fn completion_bytes(&self) -> &[u8] {
513        &self.completion_bytes
514    }
515
516    pub fn preallocated_cache(&self) -> Option<&(Tensor, Tensor)> {
517        self.seq_preallocated_cache.as_ref()
518    }
519
520    pub fn normal_cache(&mut self) -> &mut Vec<Option<KvCache>> {
521        &mut self.normal_cache
522    }
523
524    pub fn normal_draft_cache(&mut self) -> &mut Vec<Option<KvCache>> {
525        &mut self.normal_draft_cache
526    }
527
528    pub fn cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
529        &mut self.cache
530    }
531
532    pub fn draft_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
533        &mut self.draft_cache
534    }
535
536    pub fn xlora_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
537        self.xlora_cache.as_mut().expect("No X-LoRA cache.")
538    }
539
540    pub fn scaling_cache(&mut self) -> &mut Option<Tensor> {
541        &mut self.scaling_cache
542    }
543
544    pub fn is_xlora(&self) -> bool {
545        self.xlora_cache.is_some()
546    }
547
548    pub fn sampler(&mut self) -> Arc<Sampler> {
549        self.sampler.clone()
550    }
551
552    /// Add a some prefill tokens. Only meant for internal speculative decoding usage.
553    pub fn set_prefill_toks(&mut self, toks: Vec<u32>) {
554        self.prefill_prompt_toks = Some(toks)
555    }
556
557    /// Remove the prefill tokens.
558    pub fn reset_prefill_toks(&mut self) {
559        self.prefill_prompt_toks = None
560    }
561
562    /// Internal api to add one raw token.
563    pub(crate) fn add_tmp_tok(&mut self, tok: u32) {
564        self.is_tmp = true;
565        self.tokens.push(tok);
566        // Handle possible block engine
567        self.custom_metadata.append_token_to_blocks(tok as usize);
568    }
569
570    /// Internal api to remove n raw tokens.
571    pub(crate) fn remove_tmp_tok(&mut self, n: usize) {
572        self.is_tmp = false;
573        self.tokens.truncate(self.tokens.len() - n);
574        // Handle possible block engine
575        self.custom_metadata.remove_tokens_from_blocks(n);
576    }
577
578    pub fn add_token(
579        &mut self,
580        tok: Logprobs,
581        completion_bytes: Vec<u8>,
582        is_done: &Option<StopReason>,
583    ) {
584        let stopped_by_token = matches!(
585            is_done,
586            Some(StopReason::Eos) | Some(StopReason::StopTok(_))
587        );
588        if !stopped_by_token {
589            // Completion bytes is used to check for stop strings, and as the response buffer.
590            // We don't need to add stop tokens to the completion bytes to check for stop strings.
591            // And by not adding it here, we can avoid having to delete these tokens from the output.
592            self.completion_bytes.extend_from_slice(&completion_bytes);
593            self.last_completion_bytes_len = completion_bytes.len();
594        }
595        self.last_logprob = tok.logprob;
596        self.last_is_done = *is_done;
597
598        self.custom_metadata
599            .append_token_to_blocks(tok.token as usize);
600
601        self.cumulative_logprob += tok.logprob;
602        self.tokens.push(tok.token);
603        self.logprobs.push(tok);
604        self.reset_prefill_toks();
605    }
606
607    pub fn responder(&self) -> Sender<Response> {
608        self.responder.clone()
609    }
610
611    pub fn creation_time(&self) -> u64 {
612        self.creation_time
613    }
614
615    pub fn set_state(&self, state: SequenceState) {
616        if matches!(state, SequenceState::Error) {
617            get_mut_group!(self).n_choices -= 1;
618        }
619        *self.state.write().unwrap() = state;
620    }
621
622    pub fn getstate(&self) -> SequenceState {
623        *self.state.read().unwrap()
624    }
625
626    pub fn is_done(
627        &self,
628        tok: u32,
629        eos_tok: Option<&[u32]>,
630        max_model_len: usize,
631    ) -> Option<StopReason> {
632        let is_eos = match eos_tok {
633            Some(eos_tok) => eos_tok.iter().any(|t| *t == tok),
634            None => false,
635        };
636        if is_eos {
637            Some(StopReason::Eos)
638        } else if matches!(
639            &*self.state.read().unwrap(),
640            SequenceState::Done(StopReason::Canceled)
641        ) {
642            Some(StopReason::Canceled)
643        } else if self.stop_tokens.contains(&tok) {
644            Some(StopReason::StopTok(tok))
645        } else if self.max_len.is_some()
646            && self.tokens.len().saturating_sub(self.prompt_len) == self.max_len.unwrap()
647        {
648            // add_token was already called
649            Some(StopReason::Length(self.max_len.unwrap()))
650        } else if self.tokens.len().saturating_sub(self.prompt_len) == max_model_len {
651            Some(StopReason::ModelLength(max_model_len))
652        } else {
653            if !self.stop_strings.is_empty() {
654                for (idx, s) in self.stop_strings.iter().enumerate() {
655                    if let Some(pos) = galil_seiferas::gs_find(&self.completion_bytes, s.as_bytes())
656                    {
657                        return Some(StopReason::StopString {
658                            stop_string_idx: idx,
659                            completion_bytes_pos: pos,
660                        });
661                    }
662                }
663            }
664            None
665        }
666    }
667
668    pub fn logprobs(&self) -> &[Logprobs] {
669        &self.logprobs
670    }
671
672    pub fn return_logprobs(&self) -> bool {
673        self.return_logprobs
674    }
675
676    pub fn prompt_tokens(&self) -> usize {
677        self.prompt_len
678    }
679
680    pub fn stop_strings(&self) -> &[String] {
681        &self.stop_strings
682    }
683
684    /// Returns the delta between the last two decoded sequences
685    pub fn get_delta(
686        &mut self,
687    ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
688        let new_decoded = self.peek_delta();
689        if matches!(new_decoded, Ok(Some(_))) {
690            self.stream_idx = self.completion_bytes.len();
691        }
692        new_decoded
693    }
694
695    /// Peeks at the delta between the last two decoded sequences, but does not advance the stream index.
696    pub fn peek_delta(&self) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
697        let is_first = self.stream_idx == 0;
698        let new_decoded = String::from_utf8_lossy(&self.completion_bytes[self.stream_idx..]);
699        // Check if the sequence ends with valid utf8, if not skip it as it probably is a multi token sequence
700        if new_decoded.ends_with('�') {
701            return Ok(None);
702        }
703
704        // The first token usually starts with a space. We don't want to add that to the delta.
705        // Since we're using the completion_bytes, we need to take care of that ourselves.
706        // Had we used HF's Tokenizer, it would have taken care of that for us.
707        if is_first {
708            return Ok(Some(new_decoded.trim_start().to_string()));
709        }
710        Ok(Some(new_decoded.to_string()))
711    }
712
713    pub fn timestamp(&self) -> u128 {
714        self.timestamp
715    }
716
717    pub fn prompt_timestamp(&self) -> Option<u128> {
718        self.prompt_timestamp
719    }
720
721    fn update_time_info(&self) {
722        let now = SystemTime::now()
723            .duration_since(UNIX_EPOCH)
724            .expect("Time travel has occurred!")
725            .as_millis();
726
727        if let Some(ts) = self.prompt_timestamp {
728            get_mut_group!(self).total_completion_time = now - ts;
729            get_mut_group!(self).total_prompt_time = self.total_prompt_time.unwrap();
730        }
731
732        get_mut_group!(self).total_time = now - self.timestamp;
733
734        get_mut_group!(self).total_prompt_toks = self.prompt_len;
735        get_mut_group!(self).total_toks = self.len();
736    }
737
738    pub fn add_image_choice_to_group(&self, choice: ImageChoice) {
739        get_mut_group!(self).image_choices.push(choice);
740        self.update_time_info();
741    }
742
743    pub fn add_choice_to_group(&self, choice: Choice) {
744        get_mut_group!(self).choices.push(choice);
745        self.update_time_info();
746    }
747
748    pub fn add_raw_choice_to_group(&self, logit_chunks: Vec<Tensor>) {
749        get_mut_group!(self)
750            .raw_choices
751            .push((logit_chunks, self.tokens.clone()));
752        self.update_time_info();
753    }
754
755    pub fn add_completion_choice_to_group(&self, mut choice: CompletionChoice) {
756        choice.text = format!(
757            "{}{}{}",
758            self.prefix.as_deref().unwrap_or(""),
759            choice.text,
760            self.suffix.as_deref().unwrap_or("")
761        );
762        get_mut_group!(self)
763            .completion_choices
764            .push((self.cumulative_logprob, choice));
765        self.update_time_info();
766    }
767
768    pub fn get_response_index(&self) -> usize {
769        self.response_index
770    }
771
772    pub fn get_mut_group(&self) -> MutexGuard<'_, SequenceGroup> {
773        get_mut_group!(self)
774    }
775
776    pub fn add_streaming_chunk_choice_to_group(&self, chunk: ChunkChoice) {
777        get_mut_group!(self).chat_streaming_chunks.push(chunk);
778        self.update_time_info();
779    }
780
781    pub fn add_streaming_completion_chunk_choice_to_group(&self, chunk: CompletionChunkChoice) {
782        get_mut_group!(self).completion_streaming_chunks.push(chunk);
783        self.update_time_info();
784    }
785
786    pub fn take_images(&mut self) -> Option<Vec<image::DynamicImage>> {
787        // So that we don't keep having an image after the actual prompt
788        if self.has_changed_prompt {
789            // Actual prompt
790            self.input_images.take()
791        } else {
792            // Dummy inputs processing
793            self.input_images.clone()
794        }
795    }
796
797    pub fn clone_images(&mut self) -> Option<Vec<image::DynamicImage>> {
798        self.input_images.clone()
799    }
800
801    pub fn images(&self) -> Option<&[image::DynamicImage]> {
802        self.input_images.as_deref()
803    }
804
805    pub fn has_images(&self) -> bool {
806        self.input_images
807            .as_ref()
808            .is_some_and(|images| !images.is_empty())
809    }
810
811    pub fn image_gen_response_format(&self) -> Option<ImageGenerationResponseFormat> {
812        self.image_gen_response_format
813    }
814
815    pub fn sequence_stepping_type(&self) -> &SeqStepType {
816        &self.sequence_stepping_type
817    }
818
819    pub fn get_diffusion_diffusion_params(&self) -> Option<DiffusionGenerationParams> {
820        self.diffusion_params.clone()
821    }
822
823    pub fn eos_tokens(&self) -> &[u32] {
824        &self.eos_tokens
825    }
826}
827
828pub struct SequenceGroup {
829    n_choices: usize, // The target number of choices to return. Can be decreased if an error is thrown.
830    best_of: Option<usize>, // Top n seqs based on cumulative logprobs.
831    pub total_prompt_toks: usize,
832    pub total_toks: usize,
833    pub total_prompt_time: u128,
834    pub total_time: u128,
835    pub total_completion_time: u128,
836    choices: Vec<Choice>,
837    image_choices: Vec<ImageChoice>,
838    raw_choices: Vec<(Vec<Tensor>, Vec<u32>)>,
839    completion_choices: Vec<(f32, CompletionChoice)>,
840    pub chat_streaming_chunks: Vec<ChunkChoice>,
841    pub completion_streaming_chunks: Vec<CompletionChunkChoice>,
842    pub is_streaming: bool,
843    pub is_chat: bool,
844}
845
846impl SequenceGroup {
847    pub fn new(
848        n_choices: usize,
849        is_streaming: bool,
850        is_chat: bool,
851        best_of: Option<usize>,
852    ) -> Self {
853        Self {
854            choices: Vec::new(),
855            image_choices: Vec::new(),
856            raw_choices: Vec::new(),
857            completion_choices: Vec::new(),
858            n_choices,
859            total_prompt_toks: 0,
860            total_toks: 0,
861            total_prompt_time: 0,
862            total_time: 0,
863            total_completion_time: 0,
864            chat_streaming_chunks: Vec::new(),
865            completion_streaming_chunks: Vec::new(),
866            is_streaming,
867            is_chat,
868            best_of,
869        }
870    }
871
872    pub fn get_choices(&self) -> &[Choice] {
873        &self.choices
874    }
875
876    /// This may apply the best_of.
877    pub fn get_completion_choices(&self) -> Vec<CompletionChoice> {
878        if let Some(best_of) = self.best_of {
879            let mut choices = self.completion_choices.clone();
880            // Sort by descending logprobs
881            choices.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("No ordering."));
882            choices
883                .into_iter()
884                .take(best_of)
885                .map(|(_, x)| x)
886                .collect::<Vec<_>>()
887        } else {
888            self.completion_choices
889                .clone()
890                .into_iter()
891                .map(|(_, x)| x)
892                .collect::<Vec<_>>()
893        }
894    }
895
896    pub fn get_image_choices(&self) -> &[ImageChoice] {
897        &self.image_choices
898    }
899
900    pub fn get_usage(&self) -> Usage {
901        #[allow(clippy::cast_precision_loss)]
902        Usage {
903            completion_tokens: self.total_toks - self.total_prompt_toks,
904            prompt_tokens: self.total_prompt_toks,
905            total_tokens: self.total_toks,
906            avg_tok_per_sec: (self.total_toks as f32 / self.total_time as f32) * 1000.,
907            avg_prompt_tok_per_sec: (self.total_prompt_toks as f32 / self.total_prompt_time as f32)
908                * 1000.,
909            avg_compl_tok_per_sec: ((self.total_toks - self.total_prompt_toks) as f32
910                / self.total_completion_time as f32)
911                * 1000.,
912            total_time_sec: self.total_time as f32 / 1000.,
913            total_completion_time_sec: self.total_completion_time as f32 / 1000.,
914            total_prompt_time_sec: self.total_prompt_time as f32 / 1000.,
915        }
916    }
917
918    pub async fn maybe_send_chat_done_response(
919        &self,
920        response: ChatCompletionResponse,
921        sender: Sender<Response>,
922    ) -> Result<(), SendError<Response>> {
923        if self.choices.len() == self.n_choices {
924            sender.send(Response::Done(response)).await?;
925        }
926
927        Ok(())
928    }
929
930    pub async fn maybe_send_raw_done_response(
931        &self,
932        sender: Sender<Response>,
933    ) -> Result<(), SendError<Response>> {
934        if self.raw_choices.len() == self.n_choices {
935            assert_eq!(self.raw_choices.len(), 1);
936            let (logits_chunks, tokens) = self.raw_choices[0].clone();
937            sender
938                .send(Response::Raw {
939                    logits_chunks,
940                    tokens,
941                })
942                .await?;
943        }
944
945        Ok(())
946    }
947
948    pub async fn maybe_send_image_gen_response(
949        &self,
950        response: ImageGenerationResponse,
951        sender: Sender<Response>,
952    ) -> Result<(), SendError<Response>> {
953        if self.image_choices.len() == self.n_choices {
954            sender.send(Response::ImageGeneration(response)).await?;
955        }
956
957        Ok(())
958    }
959
960    pub async fn maybe_send_streaming_response(
961        &mut self,
962        seq: &Sequence,
963        model: String,
964        usage_opt: Option<Usage>,
965    ) -> Result<(), Box<SendError<Response>>> {
966        if self.chat_streaming_chunks.len() == self.n_choices && self.is_streaming {
967            let mut swap_streaming_chunks = vec![];
968
969            std::mem::swap(&mut swap_streaming_chunks, &mut self.chat_streaming_chunks);
970
971            seq.responder()
972                .send(Response::Chunk(ChatCompletionChunkResponse {
973                    id: seq.id.to_string(),
974                    choices: swap_streaming_chunks,
975                    created: seq.timestamp,
976                    model: model.clone(),
977                    system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
978                    object: "chat.completion.chunk".to_string(),
979                    usage: usage_opt,
980                }))
981                .await?;
982        } else if self.completion_streaming_chunks.len() == self.n_choices && self.is_streaming {
983            let mut swap_streaming_chunks = vec![];
984
985            std::mem::swap(
986                &mut swap_streaming_chunks,
987                &mut self.completion_streaming_chunks,
988            );
989
990            seq.responder()
991                .send(Response::CompletionChunk(CompletionChunkResponse {
992                    id: seq.id.to_string(),
993                    choices: swap_streaming_chunks,
994                    created: seq.timestamp,
995                    model: model.clone(),
996                    system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
997                    object: "text_completion".to_string(),
998                }))
999                .await?;
1000        }
1001        Ok(())
1002    }
1003
1004    pub async fn maybe_send_completion_done_response(
1005        &self,
1006        response: CompletionResponse,
1007        sender: Sender<Response>,
1008    ) -> Result<(), Box<SendError<Response>>> {
1009        if self.completion_choices.len() == self.n_choices {
1010            sender.send(Response::CompletionDone(response)).await?;
1011        }
1012        Ok(())
1013    }
1014}