mistralrs_core/
sequence.rs

1use crate::{
2    get_mut_group,
3    pipeline::{text_models_inputs_processor::PagedAttentionMeta, LayerCaches},
4    response::{ChatCompletionChunkResponse, Choice, ChunkChoice, Response, SYSTEM_FINGERPRINT},
5    sampler::{Logprobs, Sampler},
6    ChatCompletionResponse, Usage,
7};
8use crate::{
9    paged_attention::{BlockEngineSequence, LogicalTokenBlock},
10    pipeline::{DiffusionGenerationParams, KvCache},
11    response::CompletionChoice,
12    tools::ToolCallingMatcher,
13    CompletionChunkChoice, CompletionChunkResponse, CompletionResponse, ImageChoice,
14    ImageGenerationResponse, ImageGenerationResponseFormat,
15};
16use candle_core::Tensor;
17use std::{
18    fmt::Display,
19    sync::{Arc, RwLock},
20    time::{SystemTime, UNIX_EPOCH},
21};
22use tokio::sync::{
23    mpsc::{error::SendError, Sender},
24    Mutex, MutexGuard,
25};
26
27#[derive(Clone, Copy, PartialEq, Debug)]
28pub enum StopReason {
29    Eos,
30    StopTok(u32),
31    Length(usize),
32    ModelLength(usize),
33    StopString {
34        stop_string_idx: usize,
35        completion_bytes_pos: usize,
36    },
37    Canceled,
38    GeneratedImage,
39}
40
41impl Display for StopReason {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        match self {
44            StopReason::Eos => write!(f, "stop"),
45            StopReason::Length(_) | StopReason::ModelLength(_) => write!(f, "length"),
46            StopReason::StopTok(_) | StopReason::StopString { .. } => write!(f, "stop"),
47            StopReason::Canceled => write!(f, "canceled"),
48            StopReason::GeneratedImage => write!(f, "generated-image"),
49        }
50    }
51}
52
53#[derive(Clone, Copy, PartialEq, Debug)]
54pub enum SequenceState {
55    Done(StopReason),
56    RunningPrompt,
57    RunningCompletion,
58    Waiting,
59    Error,
60    RunningPrefillPrompt,
61    // For PagedAttention:
62    FinishedAborted,
63    FinishedIgnored,
64    Swapped,
65}
66
67pub enum SequenceRecognizer {
68    Llguidance(Box<llguidance::Constraint>),
69    None,
70}
71
72enum SequenceCustomMetadata {
73    PagedAttention {
74        logical_token_blocks: Vec<LogicalTokenBlock>,
75        block_size: usize,
76    },
77    None,
78}
79
80macro_rules! blocks_to_add_new_tok {
81    ($logical_token_blocks:expr) => {{
82        let last = $logical_token_blocks.last();
83        if !last.is_some_and(|last| last.is_full() || last.is_empty()) {
84            // If we have space
85            0
86        } else {
87            1
88        }
89    }};
90}
91
92impl SequenceCustomMetadata {
93    fn append_token_to_blocks(&mut self, tok: usize) {
94        match self {
95            Self::PagedAttention {
96                logical_token_blocks,
97                block_size,
98            } => {
99                let last = logical_token_blocks.last_mut();
100                match last {
101                    Some(last) => {
102                        last.append_token_id(tok);
103                    }
104                    None => {
105                        logical_token_blocks.push(LogicalTokenBlock::new(*block_size));
106                        logical_token_blocks
107                            .last_mut()
108                            .unwrap()
109                            .append_token_id(tok);
110                    }
111                }
112                if logical_token_blocks.last().as_ref().unwrap().is_full() {
113                    logical_token_blocks.push(LogicalTokenBlock::new(*block_size));
114                }
115            }
116            Self::None => (),
117        }
118    }
119
120    fn pop_token_from_blocks(&mut self) {
121        match self {
122            Self::PagedAttention {
123                logical_token_blocks,
124                block_size: _,
125            } => {
126                let last = logical_token_blocks.last_mut().unwrap();
127                last.pop_token();
128            }
129            Self::None => (),
130        }
131    }
132
133    fn append_tokens_to_blocks(&mut self, toks: Vec<usize>) {
134        for tok in toks {
135            self.append_token_to_blocks(tok);
136        }
137    }
138
139    fn remove_tokens_from_blocks(&mut self, n: usize) {
140        for _ in 0..n {
141            self.pop_token_from_blocks();
142        }
143    }
144}
145
146#[derive(Clone, Copy)]
147pub enum SeqStepType {
148    PromptAndDecode,
149    OneShot,
150}
151
152pub struct Sequence {
153    // Metadata, const
154    id: usize,
155    prompt_len: usize,
156    max_len: Option<usize>,
157    timestamp: u128,
158    sampler: Arc<Sampler>,
159    stop_tokens: Vec<u32>,
160    stop_strings: Vec<String>,
161    return_logprobs: bool,
162    responder: Sender<Response>,
163    response_index: usize,
164    creation_time: u64,
165    prompt: String,
166    sequence_stepping_type: SeqStepType,
167    pub(crate) return_raw_logits: bool,
168    token_offset: usize,
169    eos_tokens: Vec<u32>,
170
171    // Image generation
172    image_gen_response_format: Option<ImageGenerationResponseFormat>,
173    diffusion_params: Option<DiffusionGenerationParams>,
174
175    // Completion requests
176    suffix: Option<String>,
177    prefix: Option<String>,
178
179    // Speculative
180    is_tmp: bool,
181
182    // Prefix caching
183    prefill_prompt_toks: Option<Vec<u32>>,
184
185    // Adapter dynamic config
186    adapters: Option<Vec<String>>,
187
188    // Cache
189    normal_cache: Vec<Option<KvCache>>,
190    normal_draft_cache: Vec<Option<KvCache>>,
191    scaling_cache: Option<Tensor>,
192    cache: LayerCaches,
193    draft_cache: LayerCaches,
194    xlora_cache: Option<LayerCaches>,
195
196    // Preallocated KV cache (k,v)
197    seq_preallocated_cache: Option<(Tensor, Tensor)>,
198
199    // Mutables
200    tokens: Vec<u32>,
201    logprobs: Vec<Logprobs>,
202    cumulative_logprob: f32,
203    last_logprob: f32,
204    last_completion_bytes_len: usize,
205    last_is_done: Option<StopReason>,
206    completion_bytes: Vec<u8>,
207    stream_idx: usize,
208    pub recognizer: SequenceRecognizer,
209    scheduling_urgency: usize, // The number of passes since scheduling
210    input_images: Option<Vec<image::DynamicImage>>,
211    pub cached_pixel_values: Option<Tensor>,
212    pub cached_img_thw: Option<Tensor>,
213    pub cached_vid_thw: Option<Tensor>,
214
215    // GPU things
216    pub prompt_tok_per_sec: f32,
217    pub prompt_timestamp: Option<u128>,
218    group: Arc<Mutex<SequenceGroup>>,
219    state: RwLock<SequenceState>,
220
221    // Custom backend metadata
222    custom_metadata: SequenceCustomMetadata,
223
224    // Tool calls
225    pub tools: Option<Arc<ToolCallingMatcher>>,
226}
227
228impl BlockEngineSequence for Sequence {
229    fn blocks_to_add_new_tok(&self) -> usize {
230        match &self.custom_metadata {
231            SequenceCustomMetadata::PagedAttention {
232                logical_token_blocks,
233                block_size: _,
234            } => {
235                blocks_to_add_new_tok!(logical_token_blocks)
236            }
237            SequenceCustomMetadata::None => unreachable!(),
238        }
239    }
240
241    fn get_id(&self) -> usize {
242        self.id
243    }
244
245    fn get_logical_token_blocks(&self) -> usize {
246        match &self.custom_metadata {
247            SequenceCustomMetadata::PagedAttention {
248                logical_token_blocks,
249                block_size: _,
250            } => logical_token_blocks.len(),
251            SequenceCustomMetadata::None => unreachable!(),
252        }
253    }
254}
255
256impl Sequence {
257    #[allow(clippy::too_many_arguments)]
258    pub fn new_waiting(
259        tokens: Vec<u32>,
260        prompt: String,
261        id: usize,
262        timestamp: u128,
263        layers: usize,
264        responder: Sender<Response>,
265        sampler: Sampler,
266        stop_tokens: Vec<u32>,
267        stop_strings: Vec<String>,
268        max_len: Option<usize>,
269        return_logprobs: bool,
270        is_xlora: bool,
271        group: Arc<Mutex<SequenceGroup>>,
272        response_index: usize,
273        creation_time: u64,
274        recognizer: SequenceRecognizer,
275        suffix: Option<String>,
276        prefix: Option<String>,
277        adapters: Option<Vec<String>>,
278        input_images: Option<Vec<image::DynamicImage>>,
279        // Paged attention
280        block_size: Option<usize>,
281        //
282        tools: Option<Arc<ToolCallingMatcher>>,
283        image_gen_response_format: Option<ImageGenerationResponseFormat>,
284        sequence_stepping_type: SeqStepType,
285        diffusion_params: Option<DiffusionGenerationParams>,
286        // Preallocated KV cache (k,v)
287        seq_preallocated_cache: Option<(Tensor, Tensor)>,
288        //
289        return_raw_logits: bool,
290        eos_tokens: Vec<u32>,
291    ) -> Self {
292        let prompt_len = tokens.len();
293        let mut custom_metadata = if let Some(block_size) = block_size {
294            SequenceCustomMetadata::PagedAttention {
295                logical_token_blocks: Vec::new(),
296                block_size,
297            }
298        } else {
299            SequenceCustomMetadata::None
300        };
301        custom_metadata
302            .append_tokens_to_blocks(tokens.iter().map(|x| *x as usize).collect::<Vec<_>>());
303        Self {
304            tokens,
305            prompt,
306            logprobs: Vec::new(),
307            prompt_len,
308            id,
309            timestamp,
310            state: RwLock::new(SequenceState::Waiting),
311            normal_cache: vec![None; layers],
312            normal_draft_cache: vec![None; layers],
313            cache: vec![None; layers],
314            draft_cache: vec![None; layers],
315            xlora_cache: if is_xlora {
316                Some(vec![None; layers])
317            } else {
318                None
319            },
320            seq_preallocated_cache,
321            responder,
322            sampler: sampler.into(),
323            stop_tokens,
324            stop_strings,
325            max_len,
326            return_logprobs,
327            prompt_tok_per_sec: 0.,
328            prompt_timestamp: None,
329            group,
330            scaling_cache: None,
331            response_index,
332            creation_time,
333            recognizer,
334            prefill_prompt_toks: None,
335            suffix,
336            prefix,
337            cumulative_logprob: 0.,
338            completion_bytes: Vec::new(),
339            stream_idx: 0,
340            last_completion_bytes_len: 0,
341            last_logprob: 0.0,
342            last_is_done: None,
343            is_tmp: false,
344            scheduling_urgency: 0,
345            adapters,
346            input_images,
347            custom_metadata,
348            tools,
349            image_gen_response_format,
350            sequence_stepping_type,
351            diffusion_params,
352            cached_pixel_values: None,
353            cached_img_thw: None,
354            cached_vid_thw: None,
355            return_raw_logits,
356            token_offset: 0,
357            eos_tokens,
358        }
359    }
360
361    pub fn add_urgency(mut self) -> Self {
362        self.scheduling_urgency += 1;
363        self
364    }
365
366    pub fn reset_urgency(mut self) -> Self {
367        self.scheduling_urgency = 0;
368        self
369    }
370
371    /// Simple metric: (scheduling urgency) + log2(length)
372    /// Takes into account: urgency (scales linear) and length (scales logarithmic)
373    /// Scaling urgency is the number of scheduling passes where we have not been scheduled.
374    pub fn compute_priority(&self) -> f64 {
375        #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
376        (self.scheduling_urgency as f64) + (self.len() as f64).log2()
377    }
378
379    pub fn prefill(
380        mut self,
381        cache: LayerCaches,
382        xlora_cache: Option<LayerCaches>,
383        toks: Vec<u32>,
384    ) -> Self {
385        self.cache = cache;
386        self.xlora_cache = xlora_cache;
387        self.prefill_prompt_toks = Some(toks);
388        self.set_state(SequenceState::RunningPrefillPrompt);
389        self
390    }
391
392    pub fn prefill_v2(
393        mut self,
394        cache: Vec<Option<KvCache>>,
395        toks: Vec<u32>,
396        offset: usize,
397    ) -> Self {
398        self.normal_cache = cache;
399        self.prefill_prompt_toks = Some(toks);
400        self.set_state(SequenceState::RunningPrefillPrompt);
401        self.token_offset = offset;
402        self
403    }
404
405    /// This is the number of tokens. If the KV cache is Some, then it will use that.
406    pub fn len(&self) -> usize {
407        if let Some(toks) = &self.prefill_prompt_toks {
408            return toks.len();
409        }
410        if self.is_tmp {
411            return self.tokens.len();
412        }
413        // Use xlora cache first because of non granular
414        if self.xlora_cache.as_ref().is_some_and(|c| c[0].is_some()) {
415            self.xlora_cache.as_ref().unwrap()[0]
416                .as_ref()
417                .unwrap()
418                .0
419                .dims()[2]
420                + 1
421        } else if let Some((_, x)) = &self.cache[0] {
422            x.dims()[2] + 1
423        } else {
424            self.tokens.len()
425        }
426    }
427
428    pub fn id(&self) -> &usize {
429        &self.id
430    }
431
432    pub fn is_running(&self) -> bool {
433        matches!(
434            *self.state.read().unwrap(),
435            SequenceState::RunningCompletion | SequenceState::RunningPrompt // | SequenceState::RunningPrefillPrompt
436        )
437    }
438
439    pub fn is_completion(&self) -> bool {
440        matches!(
441            *self.state.read().unwrap(),
442            SequenceState::RunningCompletion
443        )
444    }
445
446    pub fn is_prompt(&self) -> bool {
447        matches!(
448            *self.state.read().unwrap(),
449            SequenceState::RunningPrompt | SequenceState::RunningPrefillPrompt
450        )
451    }
452
453    pub fn is_waiting(&self) -> bool {
454        matches!(*self.state.read().unwrap(), SequenceState::Waiting)
455    }
456
457    pub fn is_finished_paged_attn(&self) -> bool {
458        matches!(
459            *self.state.read().unwrap(),
460            SequenceState::FinishedAborted
461                | SequenceState::FinishedIgnored
462                | SequenceState::Done(_)
463        )
464    }
465
466    pub fn get_toks(&self) -> &[u32] {
467        if let Some(toks) = &self.prefill_prompt_toks {
468            return toks;
469        }
470        &self.tokens
471    }
472
473    pub fn get_initial_prompt(&self) -> &str {
474        &self.prompt
475    }
476
477    pub fn set_initial_prompt(&mut self, new: String) {
478        self.prompt = new;
479    }
480
481    pub fn token_offset(&self) -> usize {
482        self.token_offset
483    }
484
485    /// This will also set prompt_len
486    pub(crate) fn set_toks_and_reallocate(
487        &mut self,
488        toks: Vec<u32>,
489        paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
490    ) {
491        self.tokens.clone_from(&toks);
492        self.prompt_len = self.tokens.len();
493        // Handle possible block engine
494        match &mut self.custom_metadata {
495            SequenceCustomMetadata::PagedAttention {
496                logical_token_blocks,
497                block_size: _,
498            } => {
499                logical_token_blocks.clear();
500            }
501            SequenceCustomMetadata::None => (),
502        }
503        self.custom_metadata
504            .append_tokens_to_blocks(toks.iter().map(|x| *x as usize).collect::<Vec<_>>());
505
506        if let Some(metadata) = paged_attn_metadata {
507            // Free and then reallocate as appropriate
508            metadata.block_engine.free_sequence(*self.id());
509            metadata.block_engine.allocate(self);
510        }
511    }
512
513    pub fn completion_bytes(&self) -> &[u8] {
514        &self.completion_bytes
515    }
516
517    pub fn preallocated_cache(&self) -> Option<&(Tensor, Tensor)> {
518        self.seq_preallocated_cache.as_ref()
519    }
520
521    pub fn normal_cache(&mut self) -> &mut Vec<Option<KvCache>> {
522        &mut self.normal_cache
523    }
524
525    pub fn normal_draft_cache(&mut self) -> &mut Vec<Option<KvCache>> {
526        &mut self.normal_draft_cache
527    }
528
529    pub fn cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
530        &mut self.cache
531    }
532
533    pub fn draft_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
534        &mut self.draft_cache
535    }
536
537    pub fn xlora_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
538        self.xlora_cache.as_mut().expect("No X-LoRA cache.")
539    }
540
541    pub fn scaling_cache(&mut self) -> &mut Option<Tensor> {
542        &mut self.scaling_cache
543    }
544
545    pub fn is_xlora(&self) -> bool {
546        self.xlora_cache.is_some()
547    }
548
549    pub fn sampler(&mut self) -> Arc<Sampler> {
550        self.sampler.clone()
551    }
552
553    /// Add a some prefill tokens. Only meant for internal speculative decoding usage.
554    pub fn set_prefill_toks(&mut self, toks: Vec<u32>) {
555        self.prefill_prompt_toks = Some(toks)
556    }
557
558    /// Remove the prefill tokens.
559    pub fn reset_prefill_toks(&mut self) {
560        self.prefill_prompt_toks = None
561    }
562
563    /// Internal api to add one raw token.
564    pub(crate) fn add_tmp_tok(&mut self, tok: u32) {
565        self.is_tmp = true;
566        self.tokens.push(tok);
567        // Handle possible block engine
568        self.custom_metadata.append_token_to_blocks(tok as usize);
569    }
570
571    /// Internal api to remove n raw tokens.
572    pub(crate) fn remove_tmp_tok(&mut self, n: usize) {
573        self.is_tmp = false;
574        self.tokens.truncate(self.tokens.len() - n);
575        // Handle possible block engine
576        self.custom_metadata.remove_tokens_from_blocks(n);
577    }
578
579    pub fn add_token(
580        &mut self,
581        tok: Logprobs,
582        completion_bytes: Vec<u8>,
583        is_done: &Option<StopReason>,
584    ) {
585        let stopped_by_token = matches!(
586            is_done,
587            Some(StopReason::Eos) | Some(StopReason::StopTok(_))
588        );
589        if !stopped_by_token {
590            // Completion bytes is used to check for stop strings, and as the response buffer.
591            // We don't need to add stop tokens to the completion bytes to check for stop strings.
592            // And by not adding it here, we can avoid having to delete these tokens from the output.
593            self.completion_bytes.extend_from_slice(&completion_bytes);
594            self.last_completion_bytes_len = completion_bytes.len();
595        }
596        self.last_logprob = tok.logprob;
597        self.last_is_done = *is_done;
598
599        self.custom_metadata
600            .append_token_to_blocks(tok.token as usize);
601
602        self.cumulative_logprob += tok.logprob;
603        self.tokens.push(tok.token);
604        self.logprobs.push(tok);
605        self.reset_prefill_toks();
606    }
607
608    pub fn responder(&self) -> Sender<Response> {
609        self.responder.clone()
610    }
611
612    pub fn creation_time(&self) -> u64 {
613        self.creation_time
614    }
615
616    pub fn set_state(&self, state: SequenceState) {
617        if matches!(state, SequenceState::Error) {
618            get_mut_group!(self).n_choices -= 1;
619        }
620        *self.state.write().unwrap() = state;
621    }
622
623    pub fn getstate(&self) -> SequenceState {
624        *self.state.read().unwrap()
625    }
626
627    pub fn is_done(
628        &self,
629        tok: u32,
630        eos_tok: Option<&[u32]>,
631        max_model_len: usize,
632    ) -> Option<StopReason> {
633        let is_eos = match eos_tok {
634            Some(eos_tok) => eos_tok.iter().any(|t| *t == tok),
635            None => false,
636        };
637        if is_eos {
638            Some(StopReason::Eos)
639        } else if matches!(
640            &*self.state.read().unwrap(),
641            SequenceState::Done(StopReason::Canceled)
642        ) {
643            Some(StopReason::Canceled)
644        } else if self.stop_tokens.contains(&tok) {
645            Some(StopReason::StopTok(tok))
646        } else if self.max_len.is_some()
647            && self.tokens.len().saturating_sub(self.prompt_len) == self.max_len.unwrap()
648        {
649            // add_token was already called
650            Some(StopReason::Length(self.max_len.unwrap()))
651        } else if self.tokens.len().saturating_sub(self.prompt_len) == max_model_len {
652            Some(StopReason::ModelLength(max_model_len))
653        } else {
654            if !self.stop_strings.is_empty() {
655                for (idx, s) in self.stop_strings.iter().enumerate() {
656                    if let Some(pos) = galil_seiferas::gs_find(&self.completion_bytes, s.as_bytes())
657                    {
658                        return Some(StopReason::StopString {
659                            stop_string_idx: idx,
660                            completion_bytes_pos: pos,
661                        });
662                    }
663                }
664            }
665            None
666        }
667    }
668
669    pub fn logprobs(&self) -> &[Logprobs] {
670        &self.logprobs
671    }
672
673    pub fn return_logprobs(&self) -> bool {
674        self.return_logprobs
675    }
676
677    pub fn prompt_tokens(&self) -> usize {
678        self.prompt_len
679    }
680
681    pub fn stop_strings(&self) -> &[String] {
682        &self.stop_strings
683    }
684
685    /// Returns the delta between the last two decoded sequences
686    pub fn get_delta(
687        &mut self,
688    ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
689        let new_decoded = self.peek_delta();
690        if matches!(new_decoded, Ok(Some(_))) {
691            self.stream_idx = self.completion_bytes.len();
692        }
693        new_decoded
694    }
695
696    /// Peeks at the delta between the last two decoded sequences, but does not advance the stream index.
697    pub fn peek_delta(&self) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
698        let is_first = self.stream_idx == 0;
699        let new_decoded = String::from_utf8_lossy(&self.completion_bytes[self.stream_idx..]);
700        // Check if the sequence ends with valid utf8, if not skip it as it probably is a multi token sequence
701        if new_decoded.ends_with('�') {
702            return Ok(None);
703        }
704
705        // The first token usually starts with a space. We don't want to add that to the delta.
706        // Since we're using the completion_bytes, we need to take care of that ourselves.
707        // Had we used HF's Tokenizer, it would have taken care of that for us.
708        if is_first {
709            return Ok(Some(new_decoded.trim_start().to_string()));
710        }
711        Ok(Some(new_decoded.to_string()))
712    }
713
714    pub fn timestamp(&self) -> u128 {
715        self.timestamp
716    }
717
718    pub fn prompt_timestamp(&self) -> Option<u128> {
719        self.prompt_timestamp
720    }
721
722    fn update_time_info(&self) {
723        let now = SystemTime::now()
724            .duration_since(UNIX_EPOCH)
725            .expect("Time travel has occurred!")
726            .as_millis();
727
728        if let Some(ts) = self.prompt_timestamp {
729            get_mut_group!(self).total_completion_time += now - ts;
730            get_mut_group!(self).total_prompt_time += ts - self.timestamp;
731        }
732
733        get_mut_group!(self).total_time += now - self.timestamp;
734
735        get_mut_group!(self).total_prompt_toks = self.prompt_len;
736        get_mut_group!(self).total_toks = self.len();
737    }
738
739    pub fn add_image_choice_to_group(&self, choice: ImageChoice) {
740        get_mut_group!(self).image_choices.push(choice);
741        self.update_time_info();
742    }
743
744    pub fn add_choice_to_group(&self, choice: Choice) {
745        get_mut_group!(self).choices.push(choice);
746        self.update_time_info();
747    }
748
749    pub fn add_raw_choice_to_group(&self, logit_chunks: Vec<Tensor>) {
750        get_mut_group!(self)
751            .raw_choices
752            .push((logit_chunks, self.tokens.clone()));
753        self.update_time_info();
754    }
755
756    pub fn add_completion_choice_to_group(&self, mut choice: CompletionChoice) {
757        choice.text = format!(
758            "{}{}{}",
759            self.prefix.as_deref().unwrap_or(""),
760            choice.text,
761            self.suffix.as_deref().unwrap_or("")
762        );
763        get_mut_group!(self)
764            .completion_choices
765            .push((self.cumulative_logprob, choice));
766        self.update_time_info();
767    }
768
769    pub fn get_response_index(&self) -> usize {
770        self.response_index
771    }
772
773    pub fn get_mut_group(&self) -> MutexGuard<'_, SequenceGroup> {
774        get_mut_group!(self)
775    }
776
777    pub fn add_streaming_chunk_choice_to_group(&self, chunk: ChunkChoice) {
778        get_mut_group!(self).chat_streaming_chunks.push(chunk);
779        self.update_time_info();
780    }
781
782    pub fn add_streaming_completion_chunk_choice_to_group(&self, chunk: CompletionChunkChoice) {
783        get_mut_group!(self).completion_streaming_chunks.push(chunk);
784    }
785
786    pub fn get_adapters(&self) -> Option<Vec<String>> {
787        self.adapters.clone()
788    }
789
790    pub fn take_images(&mut self) -> Option<Vec<image::DynamicImage>> {
791        self.input_images.take()
792    }
793
794    pub fn clone_images(&mut self) -> Option<Vec<image::DynamicImage>> {
795        self.input_images.clone()
796    }
797
798    pub fn images(&self) -> Option<&[image::DynamicImage]> {
799        self.input_images.as_deref()
800    }
801
802    pub fn has_images(&self) -> bool {
803        self.input_images
804            .as_ref()
805            .is_some_and(|images| !images.is_empty())
806    }
807
808    pub fn image_gen_response_format(&self) -> Option<ImageGenerationResponseFormat> {
809        self.image_gen_response_format
810    }
811
812    pub fn sequence_stepping_type(&self) -> &SeqStepType {
813        &self.sequence_stepping_type
814    }
815
816    pub fn get_diffusion_diffusion_params(&self) -> Option<DiffusionGenerationParams> {
817        self.diffusion_params.clone()
818    }
819
820    pub fn eos_tokens(&self) -> &[u32] {
821        &self.eos_tokens
822    }
823}
824
825pub struct SequenceGroup {
826    n_choices: usize, // The target number of choices to return. Can be decreased if an error is thrown.
827    best_of: Option<usize>, // Top n seqs based on cumulative logprobs.
828    pub total_prompt_toks: usize,
829    pub total_toks: usize,
830    pub total_prompt_time: u128,
831    pub total_time: u128,
832    pub total_completion_time: u128,
833    choices: Vec<Choice>,
834    image_choices: Vec<ImageChoice>,
835    raw_choices: Vec<(Vec<Tensor>, Vec<u32>)>,
836    completion_choices: Vec<(f32, CompletionChoice)>,
837    pub chat_streaming_chunks: Vec<ChunkChoice>,
838    pub completion_streaming_chunks: Vec<CompletionChunkChoice>,
839    pub is_streaming: bool,
840    pub is_chat: bool,
841}
842
843impl SequenceGroup {
844    pub fn new(
845        n_choices: usize,
846        is_streaming: bool,
847        is_chat: bool,
848        best_of: Option<usize>,
849    ) -> Self {
850        Self {
851            choices: Vec::new(),
852            image_choices: Vec::new(),
853            raw_choices: Vec::new(),
854            completion_choices: Vec::new(),
855            n_choices,
856            total_prompt_toks: 0,
857            total_toks: 0,
858            total_prompt_time: 0,
859            total_time: 0,
860            total_completion_time: 0,
861            chat_streaming_chunks: Vec::new(),
862            completion_streaming_chunks: Vec::new(),
863            is_streaming,
864            is_chat,
865            best_of,
866        }
867    }
868
869    pub fn get_choices(&self) -> &[Choice] {
870        &self.choices
871    }
872
873    /// This may apply the best_of.
874    pub fn get_completion_choices(&self) -> Vec<CompletionChoice> {
875        if let Some(best_of) = self.best_of {
876            let mut choices = self.completion_choices.clone();
877            // Sort by descending logprobs
878            choices.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("No ordering."));
879            choices
880                .into_iter()
881                .take(best_of)
882                .map(|(_, x)| x)
883                .collect::<Vec<_>>()
884        } else {
885            self.completion_choices
886                .clone()
887                .into_iter()
888                .map(|(_, x)| x)
889                .collect::<Vec<_>>()
890        }
891    }
892
893    pub fn get_image_choices(&self) -> &[ImageChoice] {
894        &self.image_choices
895    }
896
897    pub fn get_usage(&self) -> Usage {
898        #[allow(clippy::cast_precision_loss)]
899        Usage {
900            completion_tokens: self.total_toks - self.total_prompt_toks,
901            prompt_tokens: self.total_prompt_toks,
902            total_tokens: self.total_toks,
903            avg_tok_per_sec: (self.total_toks as f32 / self.total_time as f32) * 1000.,
904            avg_prompt_tok_per_sec: (self.total_prompt_toks as f32 / self.total_prompt_time as f32)
905                * 1000.,
906            avg_compl_tok_per_sec: ((self.total_toks - self.total_prompt_toks) as f32
907                / self.total_completion_time as f32)
908                * 1000.,
909            total_time_sec: self.total_time as f32 / 1000.,
910            total_completion_time_sec: self.total_completion_time as f32 / 1000.,
911            total_prompt_time_sec: self.total_prompt_time as f32 / 1000.,
912        }
913    }
914
915    pub async fn maybe_send_chat_done_response(
916        &self,
917        response: ChatCompletionResponse,
918        sender: Sender<Response>,
919    ) -> Result<(), SendError<Response>> {
920        if self.choices.len() == self.n_choices {
921            sender.send(Response::Done(response)).await?;
922        }
923
924        Ok(())
925    }
926
927    pub async fn maybe_send_raw_done_response(
928        &self,
929        sender: Sender<Response>,
930    ) -> Result<(), SendError<Response>> {
931        if self.raw_choices.len() == self.n_choices {
932            assert_eq!(self.raw_choices.len(), 1);
933            let (logits_chunks, tokens) = self.raw_choices[0].clone();
934            sender
935                .send(Response::Raw {
936                    logits_chunks,
937                    tokens,
938                })
939                .await?;
940        }
941
942        Ok(())
943    }
944
945    pub async fn maybe_send_image_gen_response(
946        &self,
947        response: ImageGenerationResponse,
948        sender: Sender<Response>,
949    ) -> Result<(), SendError<Response>> {
950        if self.image_choices.len() == self.n_choices {
951            sender.send(Response::ImageGeneration(response)).await?;
952        }
953
954        Ok(())
955    }
956
957    pub async fn maybe_send_streaming_response(
958        &mut self,
959        seq: &Sequence,
960        model: String,
961        usage_opt: Option<Usage>,
962    ) -> Result<(), Box<SendError<Response>>> {
963        if self.chat_streaming_chunks.len() == self.n_choices && self.is_streaming {
964            let mut swap_streaming_chunks = vec![];
965
966            std::mem::swap(&mut swap_streaming_chunks, &mut self.chat_streaming_chunks);
967
968            seq.responder()
969                .send(Response::Chunk(ChatCompletionChunkResponse {
970                    id: seq.id.to_string(),
971                    choices: swap_streaming_chunks,
972                    created: seq.timestamp,
973                    model: model.clone(),
974                    system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
975                    object: "chat.completion.chunk".to_string(),
976                    usage: usage_opt,
977                }))
978                .await?;
979        } else if self.completion_streaming_chunks.len() == self.n_choices && self.is_streaming {
980            let mut swap_streaming_chunks = vec![];
981
982            std::mem::swap(
983                &mut swap_streaming_chunks,
984                &mut self.completion_streaming_chunks,
985            );
986
987            seq.responder()
988                .send(Response::CompletionChunk(CompletionChunkResponse {
989                    id: seq.id.to_string(),
990                    choices: swap_streaming_chunks,
991                    created: seq.timestamp,
992                    model: model.clone(),
993                    system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
994                    object: "text_completion".to_string(),
995                }))
996                .await?;
997        }
998        Ok(())
999    }
1000
1001    pub async fn maybe_send_completion_done_response(
1002        &self,
1003        response: CompletionResponse,
1004        sender: Sender<Response>,
1005    ) -> Result<(), Box<SendError<Response>>> {
1006        if self.completion_choices.len() == self.n_choices {
1007            sender.send(Response::CompletionDone(response)).await?;
1008        }
1009        Ok(())
1010    }
1011}