mistralrs_core/
sequence.rs

1use crate::{
2    get_mut_arcmutex, get_mut_group,
3    paged_attention::PhysicalTokenBlock,
4    pipeline::{text_models_inputs_processor::PagedAttentionMeta, LayerCaches},
5    response::{ChatCompletionChunkResponse, Choice, ChunkChoice, Response, SYSTEM_FINGERPRINT},
6    sampler::{Logprobs, Sampler},
7    AudioInput, ChatCompletionResponse, Usage,
8};
9use crate::{
10    paged_attention::{BlockEngineSequence, LogicalTokenBlock},
11    pipeline::{DiffusionGenerationParams, KvCache},
12    response::CompletionChoice,
13    tools::ToolCallingMatcher,
14    CompletionChunkChoice, CompletionChunkResponse, CompletionResponse, ImageChoice,
15    ImageGenerationResponse, ImageGenerationResponseFormat,
16};
17use candle_core::Tensor;
18use std::{
19    fmt::Display,
20    hash::{DefaultHasher, Hash, Hasher},
21    sync::{Arc, RwLock},
22    time::{SystemTime, UNIX_EPOCH},
23};
24use tokio::sync::{
25    mpsc::{error::SendError, Sender},
26    Mutex, MutexGuard,
27};
28
29#[derive(Clone, Copy, PartialEq, Debug)]
30pub enum StopReason {
31    Eos,
32    StopTok(u32),
33    Length(usize),
34    ModelLength(usize),
35    StopString {
36        stop_string_idx: usize,
37        completion_bytes_pos: usize,
38    },
39    Canceled,
40    GeneratedImage,
41    GeneratedSpeech,
42    ToolCalls,
43}
44
45impl Display for StopReason {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            StopReason::Eos => write!(f, "stop"),
49            StopReason::Length(_) | StopReason::ModelLength(_) => write!(f, "length"),
50            StopReason::StopTok(_) | StopReason::StopString { .. } => write!(f, "stop"),
51            StopReason::Canceled => write!(f, "canceled"),
52            StopReason::GeneratedImage => write!(f, "generated_image"),
53            StopReason::GeneratedSpeech => write!(f, "generated_speech"),
54            StopReason::ToolCalls => write!(f, "tool_calls"),
55        }
56    }
57}
58
59#[derive(Clone, Copy, PartialEq, Debug)]
60pub enum SequenceState {
61    Done(StopReason),
62    RunningPrompt,
63    RunningCompletion,
64    Waiting,
65    Error,
66    RunningPrefillPrompt,
67    // For PagedAttention:
68    FinishedAborted,
69    FinishedIgnored,
70    Swapped,
71}
72
73pub enum SequenceRecognizer {
74    Llguidance(Box<llguidance::Matcher>),
75    None,
76}
77
78enum SequenceCustomMetadata {
79    PagedAttention {
80        logical_token_blocks: Vec<LogicalTokenBlock>,
81        physical_blocks_prefill: Option<Vec<Arc<PhysicalTokenBlock>>>,
82        block_size: usize,
83    },
84    None,
85}
86
87macro_rules! blocks_to_add_new_tok {
88    ($logical_token_blocks:expr) => {{
89        let last = $logical_token_blocks.last();
90        if !last.is_some_and(|last| last.is_full() || last.is_empty()) {
91            // If we have space
92            0
93        } else {
94            1
95        }
96    }};
97}
98
99pub(crate) fn util_append_token_to_blocks(
100    tok: usize,
101    logical_token_blocks: &mut Vec<LogicalTokenBlock>,
102    block_size: usize,
103) {
104    let last = logical_token_blocks.last_mut();
105    match last {
106        Some(last) => {
107            last.append_token_id(tok);
108        }
109        None => {
110            logical_token_blocks.push(LogicalTokenBlock::new(block_size));
111            logical_token_blocks
112                .last_mut()
113                .unwrap()
114                .append_token_id(tok);
115        }
116    }
117    if logical_token_blocks.last().as_ref().unwrap().is_full() {
118        logical_token_blocks.push(LogicalTokenBlock::new(block_size));
119    }
120}
121
122impl SequenceCustomMetadata {
123    fn append_token_to_blocks(&mut self, tok: usize) {
124        match self {
125            Self::PagedAttention {
126                logical_token_blocks,
127                physical_blocks_prefill: _,
128                block_size,
129            } => {
130                util_append_token_to_blocks(tok, logical_token_blocks, *block_size);
131            }
132            Self::None => (),
133        }
134    }
135
136    fn pop_token_from_blocks(&mut self) {
137        match self {
138            Self::PagedAttention {
139                logical_token_blocks,
140                physical_blocks_prefill: _,
141                block_size: _,
142            } => {
143                let last = logical_token_blocks.last_mut().unwrap();
144                last.pop_token();
145            }
146            Self::None => (),
147        }
148    }
149
150    fn append_tokens_to_blocks(&mut self, toks: Vec<usize>) {
151        for tok in toks {
152            self.append_token_to_blocks(tok);
153        }
154    }
155
156    fn remove_tokens_from_blocks(&mut self, n: usize) {
157        for _ in 0..n {
158            self.pop_token_from_blocks();
159        }
160    }
161}
162
163#[derive(Clone, Copy)]
164pub enum SeqStepType {
165    PromptAndDecode,
166    OneShot,
167}
168
169pub struct SequenceImages {
170    images: Vec<image::DynamicImage>,
171    hashes: Vec<u64>,
172}
173
174#[derive(Clone)]
175pub struct SequenceAudios {
176    audios: Vec<AudioInput>,
177    hashes: Vec<u64>,
178}
179
180impl SequenceAudios {
181    fn new(input_audios: Vec<AudioInput>) -> Self {
182        let hashes = input_audios.iter().map(|a| {
183            let mut hasher = DefaultHasher::new();
184            for s in &a.samples {
185                s.to_bits().hash(&mut hasher);
186            }
187            a.sample_rate.hash(&mut hasher);
188            hasher.finish()
189        });
190        Self {
191            hashes: hashes.collect(),
192            audios: input_audios,
193        }
194    }
195
196    fn clone_audios(&self) -> Vec<AudioInput> {
197        self.audios.clone()
198    }
199
200    fn audios(&self) -> &[AudioInput] {
201        &self.audios
202    }
203
204    fn audios_mut(&mut self) -> &mut Vec<AudioInput> {
205        &mut self.audios
206    }
207
208    fn hashes(&self) -> &[u64] {
209        &self.hashes
210    }
211
212    fn keep_num_audios(&mut self, audios_to_keep: usize) {
213        if self.audios.len() > audios_to_keep {
214            let start = self.audios.len() - audios_to_keep;
215            self.audios = self.audios[start..].to_vec();
216            // Do not do this because we need all the hashes later in the prefix cacher.
217            // self.hashes = self.hashes[start..].to_vec();
218        }
219    }
220}
221
222impl SequenceImages {
223    fn new(input_images: Vec<image::DynamicImage>) -> Self {
224        let hashes = input_images.iter().map(|x| {
225            let mut hasher = DefaultHasher::new();
226            x.as_bytes().hash(&mut hasher);
227            hasher.finish()
228        });
229        Self {
230            hashes: hashes.collect(),
231            images: input_images,
232        }
233    }
234
235    fn clone_images(&self) -> Vec<image::DynamicImage> {
236        self.images.clone()
237    }
238
239    fn images(&self) -> &[image::DynamicImage] {
240        &self.images
241    }
242
243    fn images_mut(&mut self) -> &mut Vec<image::DynamicImage> {
244        &mut self.images
245    }
246
247    fn hashes(&self) -> &[u64] {
248        &self.hashes
249    }
250
251    fn keep_num_images(&mut self, images_to_keep: usize) {
252        if self.images.len() > images_to_keep {
253            let start = self.images.len() - images_to_keep;
254            self.images = self.images[start..].to_vec();
255            // Do not do this because we need all the hashes later in the prefix cacher.
256            // self.hashes = self.hashes[start..].to_vec();
257        }
258    }
259}
260
261// Holds all multimodal (vision/diffusion) data for a Sequence.
262pub struct MultimodalData {
263    pub input_images: Option<SequenceImages>,
264    pub input_audios: Option<SequenceAudios>,
265    pub cached_pixel_values: Option<Tensor>,
266    pub cached_img_thw: Option<Tensor>,
267    pub cached_vid_thw: Option<Tensor>,
268    pub has_changed_prompt: bool,
269    pub image_gen_response_format: Option<ImageGenerationResponseFormat>,
270    pub diffusion_params: Option<DiffusionGenerationParams>,
271}
272
273impl MultimodalData {
274    pub fn new(
275        input_images: Option<Vec<image::DynamicImage>>,
276        input_audios: Option<Vec<AudioInput>>,
277        image_gen_response_format: Option<ImageGenerationResponseFormat>,
278        diffusion_params: Option<DiffusionGenerationParams>,
279    ) -> Self {
280        MultimodalData {
281            input_images: input_images.map(SequenceImages::new),
282            input_audios: input_audios.map(SequenceAudios::new),
283            cached_pixel_values: None,
284            cached_img_thw: None,
285            cached_vid_thw: None,
286            has_changed_prompt: false,
287            image_gen_response_format,
288            diffusion_params,
289        }
290    }
291
292    pub fn take_images(&mut self) -> Option<Vec<image::DynamicImage>> {
293        if self.has_changed_prompt {
294            if let Some(input_images) = self.input_images.as_mut() {
295                let mut images = Vec::new();
296                std::mem::swap(&mut images, input_images.images_mut());
297                Some(images)
298            } else {
299                None
300            }
301        } else {
302            self.input_images.as_ref().map(|imgs| imgs.clone_images())
303        }
304    }
305
306    pub fn clone_images(&self) -> Option<Vec<image::DynamicImage>> {
307        self.input_images.as_ref().map(|imgs| imgs.clone_images())
308    }
309
310    pub fn images(&self) -> Option<&[image::DynamicImage]> {
311        self.input_images.as_ref().map(|imgs| imgs.images())
312    }
313
314    pub fn image_hashes(&self) -> Option<&[u64]> {
315        self.input_images.as_ref().map(|imgs| imgs.hashes())
316    }
317
318    pub fn has_images(&self) -> bool {
319        self.input_images
320            .as_ref()
321            .is_some_and(|imgs| !imgs.images().is_empty())
322    }
323
324    pub fn take_audios(&mut self) -> Option<Vec<AudioInput>> {
325        if self.has_changed_prompt {
326            if let Some(input_audios) = self.input_audios.as_mut() {
327                let mut audios = Vec::new();
328                std::mem::swap(&mut audios, input_audios.audios_mut());
329                Some(audios)
330            } else {
331                None
332            }
333        } else {
334            self.input_audios.as_ref().map(|imgs| imgs.clone_audios())
335        }
336    }
337
338    pub fn clone_audios(&self) -> Option<Vec<AudioInput>> {
339        self.input_audios.as_ref().map(|a| a.clone_audios())
340    }
341
342    pub fn audios(&self) -> Option<&[AudioInput]> {
343        self.input_audios.as_ref().map(|a| a.audios())
344    }
345
346    pub fn audio_hashes(&self) -> Option<&[u64]> {
347        self.input_audios.as_ref().map(|a| a.hashes())
348    }
349
350    pub fn has_audios(&self) -> bool {
351        self.input_audios
352            .as_ref()
353            .is_some_and(|a| !a.audios().is_empty())
354    }
355
356    pub fn keep_num_audios(&mut self, audios_to_keep: usize) {
357        if let Some(auds) = self.input_audios.as_mut() {
358            auds.keep_num_audios(audios_to_keep)
359        }
360    }
361
362    pub fn keep_num_images(&mut self, images_to_keep: usize) {
363        if let Some(imgs) = self.input_images.as_mut() {
364            imgs.keep_num_images(images_to_keep)
365        }
366    }
367
368    pub fn image_gen_response_format(&self) -> Option<ImageGenerationResponseFormat> {
369        self.image_gen_response_format
370    }
371
372    pub fn diffusion_params(&self) -> Option<DiffusionGenerationParams> {
373        self.diffusion_params.clone()
374    }
375}
376
377pub struct Sequence {
378    // Metadata, const
379    id: usize,
380    prompt_len: usize,
381    max_len: Option<usize>,
382    timestamp: u128,
383    sampler: Arc<Sampler>,
384    stop_tokens: Vec<u32>,
385    stop_strings: Vec<String>,
386    return_logprobs: bool,
387    responder: Sender<Response>,
388    response_index: usize,
389    creation_time: u64,
390    prompt: String,
391    sequence_stepping_type: SeqStepType,
392    pub(crate) return_raw_logits: bool,
393    token_offset: usize,
394    eos_tokens: Vec<u32>,
395
396    // Multimodal data (images, diffusion settings, pixel caches)
397    pub multimodal: MultimodalData,
398
399    // Completion requests
400    suffix: Option<String>,
401    prefix: Option<String>,
402
403    // Speculative
404    is_tmp: bool,
405
406    // Prefix caching
407    prefill_prompt_toks: Option<Vec<u32>>,
408
409    // Cache
410    normal_cache: Vec<Option<KvCache>>,
411    normal_draft_cache: Vec<Option<KvCache>>,
412    scaling_cache: Option<Tensor>,
413    cache: LayerCaches,
414    draft_cache: LayerCaches,
415    xlora_cache: Option<LayerCaches>,
416
417    // Preallocated KV cache (k,v)
418    seq_preallocated_cache: Option<(Tensor, Tensor)>,
419
420    // Mutables
421    tokens: Vec<u32>,
422    logprobs: Vec<Logprobs>,
423    cumulative_logprob: f32,
424    last_logprob: f32,
425    last_completion_bytes_len: usize,
426    last_is_done: Option<StopReason>,
427    completion_bytes: Vec<u8>,
428    stream_idx: usize,
429    pub recognizer: SequenceRecognizer,
430    scheduling_urgency: usize, // The number of passes since scheduling
431    waitlisted_count: usize, // Used in PagedAttention to alert the user when a sequence repeatedly cannot be scheduled
432
433    // GPU things
434    pub prompt_tok_per_sec: f32,
435    pub prompt_timestamp: Option<u128>,
436    pub total_prompt_time: Option<u128>,
437    group: Arc<Mutex<SequenceGroup>>,
438    state: RwLock<SequenceState>,
439
440    // Custom backend metadata
441    custom_metadata: SequenceCustomMetadata,
442
443    // Tool calls
444    pub tools: Option<Arc<ToolCallingMatcher>>,
445}
446
447impl BlockEngineSequence for Sequence {
448    fn blocks_to_add_new_tok(&self) -> usize {
449        match &self.custom_metadata {
450            SequenceCustomMetadata::PagedAttention {
451                logical_token_blocks,
452                physical_blocks_prefill: _,
453                block_size: _,
454            } => {
455                blocks_to_add_new_tok!(logical_token_blocks)
456            }
457            SequenceCustomMetadata::None => unreachable!(),
458        }
459    }
460
461    fn get_id(&self) -> usize {
462        self.id
463    }
464
465    fn logical_token_blocks(&self) -> &[LogicalTokenBlock] {
466        match &self.custom_metadata {
467            SequenceCustomMetadata::PagedAttention {
468                logical_token_blocks,
469                physical_blocks_prefill: _,
470                block_size: _,
471            } => logical_token_blocks,
472            SequenceCustomMetadata::None => unreachable!(),
473        }
474    }
475
476    fn take_physical_blocks_prefill(&mut self) -> Option<Vec<Arc<PhysicalTokenBlock>>> {
477        match &mut self.custom_metadata {
478            SequenceCustomMetadata::PagedAttention {
479                logical_token_blocks: _,
480                physical_blocks_prefill,
481                block_size: _,
482            } => physical_blocks_prefill.take(),
483            SequenceCustomMetadata::None => None,
484        }
485    }
486
487    fn increment_waitlist_count(&mut self) -> usize {
488        let prev = self.waitlisted_count;
489        self.waitlisted_count += 1;
490        prev
491    }
492}
493
494impl Sequence {
495    #[allow(clippy::too_many_arguments)]
496    pub fn new_waiting(
497        tokens: Vec<u32>,
498        prompt: String,
499        id: usize,
500        timestamp: u128,
501        layers: usize,
502        responder: Sender<Response>,
503        sampler: Sampler,
504        stop_tokens: Vec<u32>,
505        stop_strings: Vec<String>,
506        max_len: Option<usize>,
507        return_logprobs: bool,
508        is_xlora: bool,
509        group: Arc<Mutex<SequenceGroup>>,
510        response_index: usize,
511        creation_time: u64,
512        recognizer: SequenceRecognizer,
513        suffix: Option<String>,
514        prefix: Option<String>,
515        input_images: Option<Vec<image::DynamicImage>>,
516        input_audios: Option<Vec<AudioInput>>,
517        // Paged attention
518        block_size: Option<usize>,
519        //
520        tools: Option<Arc<ToolCallingMatcher>>,
521        image_gen_response_format: Option<ImageGenerationResponseFormat>,
522        sequence_stepping_type: SeqStepType,
523        diffusion_params: Option<DiffusionGenerationParams>,
524        // Preallocated KV cache (k,v)
525        seq_preallocated_cache: Option<(Tensor, Tensor)>,
526        //
527        return_raw_logits: bool,
528        eos_tokens: Vec<u32>,
529    ) -> Self {
530        let prompt_len = tokens.len();
531        let mut custom_metadata = if let Some(block_size) = block_size {
532            SequenceCustomMetadata::PagedAttention {
533                logical_token_blocks: Vec::new(),
534                physical_blocks_prefill: None,
535                block_size,
536            }
537        } else {
538            SequenceCustomMetadata::None
539        };
540        custom_metadata
541            .append_tokens_to_blocks(tokens.iter().map(|x| *x as usize).collect::<Vec<_>>());
542        Self {
543            tokens,
544            prompt,
545            logprobs: Vec::new(),
546            prompt_len,
547            id,
548            timestamp,
549            state: RwLock::new(SequenceState::Waiting),
550            normal_cache: vec![None; layers],
551            normal_draft_cache: vec![None; layers],
552            cache: vec![None; layers],
553            draft_cache: vec![None; layers],
554            xlora_cache: if is_xlora {
555                Some(vec![None; layers])
556            } else {
557                None
558            },
559            seq_preallocated_cache,
560            responder,
561            sampler: sampler.into(),
562            stop_tokens,
563            stop_strings,
564            max_len,
565            return_logprobs,
566            prompt_tok_per_sec: 0.,
567            prompt_timestamp: None,
568            group,
569            scaling_cache: None,
570            response_index,
571            creation_time,
572            recognizer,
573            prefill_prompt_toks: None,
574            suffix,
575            prefix,
576            cumulative_logprob: 0.,
577            completion_bytes: Vec::new(),
578            stream_idx: 0,
579            last_completion_bytes_len: 0,
580            last_logprob: 0.0,
581            last_is_done: None,
582            is_tmp: false,
583            scheduling_urgency: 0,
584            // Multimodal data
585            multimodal: MultimodalData::new(
586                input_images,
587                input_audios,
588                image_gen_response_format,
589                diffusion_params,
590            ),
591            custom_metadata,
592            tools,
593            sequence_stepping_type,
594            return_raw_logits,
595            token_offset: 0,
596            eos_tokens,
597            total_prompt_time: None,
598            waitlisted_count: 0,
599        }
600    }
601
602    pub fn add_urgency(mut self) -> Self {
603        self.scheduling_urgency += 1;
604        self
605    }
606
607    pub fn reset_urgency(mut self) -> Self {
608        self.scheduling_urgency = 0;
609        self
610    }
611
612    /// Simple metric: (scheduling urgency) + log2(length)
613    /// Takes into account: urgency (scales linear) and length (scales logarithmic)
614    /// Scaling urgency is the number of scheduling passes where we have not been scheduled.
615    pub fn compute_priority(&self) -> f64 {
616        #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
617        (self.scheduling_urgency as f64) + (self.len() as f64).log2()
618    }
619
620    pub fn prefill_v2_normal(
621        mut self,
622        cache: Vec<Option<KvCache>>,
623        toks: Vec<u32>,
624        offset: usize,
625    ) -> Self {
626        self.normal_cache = cache;
627        self.prefill_prompt_toks = Some(toks);
628        self.set_state(SequenceState::RunningPrefillPrompt);
629        self.token_offset = offset;
630        self
631    }
632
633    pub fn prefill_v2_paged(
634        mut self,
635        logical_blocks: Vec<LogicalTokenBlock>,
636        physical_blocks: Vec<Arc<PhysicalTokenBlock>>,
637        toks: Vec<u32>,
638        offset: usize,
639    ) -> Self {
640        self.prefill_prompt_toks = Some(toks);
641        self.set_state(SequenceState::RunningPrefillPrompt);
642        self.token_offset = offset;
643
644        if let SequenceCustomMetadata::PagedAttention {
645            logical_token_blocks,
646            physical_blocks_prefill,
647            block_size: _,
648        } = &mut self.custom_metadata
649        {
650            *logical_token_blocks = logical_blocks;
651            *physical_blocks_prefill = Some(physical_blocks);
652        }
653
654        self
655    }
656
657    /// This is the number of tokens. If the KV cache is Some, then it will use that.
658    pub fn len(&self) -> usize {
659        if let Some(toks) = &self.prefill_prompt_toks {
660            return toks.len();
661        }
662        if self.is_tmp {
663            return self.tokens.len();
664        }
665        // Use xlora cache first because of non granular
666        if self.xlora_cache.as_ref().is_some_and(|c| c[0].is_some()) {
667            self.xlora_cache.as_ref().unwrap()[0]
668                .as_ref()
669                .unwrap()
670                .0
671                .dims()[2]
672                + 1
673        } else if let Some((_, x)) = &self.cache[0] {
674            x.dims()[2] + 1
675        } else {
676            self.tokens.len()
677        }
678    }
679
680    pub fn id(&self) -> &usize {
681        &self.id
682    }
683
684    pub fn is_running(&self) -> bool {
685        matches!(
686            *self.state.read().unwrap(),
687            SequenceState::RunningCompletion | SequenceState::RunningPrompt // | SequenceState::RunningPrefillPrompt
688        )
689    }
690
691    pub fn is_completion(&self) -> bool {
692        matches!(
693            *self.state.read().unwrap(),
694            SequenceState::RunningCompletion
695        )
696    }
697
698    pub fn is_prompt(&self) -> bool {
699        matches!(
700            *self.state.read().unwrap(),
701            SequenceState::RunningPrompt | SequenceState::RunningPrefillPrompt
702        )
703    }
704
705    pub fn is_waiting(&self) -> bool {
706        matches!(*self.state.read().unwrap(), SequenceState::Waiting)
707    }
708
709    pub fn is_finished_paged_attn(&self) -> bool {
710        matches!(
711            *self.state.read().unwrap(),
712            SequenceState::FinishedAborted
713                | SequenceState::FinishedIgnored
714                | SequenceState::Done(_)
715        )
716    }
717
718    pub fn get_toks(&self) -> &[u32] {
719        if let Some(toks) = &self.prefill_prompt_toks {
720            return toks;
721        }
722        &self.tokens
723    }
724
725    pub fn get_initial_prompt(&self) -> &str {
726        &self.prompt
727    }
728
729    pub fn set_initial_prompt(&mut self, new: String) {
730        self.prompt = new;
731    }
732
733    pub fn token_offset(&self) -> usize {
734        self.token_offset
735    }
736
737    /// This will also set prompt_len
738    pub(crate) fn set_toks_and_reallocate(
739        &mut self,
740        toks: Vec<u32>,
741        paged_attn_metadata: Option<&mut PagedAttentionMeta>,
742    ) {
743        self.tokens.clone_from(&toks);
744        self.prompt_len = self.tokens.len();
745        // Handle possible block engine
746        match &mut self.custom_metadata {
747            SequenceCustomMetadata::PagedAttention {
748                logical_token_blocks,
749                physical_blocks_prefill: _,
750                block_size: _,
751            } => {
752                logical_token_blocks.clear();
753            }
754            SequenceCustomMetadata::None => (),
755        }
756        self.custom_metadata
757            .append_tokens_to_blocks(toks.iter().map(|x| *x as usize).collect::<Vec<_>>());
758
759        if let Some(metadata) = paged_attn_metadata {
760            // Free and then reallocate as appropriate
761            get_mut_arcmutex!(metadata.block_engine).free_sequence(*self.id());
762            get_mut_arcmutex!(metadata.block_engine).allocate(self);
763        }
764    }
765
766    pub fn completion_bytes(&self) -> &[u8] {
767        &self.completion_bytes
768    }
769
770    pub fn preallocated_cache(&self) -> Option<&(Tensor, Tensor)> {
771        self.seq_preallocated_cache.as_ref()
772    }
773
774    pub fn normal_cache(&mut self) -> &mut Vec<Option<KvCache>> {
775        &mut self.normal_cache
776    }
777
778    pub fn normal_draft_cache(&mut self) -> &mut Vec<Option<KvCache>> {
779        &mut self.normal_draft_cache
780    }
781
782    pub fn cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
783        &mut self.cache
784    }
785
786    pub fn draft_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
787        &mut self.draft_cache
788    }
789
790    pub fn xlora_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
791        self.xlora_cache.as_mut().expect("No X-LoRA cache.")
792    }
793
794    pub fn scaling_cache(&mut self) -> &mut Option<Tensor> {
795        &mut self.scaling_cache
796    }
797
798    pub fn is_xlora(&self) -> bool {
799        self.xlora_cache.is_some()
800    }
801
802    pub fn sampler(&mut self) -> Arc<Sampler> {
803        self.sampler.clone()
804    }
805
806    /// Add a some prefill tokens. Only meant for internal speculative decoding usage.
807    pub fn set_prefill_toks(&mut self, toks: Vec<u32>) {
808        self.prefill_prompt_toks = Some(toks)
809    }
810
811    /// Remove the prefill tokens.
812    pub fn reset_prefill_toks(&mut self) {
813        self.prefill_prompt_toks = None
814    }
815
816    /// Internal api to add one raw token.
817    pub(crate) fn add_tmp_tok(&mut self, tok: u32) {
818        self.is_tmp = true;
819        self.tokens.push(tok);
820        // Handle possible block engine
821        self.custom_metadata.append_token_to_blocks(tok as usize);
822    }
823
824    /// Internal api to remove n raw tokens.
825    pub(crate) fn remove_tmp_tok(&mut self, n: usize) {
826        self.is_tmp = false;
827        self.tokens.truncate(self.tokens.len() - n);
828        // Handle possible block engine
829        self.custom_metadata.remove_tokens_from_blocks(n);
830    }
831
832    pub fn add_token(
833        &mut self,
834        tok: Logprobs,
835        completion_bytes: Vec<u8>,
836        is_done: &Option<StopReason>,
837    ) {
838        let stopped_by_token = matches!(
839            is_done,
840            Some(StopReason::Eos) | Some(StopReason::StopTok(_))
841        );
842        if !stopped_by_token {
843            // Completion bytes is used to check for stop strings, and as the response buffer.
844            // We don't need to add stop tokens to the completion bytes to check for stop strings.
845            // And by not adding it here, we can avoid having to delete these tokens from the output.
846            self.completion_bytes.extend_from_slice(&completion_bytes);
847            self.last_completion_bytes_len = completion_bytes.len();
848        }
849        self.last_logprob = tok.logprob;
850        self.last_is_done = *is_done;
851
852        self.custom_metadata
853            .append_token_to_blocks(tok.token as usize);
854
855        self.cumulative_logprob += tok.logprob;
856        self.tokens.push(tok.token);
857        self.logprobs.push(tok);
858        self.reset_prefill_toks();
859    }
860
861    pub fn responder(&self) -> Sender<Response> {
862        self.responder.clone()
863    }
864
865    pub fn creation_time(&self) -> u64 {
866        self.creation_time
867    }
868
869    pub fn set_state(&self, state: SequenceState) {
870        if matches!(state, SequenceState::Error) {
871            get_mut_group!(self).n_choices = get_mut_group!(self).n_choices.saturating_sub(1);
872        }
873        *self.state.write().unwrap() = state;
874    }
875
876    pub fn getstate(&self) -> SequenceState {
877        *self.state.read().unwrap()
878    }
879
880    pub fn is_done(
881        &self,
882        tok: u32,
883        eos_tok: Option<&[u32]>,
884        max_model_len: usize,
885    ) -> Option<StopReason> {
886        let is_eos = match eos_tok {
887            Some(eos_tok) => eos_tok.contains(&tok),
888            None => false,
889        };
890        if is_eos {
891            Some(StopReason::Eos)
892        } else if matches!(
893            &*self.state.read().unwrap(),
894            SequenceState::Done(StopReason::Canceled)
895        ) {
896            Some(StopReason::Canceled)
897        } else if self.stop_tokens.contains(&tok) {
898            Some(StopReason::StopTok(tok))
899        } else if self.max_len.is_some()
900            && self.tokens.len().saturating_sub(self.prompt_len) == self.max_len.unwrap()
901        {
902            // add_token was already called
903            Some(StopReason::Length(self.max_len.unwrap()))
904        } else if self.tokens.len().saturating_sub(self.prompt_len) == max_model_len {
905            Some(StopReason::ModelLength(max_model_len))
906        } else {
907            if !self.stop_strings.is_empty() {
908                for (idx, s) in self.stop_strings.iter().enumerate() {
909                    if let Some(pos) = galil_seiferas::gs_find(&self.completion_bytes, s.as_bytes())
910                    {
911                        return Some(StopReason::StopString {
912                            stop_string_idx: idx,
913                            completion_bytes_pos: pos,
914                        });
915                    }
916                }
917            }
918            None
919        }
920    }
921
922    pub fn logprobs(&self) -> &[Logprobs] {
923        &self.logprobs
924    }
925
926    pub fn return_logprobs(&self) -> bool {
927        self.return_logprobs
928    }
929
930    pub fn prompt_tokens(&self) -> usize {
931        self.prompt_len
932    }
933
934    pub fn stop_strings(&self) -> &[String] {
935        &self.stop_strings
936    }
937
938    /// Returns the delta between the last two decoded sequences
939    pub fn get_delta(
940        &mut self,
941    ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
942        let new_decoded = self.peek_delta();
943        if matches!(new_decoded, Ok(Some(_))) {
944            self.stream_idx = self.completion_bytes.len();
945        }
946        new_decoded
947    }
948
949    /// Peeks at the delta between the last two decoded sequences, but does not advance the stream index.
950    pub fn peek_delta(&self) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
951        let is_first = self.stream_idx == 0;
952        let new_decoded = String::from_utf8_lossy(&self.completion_bytes[self.stream_idx..]);
953        // Check if the sequence ends with valid utf8, if not skip it as it probably is a multi token sequence
954        if new_decoded.ends_with('�') {
955            return Ok(None);
956        }
957
958        // The first token usually starts with a space. We don't want to add that to the delta.
959        // Since we're using the completion_bytes, we need to take care of that ourselves.
960        // Had we used HF's Tokenizer, it would have taken care of that for us.
961        if is_first {
962            return Ok(Some(new_decoded.trim_start().to_string()));
963        }
964        Ok(Some(new_decoded.to_string()))
965    }
966
967    pub fn timestamp(&self) -> u128 {
968        self.timestamp
969    }
970
971    pub fn prompt_timestamp(&self) -> Option<u128> {
972        self.prompt_timestamp
973    }
974
975    fn update_time_info(&self) {
976        let now = SystemTime::now()
977            .duration_since(UNIX_EPOCH)
978            .expect("Time travel has occurred!")
979            .as_millis();
980
981        if let Some(ts) = self.prompt_timestamp {
982            get_mut_group!(self).total_completion_time = now - ts;
983            get_mut_group!(self).total_prompt_time = self.total_prompt_time.unwrap();
984        }
985
986        get_mut_group!(self).total_time = now - self.timestamp;
987
988        get_mut_group!(self).total_prompt_toks = self.prompt_len;
989        get_mut_group!(self).total_toks = self.len();
990    }
991
992    pub fn add_image_choice_to_group(&self, choice: ImageChoice) {
993        get_mut_group!(self).image_choices.push(choice);
994    }
995
996    pub fn add_speech_pcm_to_group(&self, pcm: Arc<Vec<f32>>, rate: usize, channels: usize) {
997        get_mut_group!(self).speech_pcms.push((pcm, rate, channels));
998    }
999
1000    pub fn add_choice_to_group(&self, choice: Choice) {
1001        get_mut_group!(self).choices.push(choice);
1002        self.update_time_info();
1003    }
1004
1005    pub fn add_raw_choice_to_group(&self, logit_chunks: Vec<Tensor>) {
1006        get_mut_group!(self)
1007            .raw_choices
1008            .push((logit_chunks, self.tokens.clone()));
1009        self.update_time_info();
1010    }
1011
1012    pub fn add_completion_choice_to_group(&self, mut choice: CompletionChoice) {
1013        choice.text = format!(
1014            "{}{}{}",
1015            self.prefix.as_deref().unwrap_or(""),
1016            choice.text,
1017            self.suffix.as_deref().unwrap_or("")
1018        );
1019        get_mut_group!(self)
1020            .completion_choices
1021            .push((self.cumulative_logprob, choice));
1022        self.update_time_info();
1023    }
1024
1025    pub fn get_response_index(&self) -> usize {
1026        self.response_index
1027    }
1028
1029    pub fn get_mut_group(&self) -> MutexGuard<'_, SequenceGroup> {
1030        get_mut_group!(self)
1031    }
1032
1033    pub fn add_streaming_chunk_choice_to_group(&self, chunk: ChunkChoice) {
1034        get_mut_group!(self).chat_streaming_chunks.push(chunk);
1035        self.update_time_info();
1036    }
1037
1038    pub fn add_streaming_completion_chunk_choice_to_group(&self, chunk: CompletionChunkChoice) {
1039        get_mut_group!(self).completion_streaming_chunks.push(chunk);
1040        self.update_time_info();
1041    }
1042
1043    pub fn take_images(&mut self) -> Option<Vec<image::DynamicImage>> {
1044        self.multimodal.take_images()
1045    }
1046
1047    pub fn clone_images(&self) -> Option<Vec<image::DynamicImage>> {
1048        self.multimodal.clone_images()
1049    }
1050
1051    pub fn images(&self) -> Option<&[image::DynamicImage]> {
1052        self.multimodal.images()
1053    }
1054
1055    pub fn image_hashes(&self) -> Option<&[u64]> {
1056        self.multimodal.image_hashes()
1057    }
1058
1059    pub fn has_images(&self) -> bool {
1060        self.multimodal.has_images()
1061    }
1062
1063    pub fn take_audios(&mut self) -> Option<Vec<AudioInput>> {
1064        self.multimodal.take_audios()
1065    }
1066
1067    pub fn clone_audios(&self) -> Option<Vec<AudioInput>> {
1068        self.multimodal.clone_audios()
1069    }
1070
1071    pub fn audios(&self) -> Option<&[AudioInput]> {
1072        self.multimodal.audios()
1073    }
1074
1075    pub fn audio_hashes(&self) -> Option<&[u64]> {
1076        self.multimodal.audio_hashes()
1077    }
1078
1079    pub fn has_audios(&self) -> bool {
1080        self.multimodal.has_audios()
1081    }
1082
1083    /// Keep these last n audios
1084    pub fn keep_num_audios(&mut self, audios_to_keep: usize) {
1085        self.multimodal.keep_num_audios(audios_to_keep)
1086    }
1087
1088    /// Keep these last n images
1089    pub fn keep_num_images(&mut self, images_to_keep: usize) {
1090        self.multimodal.keep_num_images(images_to_keep)
1091    }
1092
1093    pub fn image_gen_response_format(&self) -> Option<ImageGenerationResponseFormat> {
1094        self.multimodal.image_gen_response_format()
1095    }
1096
1097    pub fn sequence_stepping_type(&self) -> &SeqStepType {
1098        &self.sequence_stepping_type
1099    }
1100
1101    pub fn get_diffusion_diffusion_params(&self) -> Option<DiffusionGenerationParams> {
1102        self.multimodal.diffusion_params()
1103    }
1104
1105    pub fn eos_tokens(&self) -> &[u32] {
1106        &self.eos_tokens
1107    }
1108}
1109
1110pub struct SequenceGroup {
1111    n_choices: usize, // The target number of choices to return. Can be decreased if an error is thrown.
1112    best_of: Option<usize>, // Top n seqs based on cumulative logprobs.
1113    pub total_prompt_toks: usize,
1114    pub total_toks: usize,
1115    pub total_prompt_time: u128,
1116    pub total_time: u128,
1117    pub total_completion_time: u128,
1118    choices: Vec<Choice>,
1119    image_choices: Vec<ImageChoice>,
1120    speech_pcms: Vec<(Arc<Vec<f32>>, usize, usize)>, // (pcm, rate, channels)
1121    raw_choices: Vec<(Vec<Tensor>, Vec<u32>)>,
1122    completion_choices: Vec<(f32, CompletionChoice)>,
1123    pub chat_streaming_chunks: Vec<ChunkChoice>,
1124    pub completion_streaming_chunks: Vec<CompletionChunkChoice>,
1125    pub is_streaming: bool,
1126    pub is_chat: bool,
1127}
1128
1129impl SequenceGroup {
1130    pub fn new(
1131        n_choices: usize,
1132        is_streaming: bool,
1133        is_chat: bool,
1134        best_of: Option<usize>,
1135    ) -> Self {
1136        Self {
1137            choices: Vec::new(),
1138            image_choices: Vec::new(),
1139            speech_pcms: Vec::new(),
1140            raw_choices: Vec::new(),
1141            completion_choices: Vec::new(),
1142            n_choices,
1143            total_prompt_toks: 0,
1144            total_toks: 0,
1145            total_prompt_time: 0,
1146            total_time: 0,
1147            total_completion_time: 0,
1148            chat_streaming_chunks: Vec::new(),
1149            completion_streaming_chunks: Vec::new(),
1150            is_streaming,
1151            is_chat,
1152            best_of,
1153        }
1154    }
1155
1156    pub fn get_choices(&self) -> &[Choice] {
1157        &self.choices
1158    }
1159
1160    /// This may apply the best_of.
1161    pub fn get_completion_choices(&self) -> Vec<CompletionChoice> {
1162        if let Some(best_of) = self.best_of {
1163            let mut choices = self.completion_choices.clone();
1164            // Sort by descending logprobs
1165            choices.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("No ordering."));
1166            choices
1167                .into_iter()
1168                .take(best_of)
1169                .map(|(_, x)| x)
1170                .collect::<Vec<_>>()
1171        } else {
1172            self.completion_choices
1173                .clone()
1174                .into_iter()
1175                .map(|(_, x)| x)
1176                .collect::<Vec<_>>()
1177        }
1178    }
1179
1180    pub fn get_image_choices(&self) -> &[ImageChoice] {
1181        &self.image_choices
1182    }
1183
1184    pub fn get_usage(&self) -> Usage {
1185        #[allow(clippy::cast_precision_loss)]
1186        Usage {
1187            completion_tokens: self.total_toks - self.total_prompt_toks,
1188            prompt_tokens: self.total_prompt_toks,
1189            total_tokens: self.total_toks,
1190            avg_tok_per_sec: (self.total_toks as f32 / self.total_time as f32) * 1000.,
1191            avg_prompt_tok_per_sec: (self.total_prompt_toks as f32 / self.total_prompt_time as f32)
1192                * 1000.,
1193            avg_compl_tok_per_sec: ((self.total_toks - self.total_prompt_toks) as f32
1194                / self.total_completion_time as f32)
1195                * 1000.,
1196            total_time_sec: self.total_time as f32 / 1000.,
1197            total_completion_time_sec: self.total_completion_time as f32 / 1000.,
1198            total_prompt_time_sec: self.total_prompt_time as f32 / 1000.,
1199        }
1200    }
1201
1202    pub async fn maybe_send_chat_done_response(
1203        &self,
1204        response: ChatCompletionResponse,
1205        sender: Sender<Response>,
1206    ) -> Result<(), SendError<Response>> {
1207        if self.choices.len() == self.n_choices {
1208            sender.send(Response::Done(response)).await?;
1209        }
1210
1211        Ok(())
1212    }
1213
1214    pub async fn maybe_send_raw_done_response(
1215        &self,
1216        sender: Sender<Response>,
1217    ) -> Result<(), SendError<Response>> {
1218        if self.raw_choices.len() == self.n_choices {
1219            assert_eq!(self.raw_choices.len(), 1);
1220            let (logits_chunks, tokens) = self.raw_choices[0].clone();
1221            sender
1222                .send(Response::Raw {
1223                    logits_chunks,
1224                    tokens,
1225                })
1226                .await?;
1227        }
1228
1229        Ok(())
1230    }
1231
1232    pub async fn maybe_send_image_gen_response(
1233        &self,
1234        response: ImageGenerationResponse,
1235        sender: Sender<Response>,
1236    ) -> Result<(), SendError<Response>> {
1237        if self.image_choices.len() == self.n_choices {
1238            sender.send(Response::ImageGeneration(response)).await?;
1239        }
1240
1241        Ok(())
1242    }
1243
1244    pub async fn maybe_send_speech_response(
1245        &self,
1246        sender: Sender<Response>,
1247    ) -> Result<(), SendError<Response>> {
1248        assert_eq!(self.speech_pcms.len(), 1);
1249
1250        let (pcm, rate, channels) = self.speech_pcms[0].clone();
1251        sender
1252            .send(Response::Speech {
1253                pcm,
1254                rate,
1255                channels,
1256            })
1257            .await?;
1258
1259        Ok(())
1260    }
1261
1262    pub async fn maybe_send_streaming_response(
1263        &mut self,
1264        seq: &Sequence,
1265        model: String,
1266        usage_opt: Option<Usage>,
1267    ) -> Result<(), Box<SendError<Response>>> {
1268        if self.chat_streaming_chunks.len() == self.n_choices && self.is_streaming {
1269            let mut swap_streaming_chunks = vec![];
1270
1271            std::mem::swap(&mut swap_streaming_chunks, &mut self.chat_streaming_chunks);
1272
1273            seq.responder()
1274                .send(Response::Chunk(ChatCompletionChunkResponse {
1275                    id: seq.id.to_string(),
1276                    choices: swap_streaming_chunks,
1277                    created: seq.timestamp,
1278                    model: model.clone(),
1279                    system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
1280                    object: "chat.completion.chunk".to_string(),
1281                    usage: usage_opt,
1282                }))
1283                .await?;
1284        } else if self.completion_streaming_chunks.len() == self.n_choices && self.is_streaming {
1285            let mut swap_streaming_chunks = vec![];
1286
1287            std::mem::swap(
1288                &mut swap_streaming_chunks,
1289                &mut self.completion_streaming_chunks,
1290            );
1291
1292            seq.responder()
1293                .send(Response::CompletionChunk(CompletionChunkResponse {
1294                    id: seq.id.to_string(),
1295                    choices: swap_streaming_chunks,
1296                    created: seq.timestamp,
1297                    model: model.clone(),
1298                    system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
1299                    object: "text_completion".to_string(),
1300                }))
1301                .await?;
1302        }
1303        Ok(())
1304    }
1305
1306    pub async fn maybe_send_completion_done_response(
1307        &self,
1308        response: CompletionResponse,
1309        sender: Sender<Response>,
1310    ) -> Result<(), Box<SendError<Response>>> {
1311        if self.completion_choices.len() == self.n_choices {
1312            sender.send(Response::CompletionDone(response)).await?;
1313        }
1314        Ok(())
1315    }
1316}