mistralrs_core/engine/mod.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014
use candle_core::Tensor;
use either::Either;
use llguidance::toktrie::TokEnv;
use once_cell::sync::Lazy;
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::{Instant, SystemTime, UNIX_EPOCH},
};
use tokio::sync::{mpsc::Receiver, Mutex};
use crate::{
pipeline::{
llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
text_models_inputs_processor::PagedAttentionMeta,
AdapterInstruction, CacheBackendMetadata, CacheInstruction, EitherCache, NormalCache,
},
prefix_cacher_v2::PrefixCacheManagerV2,
request::{DetokenizationRequest, NormalRequest, TokenizationRequest},
response::CompletionChoice,
scheduler::{Scheduler, SchedulerOutput},
sequence::{SeqStepType, StopReason},
tools::{ToolCallingMatcher, ToolChoice},
CompletionResponse, RequestMessage, Response, SchedulerConfig, DEBUG,
};
use rand::SeedableRng;
use rand_isaac::Isaac64Rng;
use tracing::{info, warn};
use crate::{
get_mut_arcmutex, handle_pipeline_forward_error, handle_seq_error,
pipeline::Pipeline,
request::Request,
response::{ChatCompletionResponse, Choice, ResponseMessage},
sampler::Sampler,
sequence::{Sequence, SequenceGroup, SequenceRecognizer, SequenceState},
Constraint, StopTokens,
};
pub enum EngineInstruction {
Terminate,
}
const SEED: u64 = 0;
/// Terminate all sequences on the next scheduling step. Be sure to reset this.
pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
/// Engine instructions, per Engine (MistralRs) ID.
pub static ENGINE_INSTRUCTIONS: Lazy<std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>> =
Lazy::new(|| std::sync::Mutex::new(HashMap::new()));
pub struct Engine {
rx: Receiver<Request>,
pipeline: Arc<Mutex<dyn Pipeline>>,
scheduler: Box<dyn Scheduler>,
id: usize,
truncate_sequence: bool,
no_kv_cache: bool,
prefix_cacher: PrefixCacheManagerV2,
is_debug: bool,
disable_eos_stop: bool,
throughput_logging_enabled: bool,
}
impl Engine {
#[allow(clippy::too_many_arguments)]
pub fn new(
rx: Receiver<Request>,
pipeline: Arc<Mutex<dyn Pipeline>>,
config: SchedulerConfig,
truncate_sequence: bool,
no_kv_cache: bool,
no_prefix_cache: bool,
prefix_cache_n: usize,
disable_eos_stop: bool,
throughput_logging_enabled: bool,
) -> Self {
let device = get_mut_arcmutex!(pipeline).device().clone();
let has_no_kv_cache = get_mut_arcmutex!(pipeline).get_metadata().has_no_kv_cache;
if no_kv_cache {
// Diffusion models...
assert_eq!(has_no_kv_cache, no_kv_cache);
}
// Prefix caching is always disabled if using PagedAttention for now.
// TODO
let no_prefix_cache = matches!(config, SchedulerConfig::PagedAttentionMeta { .. })
|| no_prefix_cache
|| has_no_kv_cache;
Self {
rx,
pipeline,
scheduler: config.into_scheduler(),
id: 0,
truncate_sequence,
no_kv_cache: no_kv_cache & !has_no_kv_cache,
prefix_cacher: PrefixCacheManagerV2::new(device, prefix_cache_n, no_prefix_cache),
is_debug: DEBUG.load(Ordering::Relaxed),
disable_eos_stop,
throughput_logging_enabled,
}
}
pub async fn run(&mut self) {
let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
let mut last_completion_ids: Vec<usize> = vec![];
'lp: loop {
if matches!(
ENGINE_INSTRUCTIONS
.lock()
.expect("`ENGINE_INSTRUCTIONS` was poisioned")
.get(&self.id),
Some(Some(EngineInstruction::Terminate))
) {
break 'lp;
}
while let Ok(request) = self.rx.try_recv() {
if matches!(request, Request::Terminate) {
break 'lp;
}
self.handle_request(request).await;
}
let run_start = Instant::now();
let scheduled = self.scheduler.schedule();
match scheduled {
SchedulerOutput::DefaultScheduler {
output: mut scheduled,
} => {
let mut prompt_ts = None;
let mut completion_ts = None;
if scheduled.completion.len() > 0 {
let throughput_start = Instant::now();
let current_completion_ids: Vec<usize> =
scheduled.completion.iter().map(|seq| *seq.id()).collect();
let res = {
let mut pipeline = get_mut_arcmutex!(self.pipeline);
let pre_op = if !self.no_kv_cache
&& last_completion_ids != current_completion_ids
{
CacheInstruction::In(
scheduled.completion[0]
.get_adapters()
.map(AdapterInstruction::Activate)
.unwrap_or(AdapterInstruction::None),
)
} else {
CacheInstruction::Nothing(
scheduled.completion[0]
.get_adapters()
.map(AdapterInstruction::Activate)
.unwrap_or(AdapterInstruction::None),
)
};
let post_op = if !self.no_kv_cache {
CacheInstruction::Out
} else {
CacheInstruction::Reset {
load_preallocated_cache: false,
reset_non_granular: false,
adapter_inst: AdapterInstruction::None,
}
};
let return_raw_logits = scheduled.completion[0].return_raw_logits;
assert!(
scheduled
.completion
.iter()
.all(|seq| seq.return_raw_logits == return_raw_logits),
"All sequences must either return raw logits, or not."
);
pipeline
.step(
&mut scheduled.completion,
false,
return_raw_logits,
&mut self.prefix_cacher,
self.disable_eos_stop,
rng.clone(),
CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
)
.await
};
handle_pipeline_forward_error!(
"completion step",
res,
&mut scheduled.completion,
self.pipeline,
'lp,
self.prefix_cacher
);
let throughput_end = Instant::now();
#[allow(clippy::cast_precision_loss)]
if self.throughput_logging_enabled {
completion_ts = Some(
scheduled.completion.len() as f64
/ throughput_end
.duration_since(throughput_start)
.as_secs_f64(),
);
}
last_completion_ids = current_completion_ids;
}
if scheduled.prompt.len() > 0 {
let prompt_exec_time = {
let mut pipeline = get_mut_arcmutex!(self.pipeline);
// Run the prompt seqs
let post_op = if !self.no_kv_cache {
CacheInstruction::Out
} else {
CacheInstruction::Reset {
load_preallocated_cache: false,
reset_non_granular: false,
adapter_inst: AdapterInstruction::None,
}
};
let adapter_inst = scheduled.prompt[0]
.get_adapters()
.map(AdapterInstruction::Activate)
.unwrap_or(AdapterInstruction::None);
let return_raw_logits = scheduled.prompt[0].return_raw_logits;
assert!(
scheduled
.prompt
.iter()
.all(|seq| seq.return_raw_logits == return_raw_logits),
"All sequences must either return raw logits, or not."
);
// Reset non granular state because the old sequence must be dead.
// Technically we don't need to do this but it is better to be safe.
pipeline
.step(
&mut scheduled.prompt,
true,
return_raw_logits,
&mut self.prefix_cacher,
self.disable_eos_stop,
rng.clone(),
CacheBackendMetadata::DefaultInstructions {
pre_op: CacheInstruction::Reset {
load_preallocated_cache: true,
reset_non_granular: false,
adapter_inst,
},
post_op,
},
)
.await
};
let prompt_exec_time = handle_pipeline_forward_error!(
"prompt step",
prompt_exec_time,
&mut scheduled.prompt,
self.pipeline,
'lp,
self.prefix_cacher
);
#[allow(clippy::cast_precision_loss)]
if self.throughput_logging_enabled {
prompt_ts = Some(
scheduled
.prompt
.iter()
.map(|seq| seq.get_toks().len())
.sum::<usize>() as f64
/ prompt_exec_time.as_secs_f64(),
);
}
for seq in scheduled.prompt.iter_mut() {
match seq.sequence_stepping_type() {
SeqStepType::OneShot => {
seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
}
SeqStepType::PromptAndDecode => {
seq.set_state(SequenceState::RunningCompletion)
}
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time travel has occurred!")
.as_millis();
#[allow(clippy::cast_precision_loss)]
let prompt_tok_per_sec =
seq.len() as f32 / prompt_exec_time.as_secs_f32();
seq.prompt_tok_per_sec = prompt_tok_per_sec;
seq.prompt_timestamp = Some(now);
}
last_completion_ids = vec![];
}
if self.is_debug {
let ms_from_last_run = run_start.elapsed().as_secs_f64();
let total_len = scheduled.prompt.len() + scheduled.completion.len();
if total_len > 0 {
let prompt_lengths = scheduled
.prompt
.iter()
.map(|seq| seq.len().to_string())
.collect::<Vec<_>>()
.join(", ");
let completion_lengths = scheduled
.completion
.iter()
.map(|seq| seq.len().to_string())
.collect::<Vec<_>>()
.join(", ");
tracing::info!(
"Prompt[{}] Completion[{}] - {}ms",
prompt_lengths,
completion_lengths,
ms_from_last_run * 1000.,
);
}
}
if self.throughput_logging_enabled {
match (prompt_ts, completion_ts) {
(Some(prompt), Some(completion)) => {
info!("Throughput (scheduler V1): Prompt: {prompt} T/s Completion {completion} T/s");
}
(None, Some(completion)) => {
info!("Throughput (scheduler V1): Completion {completion} T/s");
}
(Some(prompt), None) => {
info!("Throughput (scheduler V1): Prompt: {prompt} T/s");
}
(None, None) => (),
}
}
if scheduled.prompt.len() == 0
&& scheduled.completion.len() == 0
&& self.scheduler.waiting_len() == 0
{
// If there is nothing to do, sleep until a request comes in
if let Some(request) = self.rx.recv().await {
if matches!(request, Request::Terminate) {
break 'lp;
}
self.handle_request(request).await;
}
}
}
SchedulerOutput::PagedAttention { mut output } => {
if !output.scheduled.is_empty() {
let throughput_start = Instant::now();
let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
let mut guards = output
.scheduled
.iter_mut()
.map(|seq| seq.lock().unwrap())
.collect::<Vec<_>>();
let mut guards_mut =
guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
let res = {
let mut pipeline = get_mut_arcmutex!(self.pipeline);
let block_size = self.scheduler.block_size().unwrap();
let metadata = PagedAttentionMeta {
block_size,
sliding_window: pipeline.get_metadata().sliding_window,
block_engine: self.scheduler.block_engine().unwrap(),
};
let return_raw_logits = guards_mut[0].return_raw_logits;
assert!(
guards_mut
.iter()
.all(|seq| seq.return_raw_logits == return_raw_logits),
"All sequences must either return raw logits, or not."
);
pipeline
.step(
&mut guards_mut,
is_prompt,
return_raw_logits,
&mut self.prefix_cacher,
self.disable_eos_stop,
rng.clone(),
CacheBackendMetadata::PagedAttention {
metadata,
blocks_to_copy: output.blocks_to_copy,
blocks_to_swap_in: output.blocks_to_swap_in,
blocks_to_swap_out: output.blocks_to_swap_out,
},
)
.await
};
handle_pipeline_forward_error!(
"step",
res,
&mut guards_mut,
self.pipeline,
'lp,
self.prefix_cacher
);
if self.is_debug {
let ms_from_last_run = run_start.elapsed().as_secs_f64();
let total_len = guards.len();
if total_len > 0 {
let lengths = guards
.iter()
.map(|seq| seq.len().to_string())
.collect::<Vec<_>>()
.join(", ");
let (prompt_lengths, completion_lengths) = if is_prompt {
(lengths, "".to_string())
} else {
("".to_string(), lengths)
};
tracing::info!(
"Prompt[{}] Completion[{}] - {}ms",
prompt_lengths,
completion_lengths,
ms_from_last_run * 1000.,
);
}
}
let throughput_end = Instant::now();
#[allow(clippy::cast_precision_loss)]
if self.throughput_logging_enabled {
let n_toks = if is_prompt {
guards.iter().map(|seq| seq.get_toks().len()).sum::<usize>()
} else {
guards.len()
};
let ts = n_toks as f64
/ throughput_end
.duration_since(throughput_start)
.as_secs_f64();
info!("Throughput (scheduler V2): {ts} T/s");
}
if is_prompt {
for mut seq in guards {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time travel has occurred!")
.as_millis();
#[allow(clippy::cast_precision_loss)]
let prompt_tok_per_sec =
seq.len() as f32 / (now - seq.timestamp()) as f32;
seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
seq.prompt_timestamp = Some(now);
}
}
}
}
}
self.scheduler.free_finished_sequence_groups();
}
}
fn build_sequence_recognizer(
tok_env: &Option<TokEnv>,
constraint: &Constraint,
) -> anyhow::Result<SequenceRecognizer> {
if let Some(grm) = llg_grammar_from_constraint(constraint)? {
let tok_env = tok_env
.as_ref()
.ok_or_else(|| anyhow::anyhow!("No token environment found."))?;
let llg = constraint_from_llg_grammar(tok_env.clone(), grm)?;
Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
} else {
Ok(SequenceRecognizer::None)
}
}
async fn handle_request(&mut self, request: Request) {
match request {
Request::ActivateAdapters(adapters) => {
match get_mut_arcmutex!(self.pipeline).activate_adapters(adapters) {
Ok(n) => info!("Swapped adapters in {n} LoRA layers."),
Err(e) => warn!("Adapter activation failed: {e:?}"),
}
}
Request::Normal(request) => self.add_request(request).await,
Request::ReIsq(level) => {
if let Err(e) = get_mut_arcmutex!(self.pipeline).re_isq_model(level) {
warn!("ISQ requantization failed: {e:?}");
}
}
Request::Tokenize(req) => self.tokenize_text(req).await,
Request::Detokenize(req) => self.detokenize_text(req).await,
Request::Terminate => panic!("This is unreachable in `handle_request`. Termination is handled in the `run` loop."),
}
}
async fn add_request(&mut self, request: NormalRequest) {
let is_chat = matches!(
request.messages,
RequestMessage::Chat(_) | RequestMessage::VisionChat { .. }
);
let echo_prompt = matches!(
request.messages,
RequestMessage::Completion {
echo_prompt: true,
..
}
);
let best_of = match request.messages {
RequestMessage::Completion { best_of, .. } => best_of,
RequestMessage::Chat(_)
| RequestMessage::CompletionTokens(_)
| RequestMessage::VisionChat { .. }
| RequestMessage::ImageGeneration { .. } => None,
};
if is_chat
&& !get_mut_arcmutex!(self.pipeline)
.get_chat_template()
.as_ref()
.is_some_and(|ch_t| ch_t.has_chat_template())
{
request
.response
.send(Response::ValidationError(
"Received messages for a model which does not have a chat template. Either use a different model or pass a single string as the prompt".into(),
)).await.expect("Expected receiver.");
return;
}
let images = match request.messages {
RequestMessage::VisionChat {
ref images,
messages: _,
} => Some(images.clone()),
_ => None,
};
let matcher = if request.tools.is_some() {
Some(Arc::new(handle_seq_error!(
ToolCallingMatcher::new(request.tool_choice.unwrap_or(ToolChoice::Auto),),
request.response
)))
} else {
None
};
let image_generation_format = match &request.messages {
RequestMessage::ImageGeneration { format, .. } => Some(*format),
_ => None,
};
let seq_step_type = match &request.messages {
RequestMessage::ImageGeneration { .. } => SeqStepType::OneShot,
_ => SeqStepType::PromptAndDecode,
};
let diffusion_params = match &request.messages {
RequestMessage::ImageGeneration {
generation_params, ..
} => Some(generation_params.clone()),
_ => None,
};
let (mut prompt_tokens, prompt_text) = match request.messages {
RequestMessage::Chat(messages)
| RequestMessage::VisionChat {
images: _,
messages,
} => {
let pipeline = &*get_mut_arcmutex!(self.pipeline);
let template = pipeline.get_processor().process(
pipeline,
messages,
true,
true,
request.tools.unwrap_or_default(),
);
handle_seq_error!(template, request.response)
}
RequestMessage::Completion { text, .. } => {
let Some(tokenizer) = &get_mut_arcmutex!(self.pipeline).tokenizer() else {
request
.response
.send(Response::ValidationError(
"Completion requests require the pipeline to have a tokenizer".into(),
))
.await
.expect("Expected receiver.");
return;
};
let prompt = tokenizer
.encode(text.clone(), true)
.map_err(anyhow::Error::msg);
(
handle_seq_error!(prompt, request.response)
.get_ids()
.to_vec(),
text,
)
}
RequestMessage::ImageGeneration { prompt, .. } => (vec![u32::MAX], prompt),
RequestMessage::CompletionTokens(it) => {
let Some(tokenizer) = &get_mut_arcmutex!(self.pipeline).tokenizer() else {
request
.response
.send(Response::ValidationError(
"Completion requests w/ raw tokens require the pipeline to have a tokenizer".into(),
))
.await
.expect("Expected receiver.");
return;
};
let prompt = tokenizer
.decode(&it, false)
.map_err(|e| anyhow::Error::msg(e.to_string()));
(it, handle_seq_error!(prompt, request.response))
}
};
if prompt_tokens.is_empty() {
request
.response
.send(Response::ValidationError(
"Received an empty prompt.".into(),
))
.await
.expect("Expected receiver.");
return;
}
if prompt_tokens.len() > get_mut_arcmutex!(self.pipeline).get_metadata().max_seq_len {
if !self.truncate_sequence {
request
.response
.send(Response::ValidationError(
format!("Prompt sequence length is greater than {}, perhaps consider using `truncate_sequence`?", get_mut_arcmutex!(self.pipeline).get_metadata().max_seq_len).into(),
)).await.expect("Expected receiver.");
return;
} else {
let prompt_len = prompt_tokens.len();
let max_len = get_mut_arcmutex!(self.pipeline).get_metadata().max_seq_len;
let currently_over = prompt_len - max_len;
let sampling_max = if let Some(sampling_max) = request.sampling_params.max_len {
if currently_over + sampling_max >= prompt_len {
10
} else {
sampling_max
}
} else {
10
};
prompt_tokens = prompt_tokens[(currently_over + sampling_max)..].to_vec();
warn!("Prompt for request {} was {} tokens over the model maximum length. The last {} tokens were truncated to make space for generation.", request.id, currently_over, prompt_len - prompt_tokens.len());
}
}
let prefill_cache = handle_seq_error!(
self.prefix_cacher.search_for_matching_cache(&prompt_tokens),
request.response
);
let topk = request
.sampling_params
.top_k
.map(|x| x as i64)
.unwrap_or(-1);
let topp = request.sampling_params.top_p.unwrap_or(1.0);
let minp = request.sampling_params.min_p.unwrap_or(0.0);
let num_hidden_layers = get_mut_arcmutex!(self.pipeline)
.get_metadata()
.num_hidden_layers;
let (stop_toks, stop_strings) = match request.sampling_params.stop_toks {
None => (vec![], vec![]),
Some(StopTokens::Ids(ref i)) => {
let tok_env = {
let pipeline = get_mut_arcmutex!(self.pipeline);
pipeline.get_metadata().tok_env.clone()
};
for id in i {
// We can't use ` ` (space) as a stop token because other tokens like ` moon` start with a space.
if let Some(tok_env) = tok_env.as_ref() {
let tok_trie = tok_env.tok_trie();
if tok_trie.has_extensions(tok_trie.token(*id)) {
request
.response
.send(Response::ValidationError(
format!("Stop token {:?} is also a prefix of other tokens and cannot be used as a stop token.", tok_trie.token_str(*id)).into(),
))
.await .expect("Expected receiver.");
return;
}
}
}
(i.clone(), vec![])
}
Some(StopTokens::Seqs(ref s)) => {
let mut stop_toks = Vec::new();
let mut stop_strings: Vec<String> = Vec::new();
let (tok_env, tokenizer) = {
let pipeline = get_mut_arcmutex!(self.pipeline);
let tok_env = pipeline.get_metadata().tok_env.clone();
let tokenizer = pipeline.tokenizer();
(tok_env, tokenizer)
};
for stop_txt in s {
let Some(tokenizer) = &tokenizer else {
request
.response
.send(Response::ValidationError(
"Completion requests require the pipeline to have a tokenizer"
.into(),
))
.await
.expect("Expected receiver.");
return;
};
let encoded = tokenizer.encode(stop_txt.to_string(), true);
let toks = handle_seq_error!(encoded, request.response)
.get_ids()
.to_vec();
if toks.len() == 1 {
if tok_env.as_ref().is_some_and(|tok_env| {
let tok_trie = tok_env.tok_trie();
tok_trie.has_extensions(tok_trie.token(toks[0]))
}) {
stop_strings.push(stop_txt.clone());
} else {
stop_toks.push(toks[0]);
}
} else {
stop_strings.push(stop_txt.clone());
}
}
(stop_toks, stop_strings)
}
};
let group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
request.sampling_params.n_choices,
request.is_streaming,
is_chat,
best_of,
)));
let tokenizer = get_mut_arcmutex!(self.pipeline).tokenizer();
let sampler = Sampler::new(
Some(request.sampling_params.temperature.unwrap_or(1.0)),
request.sampling_params.top_n_logprobs,
tokenizer,
request.sampling_params.frequency_penalty,
request.sampling_params.presence_penalty,
request.sampling_params.dry_params,
topk,
topp,
minp,
request.logits_processors.unwrap_or_default(),
);
let sampler = handle_seq_error!(sampler, request.response);
if request.sampling_params.n_choices == 0 {
request
.response
.send(Response::ValidationError(
"Number of choices must be greater than 0.".into(),
))
.await
.expect("Expected receiver.");
return;
}
// Add sequences
for response_index in 0..request.sampling_params.n_choices {
let trie = get_mut_arcmutex!(self.pipeline)
.get_metadata()
.tok_env
.clone();
let recognizer = match Self::build_sequence_recognizer(&trie, &request.constraint) {
Ok(recognizer) => recognizer,
Err(err) => {
request
.response
.send(Response::ValidationError(
format!("Invalid grammar. {}", err).into(),
))
.await
.expect("Expected receiver.");
return;
}
};
let block_size = get_mut_arcmutex!(self.pipeline)
.get_metadata()
.cache_config
.clone()
.map(|conf| conf.block_size);
let cache = get_mut_arcmutex!(self.pipeline).cache().clone();
let seq_preallocated_cache = if let EitherCache::Normal(_cache) = cache {
let metadata = get_mut_arcmutex!(self.pipeline).get_metadata();
let model_metadata = metadata
.model_metadata
.as_ref()
.expect("If a model has a NormalCache it must have a model metadata");
let n_tokens = prompt_tokens.len();
let required_blocks = n_tokens.div_ceil(NormalCache::CACHE_GROW_SIZE);
let max_seq_len = required_blocks * NormalCache::CACHE_GROW_SIZE;
let kv_shape = (
1usize,
model_metadata.num_kv_heads(),
max_seq_len,
model_metadata.head_dim(),
);
let dtype = get_mut_arcmutex!(self.pipeline)
.get_metadata()
.activation_dtype;
let seq_cache =
Tensor::zeros(kv_shape, dtype, &get_mut_arcmutex!(self.pipeline).device());
let seq_cache = match seq_cache {
Ok(x) => x,
Err(_) => {
request
.response
.send(Response::InternalError(
"Failed to allocate preallocated KV cache."
.to_string()
.into(),
))
.await
.expect("Expected receiver.");
return;
}
};
Some(seq_cache)
} else {
None
};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time travel has occurred!");
let seq = Sequence::new_waiting(
prompt_tokens.clone(),
prompt_text.clone(),
self.id,
now.as_millis(),
num_hidden_layers,
request.response.clone(),
sampler.clone(),
stop_toks.clone(),
stop_strings.clone(),
request.sampling_params.max_len,
request.return_logprobs,
get_mut_arcmutex!(self.pipeline).get_metadata().is_xlora,
group.clone(),
response_index,
now.as_secs(),
recognizer,
request.suffix.clone(),
if echo_prompt {
Some(prompt_text.clone())
} else {
None
},
request.adapters.clone(),
images.clone(),
block_size,
matcher.clone(),
image_generation_format,
seq_step_type,
diffusion_params.clone(),
seq_preallocated_cache,
request.return_raw_logits,
);
let seq = if let Some(prefill_cache) = prefill_cache.clone() {
seq.prefill_v2(
prefill_cache.normal,
prefill_cache.toks,
prefill_cache.offset,
)
} else {
seq
};
self.id += 1;
self.scheduler.add_seq(seq);
}
}
async fn tokenize_text(&self, request: TokenizationRequest) {
match request.text {
Either::Left(messages) => {
let pipeline = &*get_mut_arcmutex!(self.pipeline);
let template = pipeline.get_processor().process(
pipeline,
messages,
request.add_generation_prompt,
request.add_special_tokens,
request.tools.unwrap_or_default(),
);
let toks = match template {
Ok((toks, _)) => toks,
Err(e) => {
request
.response
.send(Err(e))
.await
.expect("Expected receiver.");
return;
}
};
request
.response
.send(Ok(toks))
.await
.expect("Sender disconnected unexpectedly!");
}
Either::Right(text) => {
let pipeline = &*get_mut_arcmutex!(self.pipeline);
let tokenizer = pipeline.tokenizer();
let tokenizer = match tokenizer {
Some(tokenizer) => tokenizer,
None => {
request
.response
.send(Err(anyhow::Error::msg(
"Pipeline does not include a toksnizer.",
)))
.await
.expect("Expected receiver.");
return;
}
};
let toks = tokenizer.encode(text, request.add_special_tokens);
let toks = match toks {
Ok(tokenizer) => tokenizer,
Err(e) => {
request
.response
.send(Err(anyhow::Error::msg(e)))
.await
.expect("Expected receiver.");
return;
}
};
request
.response
.send(Ok(toks.get_ids().to_vec()))
.await
.expect("Sender disconnected unexpectedly!");
}
};
}
async fn detokenize_text(&self, request: DetokenizationRequest) {
let pipeline = &*get_mut_arcmutex!(self.pipeline);
let tokenizer = pipeline.tokenizer();
let tokenizer = match tokenizer {
Some(tokenizer) => tokenizer,
None => {
request
.response
.send(Err(anyhow::Error::msg(
"Pipeline does not include a toksnizer.",
)))
.await
.expect("Expected receiver.");
return;
}
};
let txt = tokenizer.decode(&request.tokens, request.skip_special_tokens);
let txt = match txt {
Ok(tokenizer) => tokenizer,
Err(e) => {
request
.response
.send(Err(anyhow::Error::msg(e)))
.await
.expect("Expected receiver.");
return;
}
};
request
.response
.send(Ok(txt))
.await
.expect("Sender disconnected unexpectedly!");
}
}