mistralrs_core/
sequence.rs

1use crate::{
2    get_mut_arcmutex, get_mut_group,
3    paged_attention::PhysicalTokenBlock,
4    pipeline::{text_models_inputs_processor::PagedAttentionMeta, LayerCaches},
5    response::{ChatCompletionChunkResponse, Choice, ChunkChoice, Response, SYSTEM_FINGERPRINT},
6    sampler::{Logprobs, Sampler},
7    AudioInput, ChatCompletionResponse, Usage,
8};
9use crate::{
10    paged_attention::{BlockEngineSequence, LogicalTokenBlock},
11    pipeline::{DiffusionGenerationParams, KvCache},
12    response::CompletionChoice,
13    tools::ToolCallingMatcher,
14    CompletionChunkChoice, CompletionChunkResponse, CompletionResponse, ImageChoice,
15    ImageGenerationResponse, ImageGenerationResponseFormat,
16};
17use candle_core::Tensor;
18use std::{
19    fmt::Display,
20    hash::{DefaultHasher, Hash, Hasher},
21    sync::{Arc, RwLock},
22    time::{SystemTime, UNIX_EPOCH},
23};
24use tokio::sync::{
25    mpsc::{error::SendError, Sender},
26    Mutex, MutexGuard,
27};
28
29#[derive(Clone, Copy, PartialEq, Debug)]
30pub enum StopReason {
31    Eos,
32    StopTok(u32),
33    Length(usize),
34    ModelLength(usize),
35    StopString {
36        stop_string_idx: usize,
37        completion_bytes_pos: usize,
38    },
39    Canceled,
40    GeneratedImage,
41    GeneratedSpeech,
42    ToolCalls,
43}
44
45impl Display for StopReason {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            StopReason::Eos => write!(f, "stop"),
49            StopReason::Length(_) | StopReason::ModelLength(_) => write!(f, "length"),
50            StopReason::StopTok(_) | StopReason::StopString { .. } => write!(f, "stop"),
51            StopReason::Canceled => write!(f, "canceled"),
52            StopReason::GeneratedImage => write!(f, "generated_image"),
53            StopReason::GeneratedSpeech => write!(f, "generated_speech"),
54            StopReason::ToolCalls => write!(f, "tool_calls"),
55        }
56    }
57}
58
59#[derive(Clone, Copy, PartialEq, Debug)]
60pub enum SequenceState {
61    Done(StopReason),
62    RunningPrompt,
63    RunningCompletion,
64    Waiting,
65    Error,
66    RunningPrefillPrompt,
67    // For PagedAttention:
68    FinishedAborted,
69    FinishedIgnored,
70    Swapped,
71}
72
73pub enum SequenceRecognizer {
74    Llguidance(Box<llguidance::Matcher>),
75    None,
76}
77
78enum SequenceCustomMetadata {
79    PagedAttention {
80        logical_token_blocks: Vec<LogicalTokenBlock>,
81        physical_blocks_prefill: Option<Vec<Arc<PhysicalTokenBlock>>>,
82        block_size: usize,
83    },
84    None,
85}
86
87macro_rules! blocks_to_add_new_tok {
88    ($logical_token_blocks:expr) => {{
89        let last = $logical_token_blocks.last();
90        if !last.is_some_and(|last| last.is_full() || last.is_empty()) {
91            // If we have space
92            0
93        } else {
94            1
95        }
96    }};
97}
98
99pub(crate) fn util_append_token_to_blocks(
100    tok: usize,
101    logical_token_blocks: &mut Vec<LogicalTokenBlock>,
102    block_size: usize,
103) {
104    let last = logical_token_blocks.last_mut();
105    match last {
106        Some(last) => {
107            last.append_token_id(tok);
108        }
109        None => {
110            logical_token_blocks.push(LogicalTokenBlock::new(block_size));
111            // SAFETY: We just pushed a block, so last_mut() will return Some
112            logical_token_blocks
113                .last_mut()
114                .expect("just pushed a block, vector cannot be empty")
115                .append_token_id(tok);
116        }
117    }
118    // SAFETY: At this point, we either had a block or just created one above,
119    // so the vector is guaranteed to be non-empty.
120    if logical_token_blocks
121        .last()
122        .expect("logical_token_blocks should not be empty after appending")
123        .is_full()
124    {
125        logical_token_blocks.push(LogicalTokenBlock::new(block_size));
126    }
127}
128
129impl SequenceCustomMetadata {
130    fn append_token_to_blocks(&mut self, tok: usize) {
131        match self {
132            Self::PagedAttention {
133                logical_token_blocks,
134                physical_blocks_prefill: _,
135                block_size,
136            } => {
137                util_append_token_to_blocks(tok, logical_token_blocks, *block_size);
138            }
139            Self::None => (),
140        }
141    }
142
143    fn pop_token_from_blocks(&mut self) {
144        match self {
145            Self::PagedAttention {
146                logical_token_blocks,
147                physical_blocks_prefill: _,
148                block_size: _,
149            } => {
150                let last = logical_token_blocks.last_mut().unwrap();
151                last.pop_token();
152            }
153            Self::None => (),
154        }
155    }
156
157    fn append_tokens_to_blocks(&mut self, toks: Vec<usize>) {
158        for tok in toks {
159            self.append_token_to_blocks(tok);
160        }
161    }
162
163    fn remove_tokens_from_blocks(&mut self, n: usize) {
164        for _ in 0..n {
165            self.pop_token_from_blocks();
166        }
167    }
168}
169
170#[derive(Clone, Copy)]
171pub enum SeqStepType {
172    PromptAndDecode,
173    OneShot,
174}
175
176pub struct SequenceImages {
177    images: Vec<image::DynamicImage>,
178    hashes: Vec<u64>,
179}
180
181#[derive(Clone)]
182pub struct SequenceAudios {
183    audios: Vec<AudioInput>,
184    hashes: Vec<u64>,
185}
186
187impl SequenceAudios {
188    fn new(input_audios: Vec<AudioInput>) -> Self {
189        let hashes = input_audios.iter().map(|a| {
190            let mut hasher = DefaultHasher::new();
191            for s in &a.samples {
192                s.to_bits().hash(&mut hasher);
193            }
194            a.sample_rate.hash(&mut hasher);
195            hasher.finish()
196        });
197        Self {
198            hashes: hashes.collect(),
199            audios: input_audios,
200        }
201    }
202
203    fn clone_audios(&self) -> Vec<AudioInput> {
204        self.audios.clone()
205    }
206
207    fn audios(&self) -> &[AudioInput] {
208        &self.audios
209    }
210
211    fn audios_mut(&mut self) -> &mut Vec<AudioInput> {
212        &mut self.audios
213    }
214
215    fn hashes(&self) -> &[u64] {
216        &self.hashes
217    }
218
219    fn keep_num_audios(&mut self, audios_to_keep: usize) {
220        if self.audios.len() > audios_to_keep {
221            let start = self.audios.len() - audios_to_keep;
222            self.audios = self.audios[start..].to_vec();
223            // Do not do this because we need all the hashes later in the prefix cacher.
224            // self.hashes = self.hashes[start..].to_vec();
225        }
226    }
227}
228
229impl SequenceImages {
230    fn new(input_images: Vec<image::DynamicImage>) -> Self {
231        let hashes = input_images.iter().map(|x| {
232            let mut hasher = DefaultHasher::new();
233            x.as_bytes().hash(&mut hasher);
234            hasher.finish()
235        });
236        Self {
237            hashes: hashes.collect(),
238            images: input_images,
239        }
240    }
241
242    fn clone_images(&self) -> Vec<image::DynamicImage> {
243        self.images.clone()
244    }
245
246    fn images(&self) -> &[image::DynamicImage] {
247        &self.images
248    }
249
250    fn images_mut(&mut self) -> &mut Vec<image::DynamicImage> {
251        &mut self.images
252    }
253
254    fn hashes(&self) -> &[u64] {
255        &self.hashes
256    }
257
258    fn keep_num_images(&mut self, images_to_keep: usize) {
259        if self.images.len() > images_to_keep {
260            let start = self.images.len() - images_to_keep;
261            self.images = self.images[start..].to_vec();
262            // Do not do this because we need all the hashes later in the prefix cacher.
263            // self.hashes = self.hashes[start..].to_vec();
264        }
265    }
266}
267
268// Holds all multimodal (vision/diffusion) data for a Sequence.
269pub struct MultimodalData {
270    pub input_images: Option<SequenceImages>,
271    pub input_audios: Option<SequenceAudios>,
272    pub cached_pixel_values: Option<Tensor>,
273    pub cached_img_thw: Option<Tensor>,
274    pub cached_vid_thw: Option<Tensor>,
275    pub has_changed_prompt: bool,
276    pub image_gen_response_format: Option<ImageGenerationResponseFormat>,
277    pub diffusion_params: Option<DiffusionGenerationParams>,
278}
279
280impl MultimodalData {
281    pub fn new(
282        input_images: Option<Vec<image::DynamicImage>>,
283        input_audios: Option<Vec<AudioInput>>,
284        image_gen_response_format: Option<ImageGenerationResponseFormat>,
285        diffusion_params: Option<DiffusionGenerationParams>,
286    ) -> Self {
287        MultimodalData {
288            input_images: input_images.map(SequenceImages::new),
289            input_audios: input_audios.map(SequenceAudios::new),
290            cached_pixel_values: None,
291            cached_img_thw: None,
292            cached_vid_thw: None,
293            has_changed_prompt: false,
294            image_gen_response_format,
295            diffusion_params,
296        }
297    }
298
299    pub fn take_images(&mut self) -> Option<Vec<image::DynamicImage>> {
300        if self.has_changed_prompt {
301            if let Some(input_images) = self.input_images.as_mut() {
302                let mut images = Vec::new();
303                std::mem::swap(&mut images, input_images.images_mut());
304                Some(images)
305            } else {
306                None
307            }
308        } else {
309            self.input_images.as_ref().map(|imgs| imgs.clone_images())
310        }
311    }
312
313    pub fn clone_images(&self) -> Option<Vec<image::DynamicImage>> {
314        self.input_images.as_ref().map(|imgs| imgs.clone_images())
315    }
316
317    pub fn images(&self) -> Option<&[image::DynamicImage]> {
318        self.input_images.as_ref().map(|imgs| imgs.images())
319    }
320
321    pub fn image_hashes(&self) -> Option<&[u64]> {
322        self.input_images.as_ref().map(|imgs| imgs.hashes())
323    }
324
325    pub fn has_images(&self) -> bool {
326        self.input_images
327            .as_ref()
328            .is_some_and(|imgs| !imgs.images().is_empty())
329    }
330
331    pub fn take_audios(&mut self) -> Option<Vec<AudioInput>> {
332        if self.has_changed_prompt {
333            if let Some(input_audios) = self.input_audios.as_mut() {
334                let mut audios = Vec::new();
335                std::mem::swap(&mut audios, input_audios.audios_mut());
336                Some(audios)
337            } else {
338                None
339            }
340        } else {
341            self.input_audios.as_ref().map(|imgs| imgs.clone_audios())
342        }
343    }
344
345    pub fn clone_audios(&self) -> Option<Vec<AudioInput>> {
346        self.input_audios.as_ref().map(|a| a.clone_audios())
347    }
348
349    pub fn audios(&self) -> Option<&[AudioInput]> {
350        self.input_audios.as_ref().map(|a| a.audios())
351    }
352
353    pub fn audio_hashes(&self) -> Option<&[u64]> {
354        self.input_audios.as_ref().map(|a| a.hashes())
355    }
356
357    pub fn has_audios(&self) -> bool {
358        self.input_audios
359            .as_ref()
360            .is_some_and(|a| !a.audios().is_empty())
361    }
362
363    pub fn keep_num_audios(&mut self, audios_to_keep: usize) {
364        if let Some(auds) = self.input_audios.as_mut() {
365            auds.keep_num_audios(audios_to_keep)
366        }
367    }
368
369    pub fn keep_num_images(&mut self, images_to_keep: usize) {
370        if let Some(imgs) = self.input_images.as_mut() {
371            imgs.keep_num_images(images_to_keep)
372        }
373    }
374
375    pub fn image_gen_response_format(&self) -> Option<ImageGenerationResponseFormat> {
376        self.image_gen_response_format
377    }
378
379    pub fn diffusion_params(&self) -> Option<DiffusionGenerationParams> {
380        self.diffusion_params.clone()
381    }
382}
383
384pub struct Sequence {
385    // Metadata, const
386    id: usize,
387    prompt_len: usize,
388    max_len: Option<usize>,
389    timestamp: u128,
390    sampler: Arc<Sampler>,
391    stop_tokens: Vec<u32>,
392    stop_strings: Vec<String>,
393    return_logprobs: bool,
394    responder: Sender<Response>,
395    response_index: usize,
396    creation_time: u64,
397    prompt: String,
398    sequence_stepping_type: SeqStepType,
399    pub(crate) return_raw_logits: bool,
400    token_offset: usize,
401    eos_tokens: Vec<u32>,
402
403    // Multimodal data (images, diffusion settings, pixel caches)
404    pub multimodal: MultimodalData,
405
406    // Completion requests
407    suffix: Option<String>,
408    prefix: Option<String>,
409
410    // Speculative
411    is_tmp: bool,
412
413    // Prefix caching
414    prefill_prompt_toks: Option<Vec<u32>>,
415
416    // Cache
417    normal_cache: Vec<Option<KvCache>>,
418    normal_draft_cache: Vec<Option<KvCache>>,
419    scaling_cache: Option<Tensor>,
420    cache: LayerCaches,
421    draft_cache: LayerCaches,
422    xlora_cache: Option<LayerCaches>,
423    /// For hybrid models: index into the Mamba state pool
424    mamba_state_idx: Option<usize>,
425
426    // Preallocated KV cache (k,v)
427    seq_preallocated_cache: Option<(Tensor, Tensor)>,
428
429    // Mutables
430    tokens: Vec<u32>,
431    logprobs: Vec<Logprobs>,
432    cumulative_logprob: f32,
433    last_logprob: f32,
434    last_completion_bytes_len: usize,
435    last_is_done: Option<StopReason>,
436    completion_bytes: Vec<u8>,
437    stream_idx: usize,
438    pub recognizer: SequenceRecognizer,
439    scheduling_urgency: usize, // The number of passes since scheduling
440    waitlisted_count: usize, // Used in PagedAttention to alert the user when a sequence repeatedly cannot be scheduled
441
442    // GPU things
443    pub prompt_tok_per_sec: f32,
444    pub prompt_timestamp: Option<u128>,
445    pub total_prompt_time: Option<u128>,
446    group: Arc<Mutex<SequenceGroup>>,
447    state: RwLock<SequenceState>,
448
449    // Custom backend metadata
450    custom_metadata: SequenceCustomMetadata,
451
452    // Tool calls
453    pub tools: Option<Arc<ToolCallingMatcher>>,
454}
455
456impl BlockEngineSequence for Sequence {
457    fn blocks_to_add_new_tok(&self) -> usize {
458        match &self.custom_metadata {
459            SequenceCustomMetadata::PagedAttention {
460                logical_token_blocks,
461                physical_blocks_prefill: _,
462                block_size: _,
463            } => {
464                blocks_to_add_new_tok!(logical_token_blocks)
465            }
466            SequenceCustomMetadata::None => unreachable!(),
467        }
468    }
469
470    fn get_id(&self) -> usize {
471        self.id
472    }
473
474    fn logical_token_blocks(&self) -> &[LogicalTokenBlock] {
475        match &self.custom_metadata {
476            SequenceCustomMetadata::PagedAttention {
477                logical_token_blocks,
478                physical_blocks_prefill: _,
479                block_size: _,
480            } => logical_token_blocks,
481            SequenceCustomMetadata::None => unreachable!(),
482        }
483    }
484
485    fn take_physical_blocks_prefill(&mut self) -> Option<Vec<Arc<PhysicalTokenBlock>>> {
486        match &mut self.custom_metadata {
487            SequenceCustomMetadata::PagedAttention {
488                logical_token_blocks: _,
489                physical_blocks_prefill,
490                block_size: _,
491            } => physical_blocks_prefill.take(),
492            SequenceCustomMetadata::None => None,
493        }
494    }
495
496    fn increment_waitlist_count(&mut self) -> usize {
497        let prev = self.waitlisted_count;
498        self.waitlisted_count += 1;
499        prev
500    }
501}
502
503impl Sequence {
504    #[allow(clippy::too_many_arguments)]
505    pub fn new_waiting(
506        tokens: Vec<u32>,
507        prompt: String,
508        id: usize,
509        timestamp: u128,
510        layers: usize,
511        responder: Sender<Response>,
512        sampler: Sampler,
513        stop_tokens: Vec<u32>,
514        stop_strings: Vec<String>,
515        max_len: Option<usize>,
516        return_logprobs: bool,
517        is_xlora: bool,
518        group: Arc<Mutex<SequenceGroup>>,
519        response_index: usize,
520        creation_time: u64,
521        recognizer: SequenceRecognizer,
522        suffix: Option<String>,
523        prefix: Option<String>,
524        input_images: Option<Vec<image::DynamicImage>>,
525        input_audios: Option<Vec<AudioInput>>,
526        // Paged attention
527        block_size: Option<usize>,
528        //
529        tools: Option<Arc<ToolCallingMatcher>>,
530        image_gen_response_format: Option<ImageGenerationResponseFormat>,
531        sequence_stepping_type: SeqStepType,
532        diffusion_params: Option<DiffusionGenerationParams>,
533        // Preallocated KV cache (k,v)
534        seq_preallocated_cache: Option<(Tensor, Tensor)>,
535        //
536        return_raw_logits: bool,
537        eos_tokens: Vec<u32>,
538    ) -> Self {
539        let prompt_len = tokens.len();
540        let mut custom_metadata = if let Some(block_size) = block_size {
541            SequenceCustomMetadata::PagedAttention {
542                logical_token_blocks: Vec::new(),
543                physical_blocks_prefill: None,
544                block_size,
545            }
546        } else {
547            SequenceCustomMetadata::None
548        };
549        custom_metadata
550            .append_tokens_to_blocks(tokens.iter().map(|x| *x as usize).collect::<Vec<_>>());
551        Self {
552            tokens,
553            prompt,
554            logprobs: Vec::new(),
555            prompt_len,
556            id,
557            timestamp,
558            state: RwLock::new(SequenceState::Waiting),
559            normal_cache: vec![None; layers],
560            normal_draft_cache: vec![None; layers],
561            cache: vec![None; layers],
562            draft_cache: vec![None; layers],
563            xlora_cache: if is_xlora {
564                Some(vec![None; layers])
565            } else {
566                None
567            },
568            mamba_state_idx: None,
569            seq_preallocated_cache,
570            responder,
571            sampler: sampler.into(),
572            stop_tokens,
573            stop_strings,
574            max_len,
575            return_logprobs,
576            prompt_tok_per_sec: 0.,
577            prompt_timestamp: None,
578            group,
579            scaling_cache: None,
580            response_index,
581            creation_time,
582            recognizer,
583            prefill_prompt_toks: None,
584            suffix,
585            prefix,
586            cumulative_logprob: 0.,
587            completion_bytes: Vec::new(),
588            stream_idx: 0,
589            last_completion_bytes_len: 0,
590            last_logprob: 0.0,
591            last_is_done: None,
592            is_tmp: false,
593            scheduling_urgency: 0,
594            // Multimodal data
595            multimodal: MultimodalData::new(
596                input_images,
597                input_audios,
598                image_gen_response_format,
599                diffusion_params,
600            ),
601            custom_metadata,
602            tools,
603            sequence_stepping_type,
604            return_raw_logits,
605            token_offset: 0,
606            eos_tokens,
607            total_prompt_time: None,
608            waitlisted_count: 0,
609        }
610    }
611
612    pub fn add_urgency(mut self) -> Self {
613        self.scheduling_urgency += 1;
614        self
615    }
616
617    pub fn reset_urgency(mut self) -> Self {
618        self.scheduling_urgency = 0;
619        self
620    }
621
622    /// Simple metric: (scheduling urgency) + log2(length)
623    /// Takes into account: urgency (scales linear) and length (scales logarithmic)
624    /// Scaling urgency is the number of scheduling passes where we have not been scheduled.
625    pub fn compute_priority(&self) -> f64 {
626        #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
627        (self.scheduling_urgency as f64) + (self.len() as f64).log2()
628    }
629
630    pub fn prefill_v2_normal(
631        mut self,
632        cache: Vec<Option<KvCache>>,
633        toks: Vec<u32>,
634        offset: usize,
635    ) -> Self {
636        self.normal_cache = cache;
637        self.prefill_prompt_toks = Some(toks);
638        self.set_state(SequenceState::RunningPrefillPrompt);
639        self.token_offset = offset;
640        self
641    }
642
643    pub fn prefill_v2_paged(
644        mut self,
645        logical_blocks: Vec<LogicalTokenBlock>,
646        physical_blocks: Vec<Arc<PhysicalTokenBlock>>,
647        toks: Vec<u32>,
648        offset: usize,
649    ) -> Self {
650        self.prefill_prompt_toks = Some(toks);
651        self.set_state(SequenceState::RunningPrefillPrompt);
652        self.token_offset = offset;
653
654        if let SequenceCustomMetadata::PagedAttention {
655            logical_token_blocks,
656            physical_blocks_prefill,
657            block_size: _,
658        } = &mut self.custom_metadata
659        {
660            *logical_token_blocks = logical_blocks;
661            *physical_blocks_prefill = Some(physical_blocks);
662        }
663
664        self
665    }
666
667    /// This is the number of tokens. If the KV cache is Some, then it will use that.
668    pub fn len(&self) -> usize {
669        if let Some(toks) = &self.prefill_prompt_toks {
670            return toks.len();
671        }
672        if self.is_tmp {
673            return self.tokens.len();
674        }
675        // Use xlora cache first because of non granular
676        if self.xlora_cache.as_ref().is_some_and(|c| c[0].is_some()) {
677            self.xlora_cache.as_ref().unwrap()[0]
678                .as_ref()
679                .unwrap()
680                .0
681                .dims()[2]
682                + 1
683        } else if let Some((_, x)) = &self.cache[0] {
684            x.dims()[2] + 1
685        } else {
686            self.tokens.len()
687        }
688    }
689
690    pub fn id(&self) -> &usize {
691        &self.id
692    }
693
694    pub fn is_running(&self) -> bool {
695        matches!(
696            *self.state.read().unwrap(),
697            SequenceState::RunningCompletion | SequenceState::RunningPrompt // | SequenceState::RunningPrefillPrompt
698        )
699    }
700
701    pub fn is_completion(&self) -> bool {
702        matches!(
703            *self.state.read().unwrap(),
704            SequenceState::RunningCompletion
705        )
706    }
707
708    pub fn is_prompt(&self) -> bool {
709        matches!(
710            *self.state.read().unwrap(),
711            SequenceState::RunningPrompt | SequenceState::RunningPrefillPrompt
712        )
713    }
714
715    pub fn is_waiting(&self) -> bool {
716        matches!(*self.state.read().unwrap(), SequenceState::Waiting)
717    }
718
719    pub fn is_finished_paged_attn(&self) -> bool {
720        matches!(
721            *self.state.read().unwrap(),
722            SequenceState::FinishedAborted
723                | SequenceState::FinishedIgnored
724                | SequenceState::Done(_)
725        )
726    }
727
728    pub fn get_toks(&self) -> &[u32] {
729        if let Some(toks) = &self.prefill_prompt_toks {
730            return toks;
731        }
732        &self.tokens
733    }
734
735    pub fn get_initial_prompt(&self) -> &str {
736        &self.prompt
737    }
738
739    pub fn set_initial_prompt(&mut self, new: String) {
740        self.prompt = new;
741    }
742
743    pub fn token_offset(&self) -> usize {
744        self.token_offset
745    }
746
747    /// This will also set prompt_len
748    pub(crate) fn set_toks_and_reallocate(
749        &mut self,
750        toks: Vec<u32>,
751        paged_attn_metadata: Option<&mut PagedAttentionMeta>,
752    ) {
753        self.tokens.clone_from(&toks);
754        self.prompt_len = self.tokens.len();
755        // Handle possible block engine
756        match &mut self.custom_metadata {
757            SequenceCustomMetadata::PagedAttention {
758                logical_token_blocks,
759                physical_blocks_prefill: _,
760                block_size: _,
761            } => {
762                logical_token_blocks.clear();
763            }
764            SequenceCustomMetadata::None => (),
765        }
766        self.custom_metadata
767            .append_tokens_to_blocks(toks.iter().map(|x| *x as usize).collect::<Vec<_>>());
768
769        if let Some(metadata) = paged_attn_metadata {
770            // Free and then reallocate as appropriate
771            get_mut_arcmutex!(metadata.block_engine).free_sequence(*self.id());
772            get_mut_arcmutex!(metadata.block_engine).allocate(self);
773        }
774    }
775
776    pub fn completion_bytes(&self) -> &[u8] {
777        &self.completion_bytes
778    }
779
780    pub fn preallocated_cache(&self) -> Option<&(Tensor, Tensor)> {
781        self.seq_preallocated_cache.as_ref()
782    }
783
784    pub fn normal_cache(&mut self) -> &mut Vec<Option<KvCache>> {
785        &mut self.normal_cache
786    }
787
788    pub fn normal_draft_cache(&mut self) -> &mut Vec<Option<KvCache>> {
789        &mut self.normal_draft_cache
790    }
791
792    pub fn cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
793        &mut self.cache
794    }
795
796    pub fn draft_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
797        &mut self.draft_cache
798    }
799
800    pub fn xlora_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
801        self.xlora_cache.as_mut().expect("No X-LoRA cache.")
802    }
803
804    pub fn scaling_cache(&mut self) -> &mut Option<Tensor> {
805        &mut self.scaling_cache
806    }
807
808    pub fn mamba_state_idx(&self) -> Option<usize> {
809        self.mamba_state_idx
810    }
811
812    pub fn set_mamba_state_idx(&mut self, idx: Option<usize>) {
813        self.mamba_state_idx = idx;
814    }
815
816    pub fn is_xlora(&self) -> bool {
817        self.xlora_cache.is_some()
818    }
819
820    pub fn sampler(&mut self) -> Arc<Sampler> {
821        self.sampler.clone()
822    }
823
824    /// Add a some prefill tokens. Only meant for internal speculative decoding usage.
825    pub fn set_prefill_toks(&mut self, toks: Vec<u32>) {
826        self.prefill_prompt_toks = Some(toks)
827    }
828
829    /// Remove the prefill tokens.
830    pub fn reset_prefill_toks(&mut self) {
831        self.prefill_prompt_toks = None
832    }
833
834    /// Internal api to add one raw token.
835    pub(crate) fn add_tmp_tok(&mut self, tok: u32) {
836        self.is_tmp = true;
837        self.tokens.push(tok);
838        // Handle possible block engine
839        self.custom_metadata.append_token_to_blocks(tok as usize);
840    }
841
842    /// Internal api to remove n raw tokens.
843    pub(crate) fn remove_tmp_tok(&mut self, n: usize) {
844        self.is_tmp = false;
845        self.tokens.truncate(self.tokens.len() - n);
846        // Handle possible block engine
847        self.custom_metadata.remove_tokens_from_blocks(n);
848    }
849
850    pub fn add_token(
851        &mut self,
852        tok: Logprobs,
853        completion_bytes: Vec<u8>,
854        is_done: &Option<StopReason>,
855    ) {
856        let stopped_by_token = matches!(
857            is_done,
858            Some(StopReason::Eos) | Some(StopReason::StopTok(_))
859        );
860        if !stopped_by_token {
861            // Completion bytes is used to check for stop strings, and as the response buffer.
862            // We don't need to add stop tokens to the completion bytes to check for stop strings.
863            // And by not adding it here, we can avoid having to delete these tokens from the output.
864            self.completion_bytes.extend_from_slice(&completion_bytes);
865            self.last_completion_bytes_len = completion_bytes.len();
866        }
867        self.last_logprob = tok.logprob;
868        self.last_is_done = *is_done;
869
870        self.custom_metadata
871            .append_token_to_blocks(tok.token as usize);
872
873        self.cumulative_logprob += tok.logprob;
874        self.tokens.push(tok.token);
875        self.logprobs.push(tok);
876        self.reset_prefill_toks();
877    }
878
879    pub fn responder(&self) -> Sender<Response> {
880        self.responder.clone()
881    }
882
883    pub fn creation_time(&self) -> u64 {
884        self.creation_time
885    }
886
887    pub fn set_state(&self, state: SequenceState) {
888        if matches!(state, SequenceState::Error) {
889            let mut group = get_mut_group!(self);
890            group.n_choices = group.n_choices.saturating_sub(1);
891        }
892        *self.state.write().unwrap() = state;
893    }
894
895    pub fn getstate(&self) -> SequenceState {
896        *self.state.read().unwrap()
897    }
898
899    pub fn is_done(
900        &self,
901        tok: u32,
902        eos_tok: Option<&[u32]>,
903        max_model_len: usize,
904    ) -> Option<StopReason> {
905        let is_eos = match eos_tok {
906            Some(eos_tok) => eos_tok.contains(&tok),
907            None => false,
908        };
909        if is_eos {
910            Some(StopReason::Eos)
911        } else if matches!(
912            &*self.state.read().unwrap(),
913            SequenceState::Done(StopReason::Canceled)
914        ) {
915            Some(StopReason::Canceled)
916        } else if self.stop_tokens.contains(&tok) {
917            Some(StopReason::StopTok(tok))
918        } else if self.max_len.is_some()
919            && self.tokens.len().saturating_sub(self.prompt_len) + 1 >= self.max_len.unwrap()
920        {
921            // add_token will be called after this check
922            Some(StopReason::Length(self.max_len.unwrap()))
923        } else if self.tokens.len().saturating_sub(self.prompt_len) >= max_model_len {
924            Some(StopReason::ModelLength(max_model_len))
925        } else {
926            if !self.stop_strings.is_empty() {
927                for (idx, s) in self.stop_strings.iter().enumerate() {
928                    if let Some(pos) = galil_seiferas::gs_find(&self.completion_bytes, s.as_bytes())
929                    {
930                        return Some(StopReason::StopString {
931                            stop_string_idx: idx,
932                            completion_bytes_pos: pos,
933                        });
934                    }
935                }
936            }
937            None
938        }
939    }
940
941    pub fn logprobs(&self) -> &[Logprobs] {
942        &self.logprobs
943    }
944
945    pub fn return_logprobs(&self) -> bool {
946        self.return_logprobs
947    }
948
949    pub fn prompt_tokens(&self) -> usize {
950        self.prompt_len
951    }
952
953    pub fn stop_strings(&self) -> &[String] {
954        &self.stop_strings
955    }
956
957    /// Returns the delta between the last two decoded sequences
958    pub fn get_delta(
959        &mut self,
960    ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
961        let new_decoded = self.peek_delta();
962        if matches!(new_decoded, Ok(Some(_))) {
963            self.stream_idx = self.completion_bytes.len();
964        }
965        new_decoded
966    }
967
968    /// Peeks at the delta between the last two decoded sequences, but does not advance the stream index.
969    pub fn peek_delta(&self) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
970        let is_first = self.stream_idx == 0;
971        let new_decoded = String::from_utf8_lossy(&self.completion_bytes[self.stream_idx..]);
972        // Check if the sequence ends with valid utf8, if not skip it as it probably is a multi token sequence
973        if new_decoded.ends_with('�') {
974            return Ok(None);
975        }
976
977        // The first token usually starts with a space. We don't want to add that to the delta.
978        // Since we're using the completion_bytes, we need to take care of that ourselves.
979        // Had we used HF's Tokenizer, it would have taken care of that for us.
980        if is_first {
981            return Ok(Some(new_decoded.trim_start().to_string()));
982        }
983        Ok(Some(new_decoded.to_string()))
984    }
985
986    pub fn timestamp(&self) -> u128 {
987        self.timestamp
988    }
989
990    pub fn prompt_timestamp(&self) -> Option<u128> {
991        self.prompt_timestamp
992    }
993
994    fn update_time_info(&self) {
995        let now = SystemTime::now()
996            .duration_since(UNIX_EPOCH)
997            .expect("Time travel has occurred!")
998            .as_millis();
999
1000        if let Some(ts) = self.prompt_timestamp {
1001            get_mut_group!(self).total_completion_time = now - ts;
1002            get_mut_group!(self).total_prompt_time = self.total_prompt_time.unwrap();
1003        }
1004
1005        get_mut_group!(self).total_time = now - self.timestamp;
1006
1007        get_mut_group!(self).total_prompt_toks = self.prompt_len;
1008        get_mut_group!(self).total_toks = self.len();
1009    }
1010
1011    pub fn add_image_choice_to_group(&self, choice: ImageChoice) {
1012        get_mut_group!(self).image_choices.push(choice);
1013    }
1014
1015    pub fn add_speech_pcm_to_group(&self, pcm: Arc<Vec<f32>>, rate: usize, channels: usize) {
1016        get_mut_group!(self).speech_pcms.push((pcm, rate, channels));
1017    }
1018
1019    pub fn add_choice_to_group(&self, choice: Choice) {
1020        get_mut_group!(self).choices.push(choice);
1021        self.update_time_info();
1022    }
1023
1024    pub fn add_raw_choice_to_group(&self, logit_chunks: Vec<Tensor>) {
1025        get_mut_group!(self)
1026            .raw_choices
1027            .push((logit_chunks, self.tokens.clone()));
1028        self.update_time_info();
1029    }
1030
1031    pub fn add_embedding_choice_to_group(&self, embedding: Vec<f32>) {
1032        get_mut_group!(self).embedding_choices.push(embedding);
1033        self.update_time_info();
1034    }
1035
1036    pub fn add_completion_choice_to_group(&self, mut choice: CompletionChoice) {
1037        choice.text = format!(
1038            "{}{}{}",
1039            self.prefix.as_deref().unwrap_or(""),
1040            choice.text,
1041            self.suffix.as_deref().unwrap_or("")
1042        );
1043        get_mut_group!(self)
1044            .completion_choices
1045            .push((self.cumulative_logprob, choice));
1046        self.update_time_info();
1047    }
1048
1049    pub fn get_response_index(&self) -> usize {
1050        self.response_index
1051    }
1052
1053    pub fn get_mut_group(&self) -> MutexGuard<'_, SequenceGroup> {
1054        get_mut_group!(self)
1055    }
1056
1057    pub fn add_streaming_chunk_choice_to_group(&self, chunk: ChunkChoice) {
1058        get_mut_group!(self).chat_streaming_chunks.push(chunk);
1059        self.update_time_info();
1060    }
1061
1062    pub fn add_streaming_completion_chunk_choice_to_group(&self, chunk: CompletionChunkChoice) {
1063        get_mut_group!(self).completion_streaming_chunks.push(chunk);
1064        self.update_time_info();
1065    }
1066
1067    pub fn take_images(&mut self) -> Option<Vec<image::DynamicImage>> {
1068        self.multimodal.take_images()
1069    }
1070
1071    pub fn clone_images(&self) -> Option<Vec<image::DynamicImage>> {
1072        self.multimodal.clone_images()
1073    }
1074
1075    pub fn images(&self) -> Option<&[image::DynamicImage]> {
1076        self.multimodal.images()
1077    }
1078
1079    pub fn image_hashes(&self) -> Option<&[u64]> {
1080        self.multimodal.image_hashes()
1081    }
1082
1083    pub fn has_images(&self) -> bool {
1084        self.multimodal.has_images()
1085    }
1086
1087    pub fn take_audios(&mut self) -> Option<Vec<AudioInput>> {
1088        self.multimodal.take_audios()
1089    }
1090
1091    pub fn clone_audios(&self) -> Option<Vec<AudioInput>> {
1092        self.multimodal.clone_audios()
1093    }
1094
1095    pub fn audios(&self) -> Option<&[AudioInput]> {
1096        self.multimodal.audios()
1097    }
1098
1099    pub fn audio_hashes(&self) -> Option<&[u64]> {
1100        self.multimodal.audio_hashes()
1101    }
1102
1103    pub fn has_audios(&self) -> bool {
1104        self.multimodal.has_audios()
1105    }
1106
1107    /// Keep these last n audios
1108    pub fn keep_num_audios(&mut self, audios_to_keep: usize) {
1109        self.multimodal.keep_num_audios(audios_to_keep)
1110    }
1111
1112    /// Keep these last n images
1113    pub fn keep_num_images(&mut self, images_to_keep: usize) {
1114        self.multimodal.keep_num_images(images_to_keep)
1115    }
1116
1117    pub fn image_gen_response_format(&self) -> Option<ImageGenerationResponseFormat> {
1118        self.multimodal.image_gen_response_format()
1119    }
1120
1121    pub fn sequence_stepping_type(&self) -> &SeqStepType {
1122        &self.sequence_stepping_type
1123    }
1124
1125    pub fn get_diffusion_diffusion_params(&self) -> Option<DiffusionGenerationParams> {
1126        self.multimodal.diffusion_params()
1127    }
1128
1129    pub fn eos_tokens(&self) -> &[u32] {
1130        &self.eos_tokens
1131    }
1132}
1133
1134pub struct SequenceGroup {
1135    n_choices: usize, // The target number of choices to return. Can be decreased if an error is thrown.
1136    best_of: Option<usize>, // Top n seqs based on cumulative logprobs.
1137    pub total_prompt_toks: usize,
1138    pub total_toks: usize,
1139    pub total_prompt_time: u128,
1140    pub total_time: u128,
1141    pub total_completion_time: u128,
1142    choices: Vec<Choice>,
1143    image_choices: Vec<ImageChoice>,
1144    speech_pcms: Vec<(Arc<Vec<f32>>, usize, usize)>, // (pcm, rate, channels)
1145    raw_choices: Vec<(Vec<Tensor>, Vec<u32>)>,
1146    embedding_choices: Vec<Vec<f32>>,
1147    completion_choices: Vec<(f32, CompletionChoice)>,
1148    pub chat_streaming_chunks: Vec<ChunkChoice>,
1149    pub completion_streaming_chunks: Vec<CompletionChunkChoice>,
1150    pub is_streaming: bool,
1151    pub is_chat: bool,
1152}
1153
1154impl SequenceGroup {
1155    pub fn new(
1156        n_choices: usize,
1157        is_streaming: bool,
1158        is_chat: bool,
1159        best_of: Option<usize>,
1160    ) -> Self {
1161        Self {
1162            choices: Vec::new(),
1163            image_choices: Vec::new(),
1164            speech_pcms: Vec::new(),
1165            raw_choices: Vec::new(),
1166            embedding_choices: Vec::new(),
1167            completion_choices: Vec::new(),
1168            n_choices,
1169            total_prompt_toks: 0,
1170            total_toks: 0,
1171            total_prompt_time: 0,
1172            total_time: 0,
1173            total_completion_time: 0,
1174            chat_streaming_chunks: Vec::new(),
1175            completion_streaming_chunks: Vec::new(),
1176            is_streaming,
1177            is_chat,
1178            best_of,
1179        }
1180    }
1181
1182    pub fn get_choices(&self) -> &[Choice] {
1183        &self.choices
1184    }
1185
1186    /// This may apply the best_of.
1187    pub fn get_completion_choices(&self) -> Vec<CompletionChoice> {
1188        if let Some(best_of) = self.best_of {
1189            let mut choices = self.completion_choices.clone();
1190            // Sort by descending logprobs
1191            choices.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("No ordering."));
1192            choices
1193                .into_iter()
1194                .take(best_of)
1195                .map(|(_, x)| x)
1196                .collect::<Vec<_>>()
1197        } else {
1198            self.completion_choices
1199                .clone()
1200                .into_iter()
1201                .map(|(_, x)| x)
1202                .collect::<Vec<_>>()
1203        }
1204    }
1205
1206    pub fn get_image_choices(&self) -> &[ImageChoice] {
1207        &self.image_choices
1208    }
1209
1210    pub fn get_usage(&self) -> Usage {
1211        #[allow(clippy::cast_precision_loss)]
1212        Usage {
1213            completion_tokens: self.total_toks.saturating_sub(self.total_prompt_toks),
1214            prompt_tokens: self.total_prompt_toks,
1215            total_tokens: self.total_toks,
1216            avg_tok_per_sec: (self.total_toks as f32 / self.total_time as f32) * 1000.,
1217            avg_prompt_tok_per_sec: (self.total_prompt_toks as f32 / self.total_prompt_time as f32)
1218                * 1000.,
1219            avg_compl_tok_per_sec: (self.total_toks.saturating_sub(self.total_prompt_toks) as f32
1220                / self.total_completion_time as f32)
1221                * 1000.,
1222            total_time_sec: self.total_time as f32 / 1000.,
1223            total_completion_time_sec: self.total_completion_time as f32 / 1000.,
1224            total_prompt_time_sec: self.total_prompt_time as f32 / 1000.,
1225        }
1226    }
1227
1228    pub async fn maybe_send_chat_done_response(
1229        &self,
1230        response: ChatCompletionResponse,
1231        sender: Sender<Response>,
1232    ) -> Result<(), SendError<Response>> {
1233        if self.choices.len() == self.n_choices {
1234            sender.send(Response::Done(response)).await?;
1235        }
1236
1237        Ok(())
1238    }
1239
1240    pub async fn maybe_send_raw_done_response(
1241        &self,
1242        sender: Sender<Response>,
1243    ) -> Result<(), SendError<Response>> {
1244        if self.raw_choices.len() == self.n_choices {
1245            assert_eq!(self.raw_choices.len(), 1);
1246            let (logits_chunks, tokens) = self.raw_choices[0].clone();
1247            sender
1248                .send(Response::Raw {
1249                    logits_chunks,
1250                    tokens,
1251                })
1252                .await?;
1253        }
1254
1255        Ok(())
1256    }
1257
1258    pub async fn maybe_send_embedding_done_response(
1259        &self,
1260        sender: Sender<Response>,
1261    ) -> Result<(), SendError<Response>> {
1262        if self.embedding_choices.len() == self.n_choices {
1263            assert_eq!(self.embedding_choices.len(), 1);
1264            let embeddings = self.embedding_choices[0].clone();
1265            let prompt_tokens = self.total_prompt_toks;
1266            let total_tokens = self.total_toks;
1267            sender
1268                .send(Response::Embeddings {
1269                    embeddings,
1270                    prompt_tokens,
1271                    total_tokens,
1272                })
1273                .await?;
1274        }
1275
1276        Ok(())
1277    }
1278
1279    pub async fn maybe_send_image_gen_response(
1280        &self,
1281        response: ImageGenerationResponse,
1282        sender: Sender<Response>,
1283    ) -> Result<(), SendError<Response>> {
1284        if self.image_choices.len() == self.n_choices {
1285            sender.send(Response::ImageGeneration(response)).await?;
1286        }
1287
1288        Ok(())
1289    }
1290
1291    pub async fn maybe_send_speech_response(
1292        &self,
1293        sender: Sender<Response>,
1294    ) -> Result<(), SendError<Response>> {
1295        assert_eq!(self.speech_pcms.len(), 1);
1296
1297        let (pcm, rate, channels) = self.speech_pcms[0].clone();
1298        sender
1299            .send(Response::Speech {
1300                pcm,
1301                rate,
1302                channels,
1303            })
1304            .await?;
1305
1306        Ok(())
1307    }
1308
1309    pub async fn maybe_send_streaming_response(
1310        &mut self,
1311        seq: &Sequence,
1312        model: String,
1313        usage_opt: Option<Usage>,
1314    ) -> Result<(), Box<SendError<Response>>> {
1315        if self.chat_streaming_chunks.len() == self.n_choices && self.is_streaming {
1316            let mut swap_streaming_chunks = vec![];
1317
1318            std::mem::swap(&mut swap_streaming_chunks, &mut self.chat_streaming_chunks);
1319
1320            seq.responder()
1321                .send(Response::Chunk(ChatCompletionChunkResponse {
1322                    id: seq.id.to_string(),
1323                    choices: swap_streaming_chunks,
1324                    created: seq.timestamp,
1325                    model: model.clone(),
1326                    system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
1327                    object: "chat.completion.chunk".to_string(),
1328                    usage: usage_opt,
1329                }))
1330                .await?;
1331        } else if self.completion_streaming_chunks.len() == self.n_choices && self.is_streaming {
1332            let mut swap_streaming_chunks = vec![];
1333
1334            std::mem::swap(
1335                &mut swap_streaming_chunks,
1336                &mut self.completion_streaming_chunks,
1337            );
1338
1339            seq.responder()
1340                .send(Response::CompletionChunk(CompletionChunkResponse {
1341                    id: seq.id.to_string(),
1342                    choices: swap_streaming_chunks,
1343                    created: seq.timestamp,
1344                    model: model.clone(),
1345                    system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
1346                    object: "text_completion".to_string(),
1347                }))
1348                .await?;
1349        }
1350        Ok(())
1351    }
1352
1353    pub async fn maybe_send_completion_done_response(
1354        &self,
1355        response: CompletionResponse,
1356        sender: Sender<Response>,
1357    ) -> Result<(), Box<SendError<Response>>> {
1358        if self.completion_choices.len() == self.n_choices {
1359            sender.send(Response::CompletionDone(response)).await?;
1360        }
1361        Ok(())
1362    }
1363}