1use crate::{
2 get_mut_group,
3 pipeline::{text_models_inputs_processor::PagedAttentionMeta, LayerCaches},
4 response::{ChatCompletionChunkResponse, Choice, ChunkChoice, Response, SYSTEM_FINGERPRINT},
5 sampler::{Logprobs, Sampler},
6 ChatCompletionResponse, Usage,
7};
8use crate::{
9 paged_attention::{BlockEngineSequence, LogicalTokenBlock},
10 pipeline::{DiffusionGenerationParams, KvCache},
11 response::CompletionChoice,
12 tools::ToolCallingMatcher,
13 CompletionChunkChoice, CompletionChunkResponse, CompletionResponse, ImageChoice,
14 ImageGenerationResponse, ImageGenerationResponseFormat,
15};
16use candle_core::Tensor;
17use std::{
18 fmt::Display,
19 sync::{Arc, RwLock},
20 time::{SystemTime, UNIX_EPOCH},
21};
22use tokio::sync::{
23 mpsc::{error::SendError, Sender},
24 Mutex, MutexGuard,
25};
26
27#[derive(Clone, Copy, PartialEq, Debug)]
28pub enum StopReason {
29 Eos,
30 StopTok(u32),
31 Length(usize),
32 ModelLength(usize),
33 StopString {
34 stop_string_idx: usize,
35 completion_bytes_pos: usize,
36 },
37 Canceled,
38 GeneratedImage,
39}
40
41impl Display for StopReason {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 StopReason::Eos => write!(f, "stop"),
45 StopReason::Length(_) | StopReason::ModelLength(_) => write!(f, "length"),
46 StopReason::StopTok(_) | StopReason::StopString { .. } => write!(f, "stop"),
47 StopReason::Canceled => write!(f, "canceled"),
48 StopReason::GeneratedImage => write!(f, "generated-image"),
49 }
50 }
51}
52
53#[derive(Clone, Copy, PartialEq, Debug)]
54pub enum SequenceState {
55 Done(StopReason),
56 RunningPrompt,
57 RunningCompletion,
58 Waiting,
59 Error,
60 RunningPrefillPrompt,
61 FinishedAborted,
63 FinishedIgnored,
64 Swapped,
65}
66
67pub enum SequenceRecognizer {
68 Llguidance(Box<llguidance::Constraint>),
69 None,
70}
71
72enum SequenceCustomMetadata {
73 PagedAttention {
74 logical_token_blocks: Vec<LogicalTokenBlock>,
75 block_size: usize,
76 },
77 None,
78}
79
80macro_rules! blocks_to_add_new_tok {
81 ($logical_token_blocks:expr) => {{
82 let last = $logical_token_blocks.last();
83 if !last.is_some_and(|last| last.is_full() || last.is_empty()) {
84 0
86 } else {
87 1
88 }
89 }};
90}
91
92impl SequenceCustomMetadata {
93 fn append_token_to_blocks(&mut self, tok: usize) {
94 match self {
95 Self::PagedAttention {
96 logical_token_blocks,
97 block_size,
98 } => {
99 let last = logical_token_blocks.last_mut();
100 match last {
101 Some(last) => {
102 last.append_token_id(tok);
103 }
104 None => {
105 logical_token_blocks.push(LogicalTokenBlock::new(*block_size));
106 logical_token_blocks
107 .last_mut()
108 .unwrap()
109 .append_token_id(tok);
110 }
111 }
112 if logical_token_blocks.last().as_ref().unwrap().is_full() {
113 logical_token_blocks.push(LogicalTokenBlock::new(*block_size));
114 }
115 }
116 Self::None => (),
117 }
118 }
119
120 fn pop_token_from_blocks(&mut self) {
121 match self {
122 Self::PagedAttention {
123 logical_token_blocks,
124 block_size: _,
125 } => {
126 let last = logical_token_blocks.last_mut().unwrap();
127 last.pop_token();
128 }
129 Self::None => (),
130 }
131 }
132
133 fn append_tokens_to_blocks(&mut self, toks: Vec<usize>) {
134 for tok in toks {
135 self.append_token_to_blocks(tok);
136 }
137 }
138
139 fn remove_tokens_from_blocks(&mut self, n: usize) {
140 for _ in 0..n {
141 self.pop_token_from_blocks();
142 }
143 }
144}
145
146#[derive(Clone, Copy)]
147pub enum SeqStepType {
148 PromptAndDecode,
149 OneShot,
150}
151
152pub struct Sequence {
153 id: usize,
155 prompt_len: usize,
156 max_len: Option<usize>,
157 timestamp: u128,
158 sampler: Arc<Sampler>,
159 stop_tokens: Vec<u32>,
160 stop_strings: Vec<String>,
161 return_logprobs: bool,
162 responder: Sender<Response>,
163 response_index: usize,
164 creation_time: u64,
165 prompt: String,
166 sequence_stepping_type: SeqStepType,
167 pub(crate) return_raw_logits: bool,
168 token_offset: usize,
169 eos_tokens: Vec<u32>,
170
171 image_gen_response_format: Option<ImageGenerationResponseFormat>,
173 diffusion_params: Option<DiffusionGenerationParams>,
174
175 suffix: Option<String>,
177 prefix: Option<String>,
178
179 is_tmp: bool,
181
182 prefill_prompt_toks: Option<Vec<u32>>,
184
185 adapters: Option<Vec<String>>,
187
188 normal_cache: Vec<Option<KvCache>>,
190 normal_draft_cache: Vec<Option<KvCache>>,
191 scaling_cache: Option<Tensor>,
192 cache: LayerCaches,
193 draft_cache: LayerCaches,
194 xlora_cache: Option<LayerCaches>,
195
196 seq_preallocated_cache: Option<(Tensor, Tensor)>,
198
199 tokens: Vec<u32>,
201 logprobs: Vec<Logprobs>,
202 cumulative_logprob: f32,
203 last_logprob: f32,
204 last_completion_bytes_len: usize,
205 last_is_done: Option<StopReason>,
206 completion_bytes: Vec<u8>,
207 stream_idx: usize,
208 pub recognizer: SequenceRecognizer,
209 scheduling_urgency: usize, input_images: Option<Vec<image::DynamicImage>>,
211 pub cached_pixel_values: Option<Tensor>,
212 pub cached_img_thw: Option<Tensor>,
213 pub cached_vid_thw: Option<Tensor>,
214
215 pub prompt_tok_per_sec: f32,
217 pub prompt_timestamp: Option<u128>,
218 group: Arc<Mutex<SequenceGroup>>,
219 state: RwLock<SequenceState>,
220
221 custom_metadata: SequenceCustomMetadata,
223
224 pub tools: Option<Arc<ToolCallingMatcher>>,
226}
227
228impl BlockEngineSequence for Sequence {
229 fn blocks_to_add_new_tok(&self) -> usize {
230 match &self.custom_metadata {
231 SequenceCustomMetadata::PagedAttention {
232 logical_token_blocks,
233 block_size: _,
234 } => {
235 blocks_to_add_new_tok!(logical_token_blocks)
236 }
237 SequenceCustomMetadata::None => unreachable!(),
238 }
239 }
240
241 fn get_id(&self) -> usize {
242 self.id
243 }
244
245 fn get_logical_token_blocks(&self) -> usize {
246 match &self.custom_metadata {
247 SequenceCustomMetadata::PagedAttention {
248 logical_token_blocks,
249 block_size: _,
250 } => logical_token_blocks.len(),
251 SequenceCustomMetadata::None => unreachable!(),
252 }
253 }
254}
255
256impl Sequence {
257 #[allow(clippy::too_many_arguments)]
258 pub fn new_waiting(
259 tokens: Vec<u32>,
260 prompt: String,
261 id: usize,
262 timestamp: u128,
263 layers: usize,
264 responder: Sender<Response>,
265 sampler: Sampler,
266 stop_tokens: Vec<u32>,
267 stop_strings: Vec<String>,
268 max_len: Option<usize>,
269 return_logprobs: bool,
270 is_xlora: bool,
271 group: Arc<Mutex<SequenceGroup>>,
272 response_index: usize,
273 creation_time: u64,
274 recognizer: SequenceRecognizer,
275 suffix: Option<String>,
276 prefix: Option<String>,
277 adapters: Option<Vec<String>>,
278 input_images: Option<Vec<image::DynamicImage>>,
279 block_size: Option<usize>,
281 tools: Option<Arc<ToolCallingMatcher>>,
283 image_gen_response_format: Option<ImageGenerationResponseFormat>,
284 sequence_stepping_type: SeqStepType,
285 diffusion_params: Option<DiffusionGenerationParams>,
286 seq_preallocated_cache: Option<(Tensor, Tensor)>,
288 return_raw_logits: bool,
290 eos_tokens: Vec<u32>,
291 ) -> Self {
292 let prompt_len = tokens.len();
293 let mut custom_metadata = if let Some(block_size) = block_size {
294 SequenceCustomMetadata::PagedAttention {
295 logical_token_blocks: Vec::new(),
296 block_size,
297 }
298 } else {
299 SequenceCustomMetadata::None
300 };
301 custom_metadata
302 .append_tokens_to_blocks(tokens.iter().map(|x| *x as usize).collect::<Vec<_>>());
303 Self {
304 tokens,
305 prompt,
306 logprobs: Vec::new(),
307 prompt_len,
308 id,
309 timestamp,
310 state: RwLock::new(SequenceState::Waiting),
311 normal_cache: vec![None; layers],
312 normal_draft_cache: vec![None; layers],
313 cache: vec![None; layers],
314 draft_cache: vec![None; layers],
315 xlora_cache: if is_xlora {
316 Some(vec![None; layers])
317 } else {
318 None
319 },
320 seq_preallocated_cache,
321 responder,
322 sampler: sampler.into(),
323 stop_tokens,
324 stop_strings,
325 max_len,
326 return_logprobs,
327 prompt_tok_per_sec: 0.,
328 prompt_timestamp: None,
329 group,
330 scaling_cache: None,
331 response_index,
332 creation_time,
333 recognizer,
334 prefill_prompt_toks: None,
335 suffix,
336 prefix,
337 cumulative_logprob: 0.,
338 completion_bytes: Vec::new(),
339 stream_idx: 0,
340 last_completion_bytes_len: 0,
341 last_logprob: 0.0,
342 last_is_done: None,
343 is_tmp: false,
344 scheduling_urgency: 0,
345 adapters,
346 input_images,
347 custom_metadata,
348 tools,
349 image_gen_response_format,
350 sequence_stepping_type,
351 diffusion_params,
352 cached_pixel_values: None,
353 cached_img_thw: None,
354 cached_vid_thw: None,
355 return_raw_logits,
356 token_offset: 0,
357 eos_tokens,
358 }
359 }
360
361 pub fn add_urgency(mut self) -> Self {
362 self.scheduling_urgency += 1;
363 self
364 }
365
366 pub fn reset_urgency(mut self) -> Self {
367 self.scheduling_urgency = 0;
368 self
369 }
370
371 pub fn compute_priority(&self) -> f64 {
375 #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
376 (self.scheduling_urgency as f64) + (self.len() as f64).log2()
377 }
378
379 pub fn prefill(
380 mut self,
381 cache: LayerCaches,
382 xlora_cache: Option<LayerCaches>,
383 toks: Vec<u32>,
384 ) -> Self {
385 self.cache = cache;
386 self.xlora_cache = xlora_cache;
387 self.prefill_prompt_toks = Some(toks);
388 self.set_state(SequenceState::RunningPrefillPrompt);
389 self
390 }
391
392 pub fn prefill_v2(
393 mut self,
394 cache: Vec<Option<KvCache>>,
395 toks: Vec<u32>,
396 offset: usize,
397 ) -> Self {
398 self.normal_cache = cache;
399 self.prefill_prompt_toks = Some(toks);
400 self.set_state(SequenceState::RunningPrefillPrompt);
401 self.token_offset = offset;
402 self
403 }
404
405 pub fn len(&self) -> usize {
407 if let Some(toks) = &self.prefill_prompt_toks {
408 return toks.len();
409 }
410 if self.is_tmp {
411 return self.tokens.len();
412 }
413 if self.xlora_cache.as_ref().is_some_and(|c| c[0].is_some()) {
415 self.xlora_cache.as_ref().unwrap()[0]
416 .as_ref()
417 .unwrap()
418 .0
419 .dims()[2]
420 + 1
421 } else if let Some((_, x)) = &self.cache[0] {
422 x.dims()[2] + 1
423 } else {
424 self.tokens.len()
425 }
426 }
427
428 pub fn id(&self) -> &usize {
429 &self.id
430 }
431
432 pub fn is_running(&self) -> bool {
433 matches!(
434 *self.state.read().unwrap(),
435 SequenceState::RunningCompletion | SequenceState::RunningPrompt )
437 }
438
439 pub fn is_completion(&self) -> bool {
440 matches!(
441 *self.state.read().unwrap(),
442 SequenceState::RunningCompletion
443 )
444 }
445
446 pub fn is_prompt(&self) -> bool {
447 matches!(
448 *self.state.read().unwrap(),
449 SequenceState::RunningPrompt | SequenceState::RunningPrefillPrompt
450 )
451 }
452
453 pub fn is_waiting(&self) -> bool {
454 matches!(*self.state.read().unwrap(), SequenceState::Waiting)
455 }
456
457 pub fn is_finished_paged_attn(&self) -> bool {
458 matches!(
459 *self.state.read().unwrap(),
460 SequenceState::FinishedAborted
461 | SequenceState::FinishedIgnored
462 | SequenceState::Done(_)
463 )
464 }
465
466 pub fn get_toks(&self) -> &[u32] {
467 if let Some(toks) = &self.prefill_prompt_toks {
468 return toks;
469 }
470 &self.tokens
471 }
472
473 pub fn get_initial_prompt(&self) -> &str {
474 &self.prompt
475 }
476
477 pub fn set_initial_prompt(&mut self, new: String) {
478 self.prompt = new;
479 }
480
481 pub fn token_offset(&self) -> usize {
482 self.token_offset
483 }
484
485 pub(crate) fn set_toks_and_reallocate(
487 &mut self,
488 toks: Vec<u32>,
489 paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
490 ) {
491 self.tokens.clone_from(&toks);
492 self.prompt_len = self.tokens.len();
493 match &mut self.custom_metadata {
495 SequenceCustomMetadata::PagedAttention {
496 logical_token_blocks,
497 block_size: _,
498 } => {
499 logical_token_blocks.clear();
500 }
501 SequenceCustomMetadata::None => (),
502 }
503 self.custom_metadata
504 .append_tokens_to_blocks(toks.iter().map(|x| *x as usize).collect::<Vec<_>>());
505
506 if let Some(metadata) = paged_attn_metadata {
507 metadata.block_engine.free_sequence(*self.id());
509 metadata.block_engine.allocate(self);
510 }
511 }
512
513 pub fn completion_bytes(&self) -> &[u8] {
514 &self.completion_bytes
515 }
516
517 pub fn preallocated_cache(&self) -> Option<&(Tensor, Tensor)> {
518 self.seq_preallocated_cache.as_ref()
519 }
520
521 pub fn normal_cache(&mut self) -> &mut Vec<Option<KvCache>> {
522 &mut self.normal_cache
523 }
524
525 pub fn normal_draft_cache(&mut self) -> &mut Vec<Option<KvCache>> {
526 &mut self.normal_draft_cache
527 }
528
529 pub fn cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
530 &mut self.cache
531 }
532
533 pub fn draft_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
534 &mut self.draft_cache
535 }
536
537 pub fn xlora_cache(&mut self) -> &mut Vec<Option<(Tensor, Tensor)>> {
538 self.xlora_cache.as_mut().expect("No X-LoRA cache.")
539 }
540
541 pub fn scaling_cache(&mut self) -> &mut Option<Tensor> {
542 &mut self.scaling_cache
543 }
544
545 pub fn is_xlora(&self) -> bool {
546 self.xlora_cache.is_some()
547 }
548
549 pub fn sampler(&mut self) -> Arc<Sampler> {
550 self.sampler.clone()
551 }
552
553 pub fn set_prefill_toks(&mut self, toks: Vec<u32>) {
555 self.prefill_prompt_toks = Some(toks)
556 }
557
558 pub fn reset_prefill_toks(&mut self) {
560 self.prefill_prompt_toks = None
561 }
562
563 pub(crate) fn add_tmp_tok(&mut self, tok: u32) {
565 self.is_tmp = true;
566 self.tokens.push(tok);
567 self.custom_metadata.append_token_to_blocks(tok as usize);
569 }
570
571 pub(crate) fn remove_tmp_tok(&mut self, n: usize) {
573 self.is_tmp = false;
574 self.tokens.truncate(self.tokens.len() - n);
575 self.custom_metadata.remove_tokens_from_blocks(n);
577 }
578
579 pub fn add_token(
580 &mut self,
581 tok: Logprobs,
582 completion_bytes: Vec<u8>,
583 is_done: &Option<StopReason>,
584 ) {
585 let stopped_by_token = matches!(
586 is_done,
587 Some(StopReason::Eos) | Some(StopReason::StopTok(_))
588 );
589 if !stopped_by_token {
590 self.completion_bytes.extend_from_slice(&completion_bytes);
594 self.last_completion_bytes_len = completion_bytes.len();
595 }
596 self.last_logprob = tok.logprob;
597 self.last_is_done = *is_done;
598
599 self.custom_metadata
600 .append_token_to_blocks(tok.token as usize);
601
602 self.cumulative_logprob += tok.logprob;
603 self.tokens.push(tok.token);
604 self.logprobs.push(tok);
605 self.reset_prefill_toks();
606 }
607
608 pub fn responder(&self) -> Sender<Response> {
609 self.responder.clone()
610 }
611
612 pub fn creation_time(&self) -> u64 {
613 self.creation_time
614 }
615
616 pub fn set_state(&self, state: SequenceState) {
617 if matches!(state, SequenceState::Error) {
618 get_mut_group!(self).n_choices -= 1;
619 }
620 *self.state.write().unwrap() = state;
621 }
622
623 pub fn getstate(&self) -> SequenceState {
624 *self.state.read().unwrap()
625 }
626
627 pub fn is_done(
628 &self,
629 tok: u32,
630 eos_tok: Option<&[u32]>,
631 max_model_len: usize,
632 ) -> Option<StopReason> {
633 let is_eos = match eos_tok {
634 Some(eos_tok) => eos_tok.iter().any(|t| *t == tok),
635 None => false,
636 };
637 if is_eos {
638 Some(StopReason::Eos)
639 } else if matches!(
640 &*self.state.read().unwrap(),
641 SequenceState::Done(StopReason::Canceled)
642 ) {
643 Some(StopReason::Canceled)
644 } else if self.stop_tokens.contains(&tok) {
645 Some(StopReason::StopTok(tok))
646 } else if self.max_len.is_some()
647 && self.tokens.len().saturating_sub(self.prompt_len) == self.max_len.unwrap()
648 {
649 Some(StopReason::Length(self.max_len.unwrap()))
651 } else if self.tokens.len().saturating_sub(self.prompt_len) == max_model_len {
652 Some(StopReason::ModelLength(max_model_len))
653 } else {
654 if !self.stop_strings.is_empty() {
655 for (idx, s) in self.stop_strings.iter().enumerate() {
656 if let Some(pos) = galil_seiferas::gs_find(&self.completion_bytes, s.as_bytes())
657 {
658 return Some(StopReason::StopString {
659 stop_string_idx: idx,
660 completion_bytes_pos: pos,
661 });
662 }
663 }
664 }
665 None
666 }
667 }
668
669 pub fn logprobs(&self) -> &[Logprobs] {
670 &self.logprobs
671 }
672
673 pub fn return_logprobs(&self) -> bool {
674 self.return_logprobs
675 }
676
677 pub fn prompt_tokens(&self) -> usize {
678 self.prompt_len
679 }
680
681 pub fn stop_strings(&self) -> &[String] {
682 &self.stop_strings
683 }
684
685 pub fn get_delta(
687 &mut self,
688 ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
689 let new_decoded = self.peek_delta();
690 if matches!(new_decoded, Ok(Some(_))) {
691 self.stream_idx = self.completion_bytes.len();
692 }
693 new_decoded
694 }
695
696 pub fn peek_delta(&self) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
698 let is_first = self.stream_idx == 0;
699 let new_decoded = String::from_utf8_lossy(&self.completion_bytes[self.stream_idx..]);
700 if new_decoded.ends_with('�') {
702 return Ok(None);
703 }
704
705 if is_first {
709 return Ok(Some(new_decoded.trim_start().to_string()));
710 }
711 Ok(Some(new_decoded.to_string()))
712 }
713
714 pub fn timestamp(&self) -> u128 {
715 self.timestamp
716 }
717
718 pub fn prompt_timestamp(&self) -> Option<u128> {
719 self.prompt_timestamp
720 }
721
722 fn update_time_info(&self) {
723 let now = SystemTime::now()
724 .duration_since(UNIX_EPOCH)
725 .expect("Time travel has occurred!")
726 .as_millis();
727
728 if let Some(ts) = self.prompt_timestamp {
729 get_mut_group!(self).total_completion_time += now - ts;
730 get_mut_group!(self).total_prompt_time += ts - self.timestamp;
731 }
732
733 get_mut_group!(self).total_time += now - self.timestamp;
734
735 get_mut_group!(self).total_prompt_toks = self.prompt_len;
736 get_mut_group!(self).total_toks = self.len();
737 }
738
739 pub fn add_image_choice_to_group(&self, choice: ImageChoice) {
740 get_mut_group!(self).image_choices.push(choice);
741 self.update_time_info();
742 }
743
744 pub fn add_choice_to_group(&self, choice: Choice) {
745 get_mut_group!(self).choices.push(choice);
746 self.update_time_info();
747 }
748
749 pub fn add_raw_choice_to_group(&self, logit_chunks: Vec<Tensor>) {
750 get_mut_group!(self)
751 .raw_choices
752 .push((logit_chunks, self.tokens.clone()));
753 self.update_time_info();
754 }
755
756 pub fn add_completion_choice_to_group(&self, mut choice: CompletionChoice) {
757 choice.text = format!(
758 "{}{}{}",
759 self.prefix.as_deref().unwrap_or(""),
760 choice.text,
761 self.suffix.as_deref().unwrap_or("")
762 );
763 get_mut_group!(self)
764 .completion_choices
765 .push((self.cumulative_logprob, choice));
766 self.update_time_info();
767 }
768
769 pub fn get_response_index(&self) -> usize {
770 self.response_index
771 }
772
773 pub fn get_mut_group(&self) -> MutexGuard<'_, SequenceGroup> {
774 get_mut_group!(self)
775 }
776
777 pub fn add_streaming_chunk_choice_to_group(&self, chunk: ChunkChoice) {
778 get_mut_group!(self).chat_streaming_chunks.push(chunk);
779 self.update_time_info();
780 }
781
782 pub fn add_streaming_completion_chunk_choice_to_group(&self, chunk: CompletionChunkChoice) {
783 get_mut_group!(self).completion_streaming_chunks.push(chunk);
784 }
785
786 pub fn get_adapters(&self) -> Option<Vec<String>> {
787 self.adapters.clone()
788 }
789
790 pub fn take_images(&mut self) -> Option<Vec<image::DynamicImage>> {
791 self.input_images.take()
792 }
793
794 pub fn clone_images(&mut self) -> Option<Vec<image::DynamicImage>> {
795 self.input_images.clone()
796 }
797
798 pub fn images(&self) -> Option<&[image::DynamicImage]> {
799 self.input_images.as_deref()
800 }
801
802 pub fn has_images(&self) -> bool {
803 self.input_images
804 .as_ref()
805 .is_some_and(|images| !images.is_empty())
806 }
807
808 pub fn image_gen_response_format(&self) -> Option<ImageGenerationResponseFormat> {
809 self.image_gen_response_format
810 }
811
812 pub fn sequence_stepping_type(&self) -> &SeqStepType {
813 &self.sequence_stepping_type
814 }
815
816 pub fn get_diffusion_diffusion_params(&self) -> Option<DiffusionGenerationParams> {
817 self.diffusion_params.clone()
818 }
819
820 pub fn eos_tokens(&self) -> &[u32] {
821 &self.eos_tokens
822 }
823}
824
825pub struct SequenceGroup {
826 n_choices: usize, best_of: Option<usize>, pub total_prompt_toks: usize,
829 pub total_toks: usize,
830 pub total_prompt_time: u128,
831 pub total_time: u128,
832 pub total_completion_time: u128,
833 choices: Vec<Choice>,
834 image_choices: Vec<ImageChoice>,
835 raw_choices: Vec<(Vec<Tensor>, Vec<u32>)>,
836 completion_choices: Vec<(f32, CompletionChoice)>,
837 pub chat_streaming_chunks: Vec<ChunkChoice>,
838 pub completion_streaming_chunks: Vec<CompletionChunkChoice>,
839 pub is_streaming: bool,
840 pub is_chat: bool,
841}
842
843impl SequenceGroup {
844 pub fn new(
845 n_choices: usize,
846 is_streaming: bool,
847 is_chat: bool,
848 best_of: Option<usize>,
849 ) -> Self {
850 Self {
851 choices: Vec::new(),
852 image_choices: Vec::new(),
853 raw_choices: Vec::new(),
854 completion_choices: Vec::new(),
855 n_choices,
856 total_prompt_toks: 0,
857 total_toks: 0,
858 total_prompt_time: 0,
859 total_time: 0,
860 total_completion_time: 0,
861 chat_streaming_chunks: Vec::new(),
862 completion_streaming_chunks: Vec::new(),
863 is_streaming,
864 is_chat,
865 best_of,
866 }
867 }
868
869 pub fn get_choices(&self) -> &[Choice] {
870 &self.choices
871 }
872
873 pub fn get_completion_choices(&self) -> Vec<CompletionChoice> {
875 if let Some(best_of) = self.best_of {
876 let mut choices = self.completion_choices.clone();
877 choices.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("No ordering."));
879 choices
880 .into_iter()
881 .take(best_of)
882 .map(|(_, x)| x)
883 .collect::<Vec<_>>()
884 } else {
885 self.completion_choices
886 .clone()
887 .into_iter()
888 .map(|(_, x)| x)
889 .collect::<Vec<_>>()
890 }
891 }
892
893 pub fn get_image_choices(&self) -> &[ImageChoice] {
894 &self.image_choices
895 }
896
897 pub fn get_usage(&self) -> Usage {
898 #[allow(clippy::cast_precision_loss)]
899 Usage {
900 completion_tokens: self.total_toks - self.total_prompt_toks,
901 prompt_tokens: self.total_prompt_toks,
902 total_tokens: self.total_toks,
903 avg_tok_per_sec: (self.total_toks as f32 / self.total_time as f32) * 1000.,
904 avg_prompt_tok_per_sec: (self.total_prompt_toks as f32 / self.total_prompt_time as f32)
905 * 1000.,
906 avg_compl_tok_per_sec: ((self.total_toks - self.total_prompt_toks) as f32
907 / self.total_completion_time as f32)
908 * 1000.,
909 total_time_sec: self.total_time as f32 / 1000.,
910 total_completion_time_sec: self.total_completion_time as f32 / 1000.,
911 total_prompt_time_sec: self.total_prompt_time as f32 / 1000.,
912 }
913 }
914
915 pub async fn maybe_send_chat_done_response(
916 &self,
917 response: ChatCompletionResponse,
918 sender: Sender<Response>,
919 ) -> Result<(), SendError<Response>> {
920 if self.choices.len() == self.n_choices {
921 sender.send(Response::Done(response)).await?;
922 }
923
924 Ok(())
925 }
926
927 pub async fn maybe_send_raw_done_response(
928 &self,
929 sender: Sender<Response>,
930 ) -> Result<(), SendError<Response>> {
931 if self.raw_choices.len() == self.n_choices {
932 assert_eq!(self.raw_choices.len(), 1);
933 let (logits_chunks, tokens) = self.raw_choices[0].clone();
934 sender
935 .send(Response::Raw {
936 logits_chunks,
937 tokens,
938 })
939 .await?;
940 }
941
942 Ok(())
943 }
944
945 pub async fn maybe_send_image_gen_response(
946 &self,
947 response: ImageGenerationResponse,
948 sender: Sender<Response>,
949 ) -> Result<(), SendError<Response>> {
950 if self.image_choices.len() == self.n_choices {
951 sender.send(Response::ImageGeneration(response)).await?;
952 }
953
954 Ok(())
955 }
956
957 pub async fn maybe_send_streaming_response(
958 &mut self,
959 seq: &Sequence,
960 model: String,
961 usage_opt: Option<Usage>,
962 ) -> Result<(), Box<SendError<Response>>> {
963 if self.chat_streaming_chunks.len() == self.n_choices && self.is_streaming {
964 let mut swap_streaming_chunks = vec![];
965
966 std::mem::swap(&mut swap_streaming_chunks, &mut self.chat_streaming_chunks);
967
968 seq.responder()
969 .send(Response::Chunk(ChatCompletionChunkResponse {
970 id: seq.id.to_string(),
971 choices: swap_streaming_chunks,
972 created: seq.timestamp,
973 model: model.clone(),
974 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
975 object: "chat.completion.chunk".to_string(),
976 usage: usage_opt,
977 }))
978 .await?;
979 } else if self.completion_streaming_chunks.len() == self.n_choices && self.is_streaming {
980 let mut swap_streaming_chunks = vec![];
981
982 std::mem::swap(
983 &mut swap_streaming_chunks,
984 &mut self.completion_streaming_chunks,
985 );
986
987 seq.responder()
988 .send(Response::CompletionChunk(CompletionChunkResponse {
989 id: seq.id.to_string(),
990 choices: swap_streaming_chunks,
991 created: seq.timestamp,
992 model: model.clone(),
993 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
994 object: "text_completion".to_string(),
995 }))
996 .await?;
997 }
998 Ok(())
999 }
1000
1001 pub async fn maybe_send_completion_done_response(
1002 &self,
1003 response: CompletionResponse,
1004 sender: Sender<Response>,
1005 ) -> Result<(), Box<SendError<Response>>> {
1006 if self.completion_choices.len() == self.n_choices {
1007 sender.send(Response::CompletionDone(response)).await?;
1008 }
1009 Ok(())
1010 }
1011}