1use crate::{
2 get_mut_group,
3 pipeline::{text_models_inputs_processor::PagedAttentionMeta, LayerCaches},
4 response::{ChatCompletionChunkResponse, Choice, ChunkChoice, Response, SYSTEM_FINGERPRINT},
5 sampler::{Logprobs, Sampler},
6 ChatCompletionResponse, Usage,
7};
8use crate::{
9 paged_attention::{BlockEngineSequence, LogicalTokenBlock},
10 pipeline::{DiffusionGenerationParams, KvCache},
11 response::CompletionChoice,
12 tools::ToolCallingMatcher,
13 CompletionChunkChoice, CompletionChunkResponse, CompletionResponse, ImageChoice,
14 ImageGenerationResponse, ImageGenerationResponseFormat,
15};
16use candle_core::Tensor;
17use std::{
18 fmt::Display,
19 sync::{Arc, RwLock},
20 time::{SystemTime, UNIX_EPOCH},
21};
22use tokio::sync::{
23 mpsc::{error::SendError, Sender},
24 Mutex, MutexGuard,
25};
26
27#[derive(Clone, Copy, PartialEq, Debug)]
28pub enum StopReason {
29 Eos,
30 StopTok(u32),
31 Length(usize),
32 ModelLength(usize),
33 StopString {
34 stop_string_idx: usize,
35 completion_bytes_pos: usize,
36 },
37 Canceled,
38 GeneratedImage,
39}
40
41impl Display for StopReason {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 StopReason::Eos => write!(f, "stop"),
45 StopReason::Length(_) | StopReason::ModelLength(_) => write!(f, "length"),
46 StopReason::StopTok(_) | StopReason::StopString { .. } => write!(f, "stop"),
47 StopReason::Canceled => write!(f, "canceled"),
48 StopReason::GeneratedImage => write!(f, "generated-image"),
49 }
50 }
51}
52
53#[derive(Clone, Copy, PartialEq, Debug)]
54pub enum SequenceState {
55 Done(StopReason),
56 RunningPrompt,
57 RunningCompletion,
58 Waiting,
59 Error,
60 RunningPrefillPrompt,
61 FinishedAborted,
63 FinishedIgnored,
64 Swapped,
65}
66
67pub enum SequenceRecognizer {
68 Llguidance(Box<llguidance::Matcher>),
69 None,
70}
71
72enum SequenceCustomMetadata {
73 PagedAttention {
74 logical_token_blocks: Vec<LogicalTokenBlock>,
75 block_size: usize,
76 },
77 None,
78}
79
80macro_rules! blocks_to_add_new_tok {
81 ($logical_token_blocks:expr) => {{
82 let last = $logical_token_blocks.last();
83 if !last.is_some_and(|last| last.is_full() || last.is_empty()) {
84 0
86 } else {
87 1
88 }
89 }};
90}
91
92impl SequenceCustomMetadata {
93 fn append_token_to_blocks(&mut self, tok: usize) {
94 match self {
95 Self::PagedAttention {
96 logical_token_blocks,
97 block_size,
98 } => {
99 let last = logical_token_blocks.last_mut();
100 match last {
101 Some(last) => {
102 last.append_token_id(tok);
103 }
104 None => {
105 logical_token_blocks.push(LogicalTokenBlock::new(*block_size));
106 logical_token_blocks
107 .last_mut()
108 .unwrap()
109 .append_token_id(tok);
110 }
111 }
112 if logical_token_blocks.last().as_ref().unwrap().is_full() {
113 logical_token_blocks.push(LogicalTokenBlock::new(*block_size));
114 }
115 }
116 Self::None => (),
117 }
118 }
119
120 fn pop_token_from_blocks(&mut self) {
121 match self {
122 Self::PagedAttention {
123 logical_token_blocks,
124 block_size: _,
125 } => {
126 let last = logical_token_blocks.last_mut().unwrap();
127 last.pop_token();
128 }
129 Self::None => (),
130 }
131 }
132
133 fn append_tokens_to_blocks(&mut self, toks: Vec<usize>) {
134 for tok in toks {
135 self.append_token_to_blocks(tok);
136 }
137 }
138
139 fn remove_tokens_from_blocks(&mut self, n: usize) {
140 for _ in 0..n {
141 self.pop_token_from_blocks();
142 }
143 }
144}
145
146#[derive(Clone, Copy)]
147pub enum SeqStepType {
148 PromptAndDecode,
149 OneShot,
150}
151
152pub struct Sequence {
153 id: usize,
155 prompt_len: usize,
156 max_len: Option<usize>,
157 timestamp: u128,
158 sampler: Arc<Sampler>,
159 stop_tokens: Vec<u32>,
160 stop_strings: Vec<String>,
161 return_logprobs: bool,
162 responder: Sender<Response>,
163 response_index: usize,
164 creation_time: u64,
165 prompt: String,
166 sequence_stepping_type: SeqStepType,
167 pub(crate) return_raw_logits: bool,
168 token_offset: usize,
169 eos_tokens: Vec<u32>,
170
171 image_gen_response_format: Option<ImageGenerationResponseFormat>,
173 diffusion_params: Option<DiffusionGenerationParams>,
174
175 suffix: Option<String>,
177 prefix: Option<String>,
178
179 is_tmp: bool,
181
182 prefill_prompt_toks: Option<Vec<u32>>,
184
185 normal_cache: Vec<Option<KvCache>>,
187 normal_draft_cache: Vec<Option<KvCache>>,
188 scaling_cache: Option<Tensor>,
189 cache: LayerCaches,
190 draft_cache: LayerCaches,
191 xlora_cache: Option<LayerCaches>,
192
193 seq_preallocated_cache: Option<(Tensor, Tensor)>,
195
196 tokens: Vec<u32>,
198 logprobs: Vec<Logprobs>,
199 cumulative_logprob: f32,
200 last_logprob: f32,
201 last_completion_bytes_len: usize,
202 last_is_done: Option<StopReason>,
203 completion_bytes: Vec<u8>,
204 stream_idx: usize,
205 pub recognizer: SequenceRecognizer,
206 scheduling_urgency: usize, input_images: Option<Vec<image::DynamicImage>>,
208 pub cached_pixel_values: Option<Tensor>,
209 pub cached_img_thw: Option<Tensor>,
210 pub cached_vid_thw: Option<Tensor>,
211 pub has_changed_prompt: bool,
212
213 pub prompt_tok_per_sec: f32,
215 pub prompt_timestamp: Option<u128>,
216 pub total_prompt_time: Option<u128>,
217 group: Arc<Mutex<SequenceGroup>>,
218 state: RwLock<SequenceState>,
219
220 custom_metadata: SequenceCustomMetadata,
222
223 pub tools: Option<Arc<ToolCallingMatcher>>,
225}
226
227impl BlockEngineSequence for Sequence {
228 fn blocks_to_add_new_tok(&self) -> usize {
229 match &self.custom_metadata {
230 SequenceCustomMetadata::PagedAttention {
231 logical_token_blocks,
232 block_size: _,
233 } => {
234 blocks_to_add_new_tok!(logical_token_blocks)
235 }
236 SequenceCustomMetadata::None => unreachable!(),
237 }
238 }
239
240 fn get_id(&self) -> usize {
241 self.id
242 }
243
244 fn get_logical_token_blocks(&self) -> usize {
245 match &self.custom_metadata {
246 SequenceCustomMetadata::PagedAttention {
247 logical_token_blocks,
248 block_size: _,
249 } => logical_token_blocks.len(),
250 SequenceCustomMetadata::None => unreachable!(),
251 }
252 }
253}
254
255impl Sequence {
256 #[allow(clippy::too_many_arguments)]
257 pub fn new_waiting(
258 tokens: Vec<u32>,
259 prompt: String,
260 id: usize,
261 timestamp: u128,
262 layers: usize,
263 responder: Sender<Response>,
264 sampler: Sampler,
265 stop_tokens: Vec<u32>,
266 stop_strings: Vec<String>,
267 max_len: Option<usize>,
268 return_logprobs: bool,
269 is_xlora: bool,
270 group: Arc<Mutex<SequenceGroup>>,
271 response_index: usize,
272 creation_time: u64,
273 recognizer: SequenceRecognizer,
274 suffix: Option<String>,
275 prefix: Option<String>,
276 input_images: Option<Vec<image::DynamicImage>>,
277 block_size: Option<usize>,
279 tools: Option<Arc<ToolCallingMatcher>>,
281 image_gen_response_format: Option<ImageGenerationResponseFormat>,
282 sequence_stepping_type: SeqStepType,
283 diffusion_params: Option<DiffusionGenerationParams>,
284 seq_preallocated_cache: Option<(Tensor, Tensor)>,
286 return_raw_logits: bool,
288 eos_tokens: Vec<u32>,
289 ) -> Self {
290 let prompt_len = tokens.len();
291 let mut custom_metadata = if let Some(block_size) = block_size {
292 SequenceCustomMetadata::PagedAttention {
293 logical_token_blocks: Vec::new(),
294 block_size,
295 }
296 } else {
297 SequenceCustomMetadata::None
298 };
299 custom_metadata
300 .append_tokens_to_blocks(tokens.iter().map(|x| *x as usize).collect::<Vec<_>>());
301 Self {
302 tokens,
303 prompt,
304 logprobs: Vec::new(),
305 prompt_len,
306 id,
307 timestamp,
308 state: RwLock::new(SequenceState::Waiting),
309 normal_cache: vec![None; layers],
310 normal_draft_cache: vec![None; layers],
311 cache: vec![None; layers],
312 draft_cache: vec![None; layers],
313 xlora_cache: if is_xlora {
314 Some(vec![None; layers])
315 } else {
316 None
317 },
318 seq_preallocated_cache,
319 responder,
320 sampler: sampler.into(),
321 stop_tokens,
322 stop_strings,
323 max_len,
324 return_logprobs,
325 prompt_tok_per_sec: 0.,
326 prompt_timestamp: None,
327 group,
328 scaling_cache: None,
329 response_index,
330 creation_time,
331 recognizer,
332 prefill_prompt_toks: None,
333 suffix,
334 prefix,
335 cumulative_logprob: 0.,
336 completion_bytes: Vec::new(),
337 stream_idx: 0,
338 last_completion_bytes_len: 0,
339 last_logprob: 0.0,
340 last_is_done: None,
341 is_tmp: false,
342 scheduling_urgency: 0,
343 input_images,
344 custom_metadata,
345 tools,
346 image_gen_response_format,
347 sequence_stepping_type,
348 diffusion_params,
349 cached_pixel_values: None,
350 cached_img_thw: None,
351 cached_vid_thw: None,
352 return_raw_logits,
353 token_offset: 0,
354 eos_tokens,
355 has_changed_prompt: false,
356 total_prompt_time: None,
357 }
358 }
359
360 pub fn add_urgency(mut self) -> Self {
361 self.scheduling_urgency += 1;
362 self
363 }
364
365 pub fn reset_urgency(mut self) -> Self {
366 self.scheduling_urgency = 0;
367 self
368 }
369
370 pub fn compute_priority(&self) -> f64 {
374 #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
375 (self.scheduling_urgency as f64) + (self.len() as f64).log2()
376 }
377
378 pub fn prefill(
379 mut self,
380 cache: LayerCaches,
381 xlora_cache: Option<LayerCaches>,
382 toks: Vec<u32>,
383 ) -> Self {
384 self.cache = cache;
385 self.xlora_cache = xlora_cache;
386 self.prefill_prompt_toks = Some(toks);
387 self.set_state(SequenceState::RunningPrefillPrompt);
388 self
389 }
390
391 pub fn prefill_v2(
392 mut self,
393 cache: Vec<Option<KvCache>>,
394 toks: Vec<u32>,
395 offset: usize,
396 ) -> Self {
397 self.normal_cache = cache;
398 self.prefill_prompt_toks = Some(toks);
399 self.set_state(SequenceState::RunningPrefillPrompt);
400 self.token_offset = offset;
401 self
402 }
403
404 pub fn len(&self) -> usize {
406 if let Some(toks) = &self.prefill_prompt_toks {
407 return toks.len();
408 }
409 if self.is_tmp {
410 return self.tokens.len();
411 }
412 if self.xlora_cache.as_ref().is_some_and(|c| c[0].is_some()) {
414 self.xlora_cache.as_ref().unwrap()[0]
415 .as_ref()
416 .unwrap()
417 .0
418 .dims()[2]
419 + 1
420 } else if let Some((_, x)) = &self.cache[0] {
421 x.dims()[2] + 1
422 } else {
423 self.tokens.len()
424 }
425 }
426
427 pub fn id(&self) -> &usize {
428 &self.id
429 }
430
431 pub fn is_running(&self) -> bool {
432 matches!(
433 *self.state.read().unwrap(),
434 SequenceState::RunningCompletion | SequenceState::RunningPrompt )
436 }
437
438 pub fn is_completion(&self) -> bool {
439 matches!(
440 *self.state.read().unwrap(),
441 SequenceState::RunningCompletion
442 )
443 }
444
445 pub fn is_prompt(&self) -> bool {
446 matches!(
447 *self.state.read().unwrap(),
448 SequenceState::RunningPrompt | SequenceState::RunningPrefillPrompt
449 )
450 }
451
452 pub fn is_waiting(&self) -> bool {
453 matches!(*self.state.read().unwrap(), SequenceState::Waiting)
454 }
455
456 pub fn is_finished_paged_attn(&self) -> bool {
457 matches!(
458 *self.state.read().unwrap(),
459 SequenceState::FinishedAborted
460 | SequenceState::FinishedIgnored
461 | SequenceState::Done(_)
462 )
463 }
464
465 pub fn get_toks(&self) -> &[u32] {
466 if let Some(toks) = &self.prefill_prompt_toks {
467 return toks;
468 }
469 &self.tokens
470 }
471
472 pub fn get_initial_prompt(&self) -> &str {
473 &self.prompt
474 }
475
476 pub fn set_initial_prompt(&mut self, new: String) {
477 self.prompt = new;
478 }
479
480 pub fn token_offset(&self) -> usize {
481 self.token_offset
482 }
483
484 pub(crate) fn set_toks_and_reallocate(
486 &mut self,
487 toks: Vec<u32>,
488 paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
489 ) {
490 self.tokens.clone_from(&toks);
491 self.prompt_len = self.tokens.len();
492 match &mut self.custom_metadata {
494 SequenceCustomMetadata::PagedAttention {
495 logical_token_blocks,
496 block_size: _,
497 } => {
498 logical_token_blocks.clear();
499 }
500 SequenceCustomMetadata::None => (),
501 }
502 self.custom_metadata
503 .append_tokens_to_blocks(toks.iter().map(|x| *x as usize).collect::<Vec<_>>());
504
505 if let Some(metadata) = paged_attn_metadata {
506 metadata.block_engine.free_sequence(*self.id());
508 metadata.block_engine.allocate(self);
509 }
510 }
511
512 pub fn completion_bytes(&self) -> &[u8] {
513 &self.completion_bytes
514 }
515
516 pub fn preallocated_cache(&self) -> Option<&(Tensor, Tensor)> {
517 self.seq_preallocated_cache.as_ref()
518 }
519
520 pub fn normal_cache(&mut self) -> &mut Vec<Option<KvCache>> {
521 &mut self.normal_cache
522 }
523
524 pub fn normal_draft_cache(&mut self) -> &mut Vec<Option<KvCache>> {
525 &mut self.normal_draft_cache
526 }
527
528 pub fn cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
529 &mut self.cache
530 }
531
532 pub fn draft_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
533 &mut self.draft_cache
534 }
535
536 pub fn xlora_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
537 self.xlora_cache.as_mut().expect("No X-LoRA cache.")
538 }
539
540 pub fn scaling_cache(&mut self) -> &mut Option<Tensor> {
541 &mut self.scaling_cache
542 }
543
544 pub fn is_xlora(&self) -> bool {
545 self.xlora_cache.is_some()
546 }
547
548 pub fn sampler(&mut self) -> Arc<Sampler> {
549 self.sampler.clone()
550 }
551
552 pub fn set_prefill_toks(&mut self, toks: Vec<u32>) {
554 self.prefill_prompt_toks = Some(toks)
555 }
556
557 pub fn reset_prefill_toks(&mut self) {
559 self.prefill_prompt_toks = None
560 }
561
562 pub(crate) fn add_tmp_tok(&mut self, tok: u32) {
564 self.is_tmp = true;
565 self.tokens.push(tok);
566 self.custom_metadata.append_token_to_blocks(tok as usize);
568 }
569
570 pub(crate) fn remove_tmp_tok(&mut self, n: usize) {
572 self.is_tmp = false;
573 self.tokens.truncate(self.tokens.len() - n);
574 self.custom_metadata.remove_tokens_from_blocks(n);
576 }
577
578 pub fn add_token(
579 &mut self,
580 tok: Logprobs,
581 completion_bytes: Vec<u8>,
582 is_done: &Option<StopReason>,
583 ) {
584 let stopped_by_token = matches!(
585 is_done,
586 Some(StopReason::Eos) | Some(StopReason::StopTok(_))
587 );
588 if !stopped_by_token {
589 self.completion_bytes.extend_from_slice(&completion_bytes);
593 self.last_completion_bytes_len = completion_bytes.len();
594 }
595 self.last_logprob = tok.logprob;
596 self.last_is_done = *is_done;
597
598 self.custom_metadata
599 .append_token_to_blocks(tok.token as usize);
600
601 self.cumulative_logprob += tok.logprob;
602 self.tokens.push(tok.token);
603 self.logprobs.push(tok);
604 self.reset_prefill_toks();
605 }
606
607 pub fn responder(&self) -> Sender<Response> {
608 self.responder.clone()
609 }
610
611 pub fn creation_time(&self) -> u64 {
612 self.creation_time
613 }
614
615 pub fn set_state(&self, state: SequenceState) {
616 if matches!(state, SequenceState::Error) {
617 get_mut_group!(self).n_choices -= 1;
618 }
619 *self.state.write().unwrap() = state;
620 }
621
622 pub fn getstate(&self) -> SequenceState {
623 *self.state.read().unwrap()
624 }
625
626 pub fn is_done(
627 &self,
628 tok: u32,
629 eos_tok: Option<&[u32]>,
630 max_model_len: usize,
631 ) -> Option<StopReason> {
632 let is_eos = match eos_tok {
633 Some(eos_tok) => eos_tok.iter().any(|t| *t == tok),
634 None => false,
635 };
636 if is_eos {
637 Some(StopReason::Eos)
638 } else if matches!(
639 &*self.state.read().unwrap(),
640 SequenceState::Done(StopReason::Canceled)
641 ) {
642 Some(StopReason::Canceled)
643 } else if self.stop_tokens.contains(&tok) {
644 Some(StopReason::StopTok(tok))
645 } else if self.max_len.is_some()
646 && self.tokens.len().saturating_sub(self.prompt_len) == self.max_len.unwrap()
647 {
648 Some(StopReason::Length(self.max_len.unwrap()))
650 } else if self.tokens.len().saturating_sub(self.prompt_len) == max_model_len {
651 Some(StopReason::ModelLength(max_model_len))
652 } else {
653 if !self.stop_strings.is_empty() {
654 for (idx, s) in self.stop_strings.iter().enumerate() {
655 if let Some(pos) = galil_seiferas::gs_find(&self.completion_bytes, s.as_bytes())
656 {
657 return Some(StopReason::StopString {
658 stop_string_idx: idx,
659 completion_bytes_pos: pos,
660 });
661 }
662 }
663 }
664 None
665 }
666 }
667
668 pub fn logprobs(&self) -> &[Logprobs] {
669 &self.logprobs
670 }
671
672 pub fn return_logprobs(&self) -> bool {
673 self.return_logprobs
674 }
675
676 pub fn prompt_tokens(&self) -> usize {
677 self.prompt_len
678 }
679
680 pub fn stop_strings(&self) -> &[String] {
681 &self.stop_strings
682 }
683
684 pub fn get_delta(
686 &mut self,
687 ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
688 let new_decoded = self.peek_delta();
689 if matches!(new_decoded, Ok(Some(_))) {
690 self.stream_idx = self.completion_bytes.len();
691 }
692 new_decoded
693 }
694
695 pub fn peek_delta(&self) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
697 let is_first = self.stream_idx == 0;
698 let new_decoded = String::from_utf8_lossy(&self.completion_bytes[self.stream_idx..]);
699 if new_decoded.ends_with('�') {
701 return Ok(None);
702 }
703
704 if is_first {
708 return Ok(Some(new_decoded.trim_start().to_string()));
709 }
710 Ok(Some(new_decoded.to_string()))
711 }
712
713 pub fn timestamp(&self) -> u128 {
714 self.timestamp
715 }
716
717 pub fn prompt_timestamp(&self) -> Option<u128> {
718 self.prompt_timestamp
719 }
720
721 fn update_time_info(&self) {
722 let now = SystemTime::now()
723 .duration_since(UNIX_EPOCH)
724 .expect("Time travel has occurred!")
725 .as_millis();
726
727 if let Some(ts) = self.prompt_timestamp {
728 get_mut_group!(self).total_completion_time = now - ts;
729 get_mut_group!(self).total_prompt_time = self.total_prompt_time.unwrap();
730 }
731
732 get_mut_group!(self).total_time = now - self.timestamp;
733
734 get_mut_group!(self).total_prompt_toks = self.prompt_len;
735 get_mut_group!(self).total_toks = self.len();
736 }
737
738 pub fn add_image_choice_to_group(&self, choice: ImageChoice) {
739 get_mut_group!(self).image_choices.push(choice);
740 self.update_time_info();
741 }
742
743 pub fn add_choice_to_group(&self, choice: Choice) {
744 get_mut_group!(self).choices.push(choice);
745 self.update_time_info();
746 }
747
748 pub fn add_raw_choice_to_group(&self, logit_chunks: Vec<Tensor>) {
749 get_mut_group!(self)
750 .raw_choices
751 .push((logit_chunks, self.tokens.clone()));
752 self.update_time_info();
753 }
754
755 pub fn add_completion_choice_to_group(&self, mut choice: CompletionChoice) {
756 choice.text = format!(
757 "{}{}{}",
758 self.prefix.as_deref().unwrap_or(""),
759 choice.text,
760 self.suffix.as_deref().unwrap_or("")
761 );
762 get_mut_group!(self)
763 .completion_choices
764 .push((self.cumulative_logprob, choice));
765 self.update_time_info();
766 }
767
768 pub fn get_response_index(&self) -> usize {
769 self.response_index
770 }
771
772 pub fn get_mut_group(&self) -> MutexGuard<'_, SequenceGroup> {
773 get_mut_group!(self)
774 }
775
776 pub fn add_streaming_chunk_choice_to_group(&self, chunk: ChunkChoice) {
777 get_mut_group!(self).chat_streaming_chunks.push(chunk);
778 self.update_time_info();
779 }
780
781 pub fn add_streaming_completion_chunk_choice_to_group(&self, chunk: CompletionChunkChoice) {
782 get_mut_group!(self).completion_streaming_chunks.push(chunk);
783 self.update_time_info();
784 }
785
786 pub fn take_images(&mut self) -> Option<Vec<image::DynamicImage>> {
787 if self.has_changed_prompt {
789 self.input_images.take()
791 } else {
792 self.input_images.clone()
794 }
795 }
796
797 pub fn clone_images(&mut self) -> Option<Vec<image::DynamicImage>> {
798 self.input_images.clone()
799 }
800
801 pub fn images(&self) -> Option<&[image::DynamicImage]> {
802 self.input_images.as_deref()
803 }
804
805 pub fn has_images(&self) -> bool {
806 self.input_images
807 .as_ref()
808 .is_some_and(|images| !images.is_empty())
809 }
810
811 pub fn image_gen_response_format(&self) -> Option<ImageGenerationResponseFormat> {
812 self.image_gen_response_format
813 }
814
815 pub fn sequence_stepping_type(&self) -> &SeqStepType {
816 &self.sequence_stepping_type
817 }
818
819 pub fn get_diffusion_diffusion_params(&self) -> Option<DiffusionGenerationParams> {
820 self.diffusion_params.clone()
821 }
822
823 pub fn eos_tokens(&self) -> &[u32] {
824 &self.eos_tokens
825 }
826}
827
828pub struct SequenceGroup {
829 n_choices: usize, best_of: Option<usize>, pub total_prompt_toks: usize,
832 pub total_toks: usize,
833 pub total_prompt_time: u128,
834 pub total_time: u128,
835 pub total_completion_time: u128,
836 choices: Vec<Choice>,
837 image_choices: Vec<ImageChoice>,
838 raw_choices: Vec<(Vec<Tensor>, Vec<u32>)>,
839 completion_choices: Vec<(f32, CompletionChoice)>,
840 pub chat_streaming_chunks: Vec<ChunkChoice>,
841 pub completion_streaming_chunks: Vec<CompletionChunkChoice>,
842 pub is_streaming: bool,
843 pub is_chat: bool,
844}
845
846impl SequenceGroup {
847 pub fn new(
848 n_choices: usize,
849 is_streaming: bool,
850 is_chat: bool,
851 best_of: Option<usize>,
852 ) -> Self {
853 Self {
854 choices: Vec::new(),
855 image_choices: Vec::new(),
856 raw_choices: Vec::new(),
857 completion_choices: Vec::new(),
858 n_choices,
859 total_prompt_toks: 0,
860 total_toks: 0,
861 total_prompt_time: 0,
862 total_time: 0,
863 total_completion_time: 0,
864 chat_streaming_chunks: Vec::new(),
865 completion_streaming_chunks: Vec::new(),
866 is_streaming,
867 is_chat,
868 best_of,
869 }
870 }
871
872 pub fn get_choices(&self) -> &[Choice] {
873 &self.choices
874 }
875
876 pub fn get_completion_choices(&self) -> Vec<CompletionChoice> {
878 if let Some(best_of) = self.best_of {
879 let mut choices = self.completion_choices.clone();
880 choices.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("No ordering."));
882 choices
883 .into_iter()
884 .take(best_of)
885 .map(|(_, x)| x)
886 .collect::<Vec<_>>()
887 } else {
888 self.completion_choices
889 .clone()
890 .into_iter()
891 .map(|(_, x)| x)
892 .collect::<Vec<_>>()
893 }
894 }
895
896 pub fn get_image_choices(&self) -> &[ImageChoice] {
897 &self.image_choices
898 }
899
900 pub fn get_usage(&self) -> Usage {
901 #[allow(clippy::cast_precision_loss)]
902 Usage {
903 completion_tokens: self.total_toks - self.total_prompt_toks,
904 prompt_tokens: self.total_prompt_toks,
905 total_tokens: self.total_toks,
906 avg_tok_per_sec: (self.total_toks as f32 / self.total_time as f32) * 1000.,
907 avg_prompt_tok_per_sec: (self.total_prompt_toks as f32 / self.total_prompt_time as f32)
908 * 1000.,
909 avg_compl_tok_per_sec: ((self.total_toks - self.total_prompt_toks) as f32
910 / self.total_completion_time as f32)
911 * 1000.,
912 total_time_sec: self.total_time as f32 / 1000.,
913 total_completion_time_sec: self.total_completion_time as f32 / 1000.,
914 total_prompt_time_sec: self.total_prompt_time as f32 / 1000.,
915 }
916 }
917
918 pub async fn maybe_send_chat_done_response(
919 &self,
920 response: ChatCompletionResponse,
921 sender: Sender<Response>,
922 ) -> Result<(), SendError<Response>> {
923 if self.choices.len() == self.n_choices {
924 sender.send(Response::Done(response)).await?;
925 }
926
927 Ok(())
928 }
929
930 pub async fn maybe_send_raw_done_response(
931 &self,
932 sender: Sender<Response>,
933 ) -> Result<(), SendError<Response>> {
934 if self.raw_choices.len() == self.n_choices {
935 assert_eq!(self.raw_choices.len(), 1);
936 let (logits_chunks, tokens) = self.raw_choices[0].clone();
937 sender
938 .send(Response::Raw {
939 logits_chunks,
940 tokens,
941 })
942 .await?;
943 }
944
945 Ok(())
946 }
947
948 pub async fn maybe_send_image_gen_response(
949 &self,
950 response: ImageGenerationResponse,
951 sender: Sender<Response>,
952 ) -> Result<(), SendError<Response>> {
953 if self.image_choices.len() == self.n_choices {
954 sender.send(Response::ImageGeneration(response)).await?;
955 }
956
957 Ok(())
958 }
959
960 pub async fn maybe_send_streaming_response(
961 &mut self,
962 seq: &Sequence,
963 model: String,
964 usage_opt: Option<Usage>,
965 ) -> Result<(), Box<SendError<Response>>> {
966 if self.chat_streaming_chunks.len() == self.n_choices && self.is_streaming {
967 let mut swap_streaming_chunks = vec![];
968
969 std::mem::swap(&mut swap_streaming_chunks, &mut self.chat_streaming_chunks);
970
971 seq.responder()
972 .send(Response::Chunk(ChatCompletionChunkResponse {
973 id: seq.id.to_string(),
974 choices: swap_streaming_chunks,
975 created: seq.timestamp,
976 model: model.clone(),
977 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
978 object: "chat.completion.chunk".to_string(),
979 usage: usage_opt,
980 }))
981 .await?;
982 } else if self.completion_streaming_chunks.len() == self.n_choices && self.is_streaming {
983 let mut swap_streaming_chunks = vec![];
984
985 std::mem::swap(
986 &mut swap_streaming_chunks,
987 &mut self.completion_streaming_chunks,
988 );
989
990 seq.responder()
991 .send(Response::CompletionChunk(CompletionChunkResponse {
992 id: seq.id.to_string(),
993 choices: swap_streaming_chunks,
994 created: seq.timestamp,
995 model: model.clone(),
996 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
997 object: "text_completion".to_string(),
998 }))
999 .await?;
1000 }
1001 Ok(())
1002 }
1003
1004 pub async fn maybe_send_completion_done_response(
1005 &self,
1006 response: CompletionResponse,
1007 sender: Sender<Response>,
1008 ) -> Result<(), Box<SendError<Response>>> {
1009 if self.completion_choices.len() == self.n_choices {
1010 sender.send(Response::CompletionDone(response)).await?;
1011 }
1012 Ok(())
1013 }
1014}