1use crate::{
2 get_mut_arcmutex, get_mut_group,
3 paged_attention::PhysicalTokenBlock,
4 pipeline::{text_models_inputs_processor::PagedAttentionMeta, LayerCaches},
5 response::{ChatCompletionChunkResponse, Choice, ChunkChoice, Response, SYSTEM_FINGERPRINT},
6 sampler::{Logprobs, Sampler},
7 AudioInput, ChatCompletionResponse, Usage,
8};
9use crate::{
10 paged_attention::{BlockEngineSequence, LogicalTokenBlock},
11 pipeline::{DiffusionGenerationParams, KvCache},
12 response::CompletionChoice,
13 tools::ToolCallingMatcher,
14 CompletionChunkChoice, CompletionChunkResponse, CompletionResponse, ImageChoice,
15 ImageGenerationResponse, ImageGenerationResponseFormat,
16};
17use candle_core::Tensor;
18use std::{
19 fmt::Display,
20 hash::{DefaultHasher, Hash, Hasher},
21 sync::{Arc, RwLock},
22 time::{SystemTime, UNIX_EPOCH},
23};
24use tokio::sync::{
25 mpsc::{error::SendError, Sender},
26 Mutex, MutexGuard,
27};
28
29#[derive(Clone, Copy, PartialEq, Debug)]
30pub enum StopReason {
31 Eos,
32 StopTok(u32),
33 Length(usize),
34 ModelLength(usize),
35 StopString {
36 stop_string_idx: usize,
37 completion_bytes_pos: usize,
38 },
39 Canceled,
40 GeneratedImage,
41 GeneratedSpeech,
42 ToolCalls,
43}
44
45impl Display for StopReason {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 StopReason::Eos => write!(f, "stop"),
49 StopReason::Length(_) | StopReason::ModelLength(_) => write!(f, "length"),
50 StopReason::StopTok(_) | StopReason::StopString { .. } => write!(f, "stop"),
51 StopReason::Canceled => write!(f, "canceled"),
52 StopReason::GeneratedImage => write!(f, "generated_image"),
53 StopReason::GeneratedSpeech => write!(f, "generated_speech"),
54 StopReason::ToolCalls => write!(f, "tool_calls"),
55 }
56 }
57}
58
59#[derive(Clone, Copy, PartialEq, Debug)]
60pub enum SequenceState {
61 Done(StopReason),
62 RunningPrompt,
63 RunningCompletion,
64 Waiting,
65 Error,
66 RunningPrefillPrompt,
67 FinishedAborted,
69 FinishedIgnored,
70 Swapped,
71}
72
73pub enum SequenceRecognizer {
74 Llguidance(Box<llguidance::Matcher>),
75 None,
76}
77
78enum SequenceCustomMetadata {
79 PagedAttention {
80 logical_token_blocks: Vec<LogicalTokenBlock>,
81 physical_blocks_prefill: Option<Vec<Arc<PhysicalTokenBlock>>>,
82 block_size: usize,
83 },
84 None,
85}
86
87macro_rules! blocks_to_add_new_tok {
88 ($logical_token_blocks:expr) => {{
89 let last = $logical_token_blocks.last();
90 if !last.is_some_and(|last| last.is_full() || last.is_empty()) {
91 0
93 } else {
94 1
95 }
96 }};
97}
98
99pub(crate) fn util_append_token_to_blocks(
100 tok: usize,
101 logical_token_blocks: &mut Vec<LogicalTokenBlock>,
102 block_size: usize,
103) {
104 let last = logical_token_blocks.last_mut();
105 match last {
106 Some(last) => {
107 last.append_token_id(tok);
108 }
109 None => {
110 logical_token_blocks.push(LogicalTokenBlock::new(block_size));
111 logical_token_blocks
113 .last_mut()
114 .expect("just pushed a block, vector cannot be empty")
115 .append_token_id(tok);
116 }
117 }
118 if logical_token_blocks
121 .last()
122 .expect("logical_token_blocks should not be empty after appending")
123 .is_full()
124 {
125 logical_token_blocks.push(LogicalTokenBlock::new(block_size));
126 }
127}
128
129impl SequenceCustomMetadata {
130 fn append_token_to_blocks(&mut self, tok: usize) {
131 match self {
132 Self::PagedAttention {
133 logical_token_blocks,
134 physical_blocks_prefill: _,
135 block_size,
136 } => {
137 util_append_token_to_blocks(tok, logical_token_blocks, *block_size);
138 }
139 Self::None => (),
140 }
141 }
142
143 fn pop_token_from_blocks(&mut self) {
144 match self {
145 Self::PagedAttention {
146 logical_token_blocks,
147 physical_blocks_prefill: _,
148 block_size: _,
149 } => {
150 let last = logical_token_blocks.last_mut().unwrap();
151 last.pop_token();
152 }
153 Self::None => (),
154 }
155 }
156
157 fn append_tokens_to_blocks(&mut self, toks: Vec<usize>) {
158 for tok in toks {
159 self.append_token_to_blocks(tok);
160 }
161 }
162
163 fn remove_tokens_from_blocks(&mut self, n: usize) {
164 for _ in 0..n {
165 self.pop_token_from_blocks();
166 }
167 }
168}
169
170#[derive(Clone, Copy)]
171pub enum SeqStepType {
172 PromptAndDecode,
173 OneShot,
174}
175
176pub struct SequenceImages {
177 images: Vec<image::DynamicImage>,
178 hashes: Vec<u64>,
179}
180
181#[derive(Clone)]
182pub struct SequenceAudios {
183 audios: Vec<AudioInput>,
184 hashes: Vec<u64>,
185}
186
187impl SequenceAudios {
188 fn new(input_audios: Vec<AudioInput>) -> Self {
189 let hashes = input_audios.iter().map(|a| {
190 let mut hasher = DefaultHasher::new();
191 for s in &a.samples {
192 s.to_bits().hash(&mut hasher);
193 }
194 a.sample_rate.hash(&mut hasher);
195 hasher.finish()
196 });
197 Self {
198 hashes: hashes.collect(),
199 audios: input_audios,
200 }
201 }
202
203 fn clone_audios(&self) -> Vec<AudioInput> {
204 self.audios.clone()
205 }
206
207 fn audios(&self) -> &[AudioInput] {
208 &self.audios
209 }
210
211 fn audios_mut(&mut self) -> &mut Vec<AudioInput> {
212 &mut self.audios
213 }
214
215 fn hashes(&self) -> &[u64] {
216 &self.hashes
217 }
218
219 fn keep_num_audios(&mut self, audios_to_keep: usize) {
220 if self.audios.len() > audios_to_keep {
221 let start = self.audios.len() - audios_to_keep;
222 self.audios = self.audios[start..].to_vec();
223 }
226 }
227}
228
229impl SequenceImages {
230 fn new(input_images: Vec<image::DynamicImage>) -> Self {
231 let hashes = input_images.iter().map(|x| {
232 let mut hasher = DefaultHasher::new();
233 x.as_bytes().hash(&mut hasher);
234 hasher.finish()
235 });
236 Self {
237 hashes: hashes.collect(),
238 images: input_images,
239 }
240 }
241
242 fn clone_images(&self) -> Vec<image::DynamicImage> {
243 self.images.clone()
244 }
245
246 fn images(&self) -> &[image::DynamicImage] {
247 &self.images
248 }
249
250 fn images_mut(&mut self) -> &mut Vec<image::DynamicImage> {
251 &mut self.images
252 }
253
254 fn hashes(&self) -> &[u64] {
255 &self.hashes
256 }
257
258 fn keep_num_images(&mut self, images_to_keep: usize) {
259 if self.images.len() > images_to_keep {
260 let start = self.images.len() - images_to_keep;
261 self.images = self.images[start..].to_vec();
262 }
265 }
266}
267
268pub struct MultimodalData {
270 pub input_images: Option<SequenceImages>,
271 pub input_audios: Option<SequenceAudios>,
272 pub cached_pixel_values: Option<Tensor>,
273 pub cached_img_thw: Option<Tensor>,
274 pub cached_vid_thw: Option<Tensor>,
275 pub has_changed_prompt: bool,
276 pub image_gen_response_format: Option<ImageGenerationResponseFormat>,
277 pub diffusion_params: Option<DiffusionGenerationParams>,
278}
279
280impl MultimodalData {
281 pub fn new(
282 input_images: Option<Vec<image::DynamicImage>>,
283 input_audios: Option<Vec<AudioInput>>,
284 image_gen_response_format: Option<ImageGenerationResponseFormat>,
285 diffusion_params: Option<DiffusionGenerationParams>,
286 ) -> Self {
287 MultimodalData {
288 input_images: input_images.map(SequenceImages::new),
289 input_audios: input_audios.map(SequenceAudios::new),
290 cached_pixel_values: None,
291 cached_img_thw: None,
292 cached_vid_thw: None,
293 has_changed_prompt: false,
294 image_gen_response_format,
295 diffusion_params,
296 }
297 }
298
299 pub fn take_images(&mut self) -> Option<Vec<image::DynamicImage>> {
300 if self.has_changed_prompt {
301 if let Some(input_images) = self.input_images.as_mut() {
302 let mut images = Vec::new();
303 std::mem::swap(&mut images, input_images.images_mut());
304 Some(images)
305 } else {
306 None
307 }
308 } else {
309 self.input_images.as_ref().map(|imgs| imgs.clone_images())
310 }
311 }
312
313 pub fn clone_images(&self) -> Option<Vec<image::DynamicImage>> {
314 self.input_images.as_ref().map(|imgs| imgs.clone_images())
315 }
316
317 pub fn images(&self) -> Option<&[image::DynamicImage]> {
318 self.input_images.as_ref().map(|imgs| imgs.images())
319 }
320
321 pub fn image_hashes(&self) -> Option<&[u64]> {
322 self.input_images.as_ref().map(|imgs| imgs.hashes())
323 }
324
325 pub fn has_images(&self) -> bool {
326 self.input_images
327 .as_ref()
328 .is_some_and(|imgs| !imgs.images().is_empty())
329 }
330
331 pub fn take_audios(&mut self) -> Option<Vec<AudioInput>> {
332 if self.has_changed_prompt {
333 if let Some(input_audios) = self.input_audios.as_mut() {
334 let mut audios = Vec::new();
335 std::mem::swap(&mut audios, input_audios.audios_mut());
336 Some(audios)
337 } else {
338 None
339 }
340 } else {
341 self.input_audios.as_ref().map(|imgs| imgs.clone_audios())
342 }
343 }
344
345 pub fn clone_audios(&self) -> Option<Vec<AudioInput>> {
346 self.input_audios.as_ref().map(|a| a.clone_audios())
347 }
348
349 pub fn audios(&self) -> Option<&[AudioInput]> {
350 self.input_audios.as_ref().map(|a| a.audios())
351 }
352
353 pub fn audio_hashes(&self) -> Option<&[u64]> {
354 self.input_audios.as_ref().map(|a| a.hashes())
355 }
356
357 pub fn has_audios(&self) -> bool {
358 self.input_audios
359 .as_ref()
360 .is_some_and(|a| !a.audios().is_empty())
361 }
362
363 pub fn keep_num_audios(&mut self, audios_to_keep: usize) {
364 if let Some(auds) = self.input_audios.as_mut() {
365 auds.keep_num_audios(audios_to_keep)
366 }
367 }
368
369 pub fn keep_num_images(&mut self, images_to_keep: usize) {
370 if let Some(imgs) = self.input_images.as_mut() {
371 imgs.keep_num_images(images_to_keep)
372 }
373 }
374
375 pub fn image_gen_response_format(&self) -> Option<ImageGenerationResponseFormat> {
376 self.image_gen_response_format
377 }
378
379 pub fn diffusion_params(&self) -> Option<DiffusionGenerationParams> {
380 self.diffusion_params.clone()
381 }
382}
383
384pub struct Sequence {
385 id: usize,
387 prompt_len: usize,
388 max_len: Option<usize>,
389 timestamp: u128,
390 sampler: Arc<Sampler>,
391 stop_tokens: Vec<u32>,
392 stop_strings: Vec<String>,
393 return_logprobs: bool,
394 responder: Sender<Response>,
395 response_index: usize,
396 creation_time: u64,
397 prompt: String,
398 sequence_stepping_type: SeqStepType,
399 pub(crate) return_raw_logits: bool,
400 token_offset: usize,
401 eos_tokens: Vec<u32>,
402
403 pub multimodal: MultimodalData,
405
406 suffix: Option<String>,
408 prefix: Option<String>,
409
410 is_tmp: bool,
412
413 prefill_prompt_toks: Option<Vec<u32>>,
415
416 normal_cache: Vec<Option<KvCache>>,
418 normal_draft_cache: Vec<Option<KvCache>>,
419 scaling_cache: Option<Tensor>,
420 cache: LayerCaches,
421 draft_cache: LayerCaches,
422 xlora_cache: Option<LayerCaches>,
423 mamba_state_idx: Option<usize>,
425
426 seq_preallocated_cache: Option<(Tensor, Tensor)>,
428
429 tokens: Vec<u32>,
431 logprobs: Vec<Logprobs>,
432 cumulative_logprob: f32,
433 last_logprob: f32,
434 last_completion_bytes_len: usize,
435 last_is_done: Option<StopReason>,
436 completion_bytes: Vec<u8>,
437 stream_idx: usize,
438 pub recognizer: SequenceRecognizer,
439 scheduling_urgency: usize, waitlisted_count: usize, pub prompt_tok_per_sec: f32,
444 pub prompt_timestamp: Option<u128>,
445 pub total_prompt_time: Option<u128>,
446 group: Arc<Mutex<SequenceGroup>>,
447 state: RwLock<SequenceState>,
448
449 custom_metadata: SequenceCustomMetadata,
451
452 pub tools: Option<Arc<ToolCallingMatcher>>,
454}
455
456impl BlockEngineSequence for Sequence {
457 fn blocks_to_add_new_tok(&self) -> usize {
458 match &self.custom_metadata {
459 SequenceCustomMetadata::PagedAttention {
460 logical_token_blocks,
461 physical_blocks_prefill: _,
462 block_size: _,
463 } => {
464 blocks_to_add_new_tok!(logical_token_blocks)
465 }
466 SequenceCustomMetadata::None => unreachable!(),
467 }
468 }
469
470 fn get_id(&self) -> usize {
471 self.id
472 }
473
474 fn logical_token_blocks(&self) -> &[LogicalTokenBlock] {
475 match &self.custom_metadata {
476 SequenceCustomMetadata::PagedAttention {
477 logical_token_blocks,
478 physical_blocks_prefill: _,
479 block_size: _,
480 } => logical_token_blocks,
481 SequenceCustomMetadata::None => unreachable!(),
482 }
483 }
484
485 fn take_physical_blocks_prefill(&mut self) -> Option<Vec<Arc<PhysicalTokenBlock>>> {
486 match &mut self.custom_metadata {
487 SequenceCustomMetadata::PagedAttention {
488 logical_token_blocks: _,
489 physical_blocks_prefill,
490 block_size: _,
491 } => physical_blocks_prefill.take(),
492 SequenceCustomMetadata::None => None,
493 }
494 }
495
496 fn increment_waitlist_count(&mut self) -> usize {
497 let prev = self.waitlisted_count;
498 self.waitlisted_count += 1;
499 prev
500 }
501}
502
503impl Sequence {
504 #[allow(clippy::too_many_arguments)]
505 pub fn new_waiting(
506 tokens: Vec<u32>,
507 prompt: String,
508 id: usize,
509 timestamp: u128,
510 layers: usize,
511 responder: Sender<Response>,
512 sampler: Sampler,
513 stop_tokens: Vec<u32>,
514 stop_strings: Vec<String>,
515 max_len: Option<usize>,
516 return_logprobs: bool,
517 is_xlora: bool,
518 group: Arc<Mutex<SequenceGroup>>,
519 response_index: usize,
520 creation_time: u64,
521 recognizer: SequenceRecognizer,
522 suffix: Option<String>,
523 prefix: Option<String>,
524 input_images: Option<Vec<image::DynamicImage>>,
525 input_audios: Option<Vec<AudioInput>>,
526 block_size: Option<usize>,
528 tools: Option<Arc<ToolCallingMatcher>>,
530 image_gen_response_format: Option<ImageGenerationResponseFormat>,
531 sequence_stepping_type: SeqStepType,
532 diffusion_params: Option<DiffusionGenerationParams>,
533 seq_preallocated_cache: Option<(Tensor, Tensor)>,
535 return_raw_logits: bool,
537 eos_tokens: Vec<u32>,
538 ) -> Self {
539 let prompt_len = tokens.len();
540 let mut custom_metadata = if let Some(block_size) = block_size {
541 SequenceCustomMetadata::PagedAttention {
542 logical_token_blocks: Vec::new(),
543 physical_blocks_prefill: None,
544 block_size,
545 }
546 } else {
547 SequenceCustomMetadata::None
548 };
549 custom_metadata
550 .append_tokens_to_blocks(tokens.iter().map(|x| *x as usize).collect::<Vec<_>>());
551 Self {
552 tokens,
553 prompt,
554 logprobs: Vec::new(),
555 prompt_len,
556 id,
557 timestamp,
558 state: RwLock::new(SequenceState::Waiting),
559 normal_cache: vec![None; layers],
560 normal_draft_cache: vec![None; layers],
561 cache: vec![None; layers],
562 draft_cache: vec![None; layers],
563 xlora_cache: if is_xlora {
564 Some(vec![None; layers])
565 } else {
566 None
567 },
568 mamba_state_idx: None,
569 seq_preallocated_cache,
570 responder,
571 sampler: sampler.into(),
572 stop_tokens,
573 stop_strings,
574 max_len,
575 return_logprobs,
576 prompt_tok_per_sec: 0.,
577 prompt_timestamp: None,
578 group,
579 scaling_cache: None,
580 response_index,
581 creation_time,
582 recognizer,
583 prefill_prompt_toks: None,
584 suffix,
585 prefix,
586 cumulative_logprob: 0.,
587 completion_bytes: Vec::new(),
588 stream_idx: 0,
589 last_completion_bytes_len: 0,
590 last_logprob: 0.0,
591 last_is_done: None,
592 is_tmp: false,
593 scheduling_urgency: 0,
594 multimodal: MultimodalData::new(
596 input_images,
597 input_audios,
598 image_gen_response_format,
599 diffusion_params,
600 ),
601 custom_metadata,
602 tools,
603 sequence_stepping_type,
604 return_raw_logits,
605 token_offset: 0,
606 eos_tokens,
607 total_prompt_time: None,
608 waitlisted_count: 0,
609 }
610 }
611
612 pub fn add_urgency(mut self) -> Self {
613 self.scheduling_urgency += 1;
614 self
615 }
616
617 pub fn reset_urgency(mut self) -> Self {
618 self.scheduling_urgency = 0;
619 self
620 }
621
622 pub fn compute_priority(&self) -> f64 {
626 #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
627 (self.scheduling_urgency as f64) + (self.len() as f64).log2()
628 }
629
630 pub fn prefill_v2_normal(
631 mut self,
632 cache: Vec<Option<KvCache>>,
633 toks: Vec<u32>,
634 offset: usize,
635 ) -> Self {
636 self.normal_cache = cache;
637 self.prefill_prompt_toks = Some(toks);
638 self.set_state(SequenceState::RunningPrefillPrompt);
639 self.token_offset = offset;
640 self
641 }
642
643 pub fn prefill_v2_paged(
644 mut self,
645 logical_blocks: Vec<LogicalTokenBlock>,
646 physical_blocks: Vec<Arc<PhysicalTokenBlock>>,
647 toks: Vec<u32>,
648 offset: usize,
649 ) -> Self {
650 self.prefill_prompt_toks = Some(toks);
651 self.set_state(SequenceState::RunningPrefillPrompt);
652 self.token_offset = offset;
653
654 if let SequenceCustomMetadata::PagedAttention {
655 logical_token_blocks,
656 physical_blocks_prefill,
657 block_size: _,
658 } = &mut self.custom_metadata
659 {
660 *logical_token_blocks = logical_blocks;
661 *physical_blocks_prefill = Some(physical_blocks);
662 }
663
664 self
665 }
666
667 pub fn len(&self) -> usize {
669 if let Some(toks) = &self.prefill_prompt_toks {
670 return toks.len();
671 }
672 if self.is_tmp {
673 return self.tokens.len();
674 }
675 if self.xlora_cache.as_ref().is_some_and(|c| c[0].is_some()) {
677 self.xlora_cache.as_ref().unwrap()[0]
678 .as_ref()
679 .unwrap()
680 .0
681 .dims()[2]
682 + 1
683 } else if let Some((_, x)) = &self.cache[0] {
684 x.dims()[2] + 1
685 } else {
686 self.tokens.len()
687 }
688 }
689
690 pub fn id(&self) -> &usize {
691 &self.id
692 }
693
694 pub fn is_running(&self) -> bool {
695 matches!(
696 *self.state.read().unwrap(),
697 SequenceState::RunningCompletion | SequenceState::RunningPrompt )
699 }
700
701 pub fn is_completion(&self) -> bool {
702 matches!(
703 *self.state.read().unwrap(),
704 SequenceState::RunningCompletion
705 )
706 }
707
708 pub fn is_prompt(&self) -> bool {
709 matches!(
710 *self.state.read().unwrap(),
711 SequenceState::RunningPrompt | SequenceState::RunningPrefillPrompt
712 )
713 }
714
715 pub fn is_waiting(&self) -> bool {
716 matches!(*self.state.read().unwrap(), SequenceState::Waiting)
717 }
718
719 pub fn is_finished_paged_attn(&self) -> bool {
720 matches!(
721 *self.state.read().unwrap(),
722 SequenceState::FinishedAborted
723 | SequenceState::FinishedIgnored
724 | SequenceState::Done(_)
725 )
726 }
727
728 pub fn get_toks(&self) -> &[u32] {
729 if let Some(toks) = &self.prefill_prompt_toks {
730 return toks;
731 }
732 &self.tokens
733 }
734
735 pub fn get_initial_prompt(&self) -> &str {
736 &self.prompt
737 }
738
739 pub fn set_initial_prompt(&mut self, new: String) {
740 self.prompt = new;
741 }
742
743 pub fn token_offset(&self) -> usize {
744 self.token_offset
745 }
746
747 pub(crate) fn set_toks_and_reallocate(
749 &mut self,
750 toks: Vec<u32>,
751 paged_attn_metadata: Option<&mut PagedAttentionMeta>,
752 ) {
753 self.tokens.clone_from(&toks);
754 self.prompt_len = self.tokens.len();
755 match &mut self.custom_metadata {
757 SequenceCustomMetadata::PagedAttention {
758 logical_token_blocks,
759 physical_blocks_prefill: _,
760 block_size: _,
761 } => {
762 logical_token_blocks.clear();
763 }
764 SequenceCustomMetadata::None => (),
765 }
766 self.custom_metadata
767 .append_tokens_to_blocks(toks.iter().map(|x| *x as usize).collect::<Vec<_>>());
768
769 if let Some(metadata) = paged_attn_metadata {
770 get_mut_arcmutex!(metadata.block_engine).free_sequence(*self.id());
772 get_mut_arcmutex!(metadata.block_engine).allocate(self);
773 }
774 }
775
776 pub fn completion_bytes(&self) -> &[u8] {
777 &self.completion_bytes
778 }
779
780 pub fn preallocated_cache(&self) -> Option<&(Tensor, Tensor)> {
781 self.seq_preallocated_cache.as_ref()
782 }
783
784 pub fn normal_cache(&mut self) -> &mut Vec<Option<KvCache>> {
785 &mut self.normal_cache
786 }
787
788 pub fn normal_draft_cache(&mut self) -> &mut Vec<Option<KvCache>> {
789 &mut self.normal_draft_cache
790 }
791
792 pub fn cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
793 &mut self.cache
794 }
795
796 pub fn draft_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
797 &mut self.draft_cache
798 }
799
800 pub fn xlora_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
801 self.xlora_cache.as_mut().expect("No X-LoRA cache.")
802 }
803
804 pub fn scaling_cache(&mut self) -> &mut Option<Tensor> {
805 &mut self.scaling_cache
806 }
807
808 pub fn mamba_state_idx(&self) -> Option<usize> {
809 self.mamba_state_idx
810 }
811
812 pub fn set_mamba_state_idx(&mut self, idx: Option<usize>) {
813 self.mamba_state_idx = idx;
814 }
815
816 pub fn is_xlora(&self) -> bool {
817 self.xlora_cache.is_some()
818 }
819
820 pub fn sampler(&mut self) -> Arc<Sampler> {
821 self.sampler.clone()
822 }
823
824 pub fn set_prefill_toks(&mut self, toks: Vec<u32>) {
826 self.prefill_prompt_toks = Some(toks)
827 }
828
829 pub fn reset_prefill_toks(&mut self) {
831 self.prefill_prompt_toks = None
832 }
833
834 pub(crate) fn add_tmp_tok(&mut self, tok: u32) {
836 self.is_tmp = true;
837 self.tokens.push(tok);
838 self.custom_metadata.append_token_to_blocks(tok as usize);
840 }
841
842 pub(crate) fn remove_tmp_tok(&mut self, n: usize) {
844 self.is_tmp = false;
845 self.tokens.truncate(self.tokens.len() - n);
846 self.custom_metadata.remove_tokens_from_blocks(n);
848 }
849
850 pub fn add_token(
851 &mut self,
852 tok: Logprobs,
853 completion_bytes: Vec<u8>,
854 is_done: &Option<StopReason>,
855 ) {
856 let stopped_by_token = matches!(
857 is_done,
858 Some(StopReason::Eos) | Some(StopReason::StopTok(_))
859 );
860 if !stopped_by_token {
861 self.completion_bytes.extend_from_slice(&completion_bytes);
865 self.last_completion_bytes_len = completion_bytes.len();
866 }
867 self.last_logprob = tok.logprob;
868 self.last_is_done = *is_done;
869
870 self.custom_metadata
871 .append_token_to_blocks(tok.token as usize);
872
873 self.cumulative_logprob += tok.logprob;
874 self.tokens.push(tok.token);
875 self.logprobs.push(tok);
876 self.reset_prefill_toks();
877 }
878
879 pub fn responder(&self) -> Sender<Response> {
880 self.responder.clone()
881 }
882
883 pub fn creation_time(&self) -> u64 {
884 self.creation_time
885 }
886
887 pub fn set_state(&self, state: SequenceState) {
888 if matches!(state, SequenceState::Error) {
889 let mut group = get_mut_group!(self);
890 group.n_choices = group.n_choices.saturating_sub(1);
891 }
892 *self.state.write().unwrap() = state;
893 }
894
895 pub fn getstate(&self) -> SequenceState {
896 *self.state.read().unwrap()
897 }
898
899 pub fn is_done(
900 &self,
901 tok: u32,
902 eos_tok: Option<&[u32]>,
903 max_model_len: usize,
904 ) -> Option<StopReason> {
905 let is_eos = match eos_tok {
906 Some(eos_tok) => eos_tok.contains(&tok),
907 None => false,
908 };
909 if is_eos {
910 Some(StopReason::Eos)
911 } else if matches!(
912 &*self.state.read().unwrap(),
913 SequenceState::Done(StopReason::Canceled)
914 ) {
915 Some(StopReason::Canceled)
916 } else if self.stop_tokens.contains(&tok) {
917 Some(StopReason::StopTok(tok))
918 } else if self.max_len.is_some()
919 && self.tokens.len().saturating_sub(self.prompt_len) + 1 >= self.max_len.unwrap()
920 {
921 Some(StopReason::Length(self.max_len.unwrap()))
923 } else if self.tokens.len().saturating_sub(self.prompt_len) >= max_model_len {
924 Some(StopReason::ModelLength(max_model_len))
925 } else {
926 if !self.stop_strings.is_empty() {
927 for (idx, s) in self.stop_strings.iter().enumerate() {
928 if let Some(pos) = galil_seiferas::gs_find(&self.completion_bytes, s.as_bytes())
929 {
930 return Some(StopReason::StopString {
931 stop_string_idx: idx,
932 completion_bytes_pos: pos,
933 });
934 }
935 }
936 }
937 None
938 }
939 }
940
941 pub fn logprobs(&self) -> &[Logprobs] {
942 &self.logprobs
943 }
944
945 pub fn return_logprobs(&self) -> bool {
946 self.return_logprobs
947 }
948
949 pub fn prompt_tokens(&self) -> usize {
950 self.prompt_len
951 }
952
953 pub fn stop_strings(&self) -> &[String] {
954 &self.stop_strings
955 }
956
957 pub fn get_delta(
959 &mut self,
960 ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
961 let new_decoded = self.peek_delta();
962 if matches!(new_decoded, Ok(Some(_))) {
963 self.stream_idx = self.completion_bytes.len();
964 }
965 new_decoded
966 }
967
968 pub fn peek_delta(&self) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
970 let is_first = self.stream_idx == 0;
971 let new_decoded = String::from_utf8_lossy(&self.completion_bytes[self.stream_idx..]);
972 if new_decoded.ends_with('�') {
974 return Ok(None);
975 }
976
977 if is_first {
981 return Ok(Some(new_decoded.trim_start().to_string()));
982 }
983 Ok(Some(new_decoded.to_string()))
984 }
985
986 pub fn timestamp(&self) -> u128 {
987 self.timestamp
988 }
989
990 pub fn prompt_timestamp(&self) -> Option<u128> {
991 self.prompt_timestamp
992 }
993
994 fn update_time_info(&self) {
995 let now = SystemTime::now()
996 .duration_since(UNIX_EPOCH)
997 .expect("Time travel has occurred!")
998 .as_millis();
999
1000 if let Some(ts) = self.prompt_timestamp {
1001 get_mut_group!(self).total_completion_time = now - ts;
1002 get_mut_group!(self).total_prompt_time = self.total_prompt_time.unwrap();
1003 }
1004
1005 get_mut_group!(self).total_time = now - self.timestamp;
1006
1007 get_mut_group!(self).total_prompt_toks = self.prompt_len;
1008 get_mut_group!(self).total_toks = self.len();
1009 }
1010
1011 pub fn add_image_choice_to_group(&self, choice: ImageChoice) {
1012 get_mut_group!(self).image_choices.push(choice);
1013 }
1014
1015 pub fn add_speech_pcm_to_group(&self, pcm: Arc<Vec<f32>>, rate: usize, channels: usize) {
1016 get_mut_group!(self).speech_pcms.push((pcm, rate, channels));
1017 }
1018
1019 pub fn add_choice_to_group(&self, choice: Choice) {
1020 get_mut_group!(self).choices.push(choice);
1021 self.update_time_info();
1022 }
1023
1024 pub fn add_raw_choice_to_group(&self, logit_chunks: Vec<Tensor>) {
1025 get_mut_group!(self)
1026 .raw_choices
1027 .push((logit_chunks, self.tokens.clone()));
1028 self.update_time_info();
1029 }
1030
1031 pub fn add_embedding_choice_to_group(&self, embedding: Vec<f32>) {
1032 get_mut_group!(self).embedding_choices.push(embedding);
1033 self.update_time_info();
1034 }
1035
1036 pub fn add_completion_choice_to_group(&self, mut choice: CompletionChoice) {
1037 choice.text = format!(
1038 "{}{}{}",
1039 self.prefix.as_deref().unwrap_or(""),
1040 choice.text,
1041 self.suffix.as_deref().unwrap_or("")
1042 );
1043 get_mut_group!(self)
1044 .completion_choices
1045 .push((self.cumulative_logprob, choice));
1046 self.update_time_info();
1047 }
1048
1049 pub fn get_response_index(&self) -> usize {
1050 self.response_index
1051 }
1052
1053 pub fn get_mut_group(&self) -> MutexGuard<'_, SequenceGroup> {
1054 get_mut_group!(self)
1055 }
1056
1057 pub fn add_streaming_chunk_choice_to_group(&self, chunk: ChunkChoice) {
1058 get_mut_group!(self).chat_streaming_chunks.push(chunk);
1059 self.update_time_info();
1060 }
1061
1062 pub fn add_streaming_completion_chunk_choice_to_group(&self, chunk: CompletionChunkChoice) {
1063 get_mut_group!(self).completion_streaming_chunks.push(chunk);
1064 self.update_time_info();
1065 }
1066
1067 pub fn take_images(&mut self) -> Option<Vec<image::DynamicImage>> {
1068 self.multimodal.take_images()
1069 }
1070
1071 pub fn clone_images(&self) -> Option<Vec<image::DynamicImage>> {
1072 self.multimodal.clone_images()
1073 }
1074
1075 pub fn images(&self) -> Option<&[image::DynamicImage]> {
1076 self.multimodal.images()
1077 }
1078
1079 pub fn image_hashes(&self) -> Option<&[u64]> {
1080 self.multimodal.image_hashes()
1081 }
1082
1083 pub fn has_images(&self) -> bool {
1084 self.multimodal.has_images()
1085 }
1086
1087 pub fn take_audios(&mut self) -> Option<Vec<AudioInput>> {
1088 self.multimodal.take_audios()
1089 }
1090
1091 pub fn clone_audios(&self) -> Option<Vec<AudioInput>> {
1092 self.multimodal.clone_audios()
1093 }
1094
1095 pub fn audios(&self) -> Option<&[AudioInput]> {
1096 self.multimodal.audios()
1097 }
1098
1099 pub fn audio_hashes(&self) -> Option<&[u64]> {
1100 self.multimodal.audio_hashes()
1101 }
1102
1103 pub fn has_audios(&self) -> bool {
1104 self.multimodal.has_audios()
1105 }
1106
1107 pub fn keep_num_audios(&mut self, audios_to_keep: usize) {
1109 self.multimodal.keep_num_audios(audios_to_keep)
1110 }
1111
1112 pub fn keep_num_images(&mut self, images_to_keep: usize) {
1114 self.multimodal.keep_num_images(images_to_keep)
1115 }
1116
1117 pub fn image_gen_response_format(&self) -> Option<ImageGenerationResponseFormat> {
1118 self.multimodal.image_gen_response_format()
1119 }
1120
1121 pub fn sequence_stepping_type(&self) -> &SeqStepType {
1122 &self.sequence_stepping_type
1123 }
1124
1125 pub fn get_diffusion_diffusion_params(&self) -> Option<DiffusionGenerationParams> {
1126 self.multimodal.diffusion_params()
1127 }
1128
1129 pub fn eos_tokens(&self) -> &[u32] {
1130 &self.eos_tokens
1131 }
1132}
1133
1134pub struct SequenceGroup {
1135 n_choices: usize, best_of: Option<usize>, pub total_prompt_toks: usize,
1138 pub total_toks: usize,
1139 pub total_prompt_time: u128,
1140 pub total_time: u128,
1141 pub total_completion_time: u128,
1142 choices: Vec<Choice>,
1143 image_choices: Vec<ImageChoice>,
1144 speech_pcms: Vec<(Arc<Vec<f32>>, usize, usize)>, raw_choices: Vec<(Vec<Tensor>, Vec<u32>)>,
1146 embedding_choices: Vec<Vec<f32>>,
1147 completion_choices: Vec<(f32, CompletionChoice)>,
1148 pub chat_streaming_chunks: Vec<ChunkChoice>,
1149 pub completion_streaming_chunks: Vec<CompletionChunkChoice>,
1150 pub is_streaming: bool,
1151 pub is_chat: bool,
1152}
1153
1154impl SequenceGroup {
1155 pub fn new(
1156 n_choices: usize,
1157 is_streaming: bool,
1158 is_chat: bool,
1159 best_of: Option<usize>,
1160 ) -> Self {
1161 Self {
1162 choices: Vec::new(),
1163 image_choices: Vec::new(),
1164 speech_pcms: Vec::new(),
1165 raw_choices: Vec::new(),
1166 embedding_choices: Vec::new(),
1167 completion_choices: Vec::new(),
1168 n_choices,
1169 total_prompt_toks: 0,
1170 total_toks: 0,
1171 total_prompt_time: 0,
1172 total_time: 0,
1173 total_completion_time: 0,
1174 chat_streaming_chunks: Vec::new(),
1175 completion_streaming_chunks: Vec::new(),
1176 is_streaming,
1177 is_chat,
1178 best_of,
1179 }
1180 }
1181
1182 pub fn get_choices(&self) -> &[Choice] {
1183 &self.choices
1184 }
1185
1186 pub fn get_completion_choices(&self) -> Vec<CompletionChoice> {
1188 if let Some(best_of) = self.best_of {
1189 let mut choices = self.completion_choices.clone();
1190 choices.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("No ordering."));
1192 choices
1193 .into_iter()
1194 .take(best_of)
1195 .map(|(_, x)| x)
1196 .collect::<Vec<_>>()
1197 } else {
1198 self.completion_choices
1199 .clone()
1200 .into_iter()
1201 .map(|(_, x)| x)
1202 .collect::<Vec<_>>()
1203 }
1204 }
1205
1206 pub fn get_image_choices(&self) -> &[ImageChoice] {
1207 &self.image_choices
1208 }
1209
1210 pub fn get_usage(&self) -> Usage {
1211 #[allow(clippy::cast_precision_loss)]
1212 Usage {
1213 completion_tokens: self.total_toks.saturating_sub(self.total_prompt_toks),
1214 prompt_tokens: self.total_prompt_toks,
1215 total_tokens: self.total_toks,
1216 avg_tok_per_sec: (self.total_toks as f32 / self.total_time as f32) * 1000.,
1217 avg_prompt_tok_per_sec: (self.total_prompt_toks as f32 / self.total_prompt_time as f32)
1218 * 1000.,
1219 avg_compl_tok_per_sec: (self.total_toks.saturating_sub(self.total_prompt_toks) as f32
1220 / self.total_completion_time as f32)
1221 * 1000.,
1222 total_time_sec: self.total_time as f32 / 1000.,
1223 total_completion_time_sec: self.total_completion_time as f32 / 1000.,
1224 total_prompt_time_sec: self.total_prompt_time as f32 / 1000.,
1225 }
1226 }
1227
1228 pub async fn maybe_send_chat_done_response(
1229 &self,
1230 response: ChatCompletionResponse,
1231 sender: Sender<Response>,
1232 ) -> Result<(), SendError<Response>> {
1233 if self.choices.len() == self.n_choices {
1234 sender.send(Response::Done(response)).await?;
1235 }
1236
1237 Ok(())
1238 }
1239
1240 pub async fn maybe_send_raw_done_response(
1241 &self,
1242 sender: Sender<Response>,
1243 ) -> Result<(), SendError<Response>> {
1244 if self.raw_choices.len() == self.n_choices {
1245 assert_eq!(self.raw_choices.len(), 1);
1246 let (logits_chunks, tokens) = self.raw_choices[0].clone();
1247 sender
1248 .send(Response::Raw {
1249 logits_chunks,
1250 tokens,
1251 })
1252 .await?;
1253 }
1254
1255 Ok(())
1256 }
1257
1258 pub async fn maybe_send_embedding_done_response(
1259 &self,
1260 sender: Sender<Response>,
1261 ) -> Result<(), SendError<Response>> {
1262 if self.embedding_choices.len() == self.n_choices {
1263 assert_eq!(self.embedding_choices.len(), 1);
1264 let embeddings = self.embedding_choices[0].clone();
1265 let prompt_tokens = self.total_prompt_toks;
1266 let total_tokens = self.total_toks;
1267 sender
1268 .send(Response::Embeddings {
1269 embeddings,
1270 prompt_tokens,
1271 total_tokens,
1272 })
1273 .await?;
1274 }
1275
1276 Ok(())
1277 }
1278
1279 pub async fn maybe_send_image_gen_response(
1280 &self,
1281 response: ImageGenerationResponse,
1282 sender: Sender<Response>,
1283 ) -> Result<(), SendError<Response>> {
1284 if self.image_choices.len() == self.n_choices {
1285 sender.send(Response::ImageGeneration(response)).await?;
1286 }
1287
1288 Ok(())
1289 }
1290
1291 pub async fn maybe_send_speech_response(
1292 &self,
1293 sender: Sender<Response>,
1294 ) -> Result<(), SendError<Response>> {
1295 assert_eq!(self.speech_pcms.len(), 1);
1296
1297 let (pcm, rate, channels) = self.speech_pcms[0].clone();
1298 sender
1299 .send(Response::Speech {
1300 pcm,
1301 rate,
1302 channels,
1303 })
1304 .await?;
1305
1306 Ok(())
1307 }
1308
1309 pub async fn maybe_send_streaming_response(
1310 &mut self,
1311 seq: &Sequence,
1312 model: String,
1313 usage_opt: Option<Usage>,
1314 ) -> Result<(), Box<SendError<Response>>> {
1315 if self.chat_streaming_chunks.len() == self.n_choices && self.is_streaming {
1316 let mut swap_streaming_chunks = vec![];
1317
1318 std::mem::swap(&mut swap_streaming_chunks, &mut self.chat_streaming_chunks);
1319
1320 seq.responder()
1321 .send(Response::Chunk(ChatCompletionChunkResponse {
1322 id: seq.id.to_string(),
1323 choices: swap_streaming_chunks,
1324 created: seq.timestamp,
1325 model: model.clone(),
1326 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
1327 object: "chat.completion.chunk".to_string(),
1328 usage: usage_opt,
1329 }))
1330 .await?;
1331 } else if self.completion_streaming_chunks.len() == self.n_choices && self.is_streaming {
1332 let mut swap_streaming_chunks = vec![];
1333
1334 std::mem::swap(
1335 &mut swap_streaming_chunks,
1336 &mut self.completion_streaming_chunks,
1337 );
1338
1339 seq.responder()
1340 .send(Response::CompletionChunk(CompletionChunkResponse {
1341 id: seq.id.to_string(),
1342 choices: swap_streaming_chunks,
1343 created: seq.timestamp,
1344 model: model.clone(),
1345 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
1346 object: "text_completion".to_string(),
1347 }))
1348 .await?;
1349 }
1350 Ok(())
1351 }
1352
1353 pub async fn maybe_send_completion_done_response(
1354 &self,
1355 response: CompletionResponse,
1356 sender: Sender<Response>,
1357 ) -> Result<(), Box<SendError<Response>>> {
1358 if self.completion_choices.len() == self.n_choices {
1359 sender.send(Response::CompletionDone(response)).await?;
1360 }
1361 Ok(())
1362 }
1363}