1use crate::{
2 pipeline::NormalCache,
3 request::{DetokenizationRequest, NormalRequest, SearchContextSize, TokenizationRequest},
4 search::{self, SearchFunctionParameters, SearchResult},
5 sequence::SeqStepType,
6 tools::{ToolCallingMatcher, ToolChoice},
7 MessageContent, RequestMessage, Response, ResponseOk,
8};
9use candle_core::Tensor;
10use either::Either;
11use indexmap::IndexMap;
12use std::{
13 borrow::Cow,
14 ops::Deref,
15 sync::{atomic::Ordering, Arc},
16 time::{SystemTime, UNIX_EPOCH},
17};
18use tokenizers::InputSequence;
19use tracing::{info, warn};
20
21use crate::{
22 get_mut_arcmutex, handle_seq_error,
23 request::Request,
24 sampler::Sampler,
25 sequence::{Sequence, SequenceGroup},
26 StopTokens,
27};
28
29use super::{Engine, TERMINATE_ALL_NEXT_STEP};
30
31impl Engine {
32 pub async fn handle_request(self: Arc<Self>, request: Request) {
33 match request {
34 Request::Normal(request) => {
35 if matches!(
36 request.messages,
37 RequestMessage::Chat { .. } | RequestMessage::VisionChat { .. }
38 ) && request.web_search_options.is_some()
39 && !request.is_streaming
40 && get_mut_arcmutex!(self.bert_pipeline).is_some()
41 {
42 let Some(web_search_options) = request.web_search_options.clone() else {
43 unreachable!()
44 };
45 let mut first_request = request.clone();
46 first_request
48 .tools
49 .get_or_insert_with(Vec::new)
50 .push(search::get_search_tool(&web_search_options).unwrap());
51
52 let mut second_request = first_request.clone();
53 first_request.web_search_options = None;
54 second_request.web_search_options = None;
55
56 let this = self.clone();
57 let handle = tokio::spawn(async move {
58 let (new_sender, mut first_receiver) = tokio::sync::mpsc::channel(1);
59 second_request.response = new_sender;
60 std::mem::swap(&mut first_request.response, &mut second_request.response);
61
62 this.add_request(first_request).await;
63 let ResponseOk::Done(done) =
64 first_receiver.recv().await.unwrap().as_result().unwrap()
65 else {
66 unreachable!()
67 };
68
69 let tool_calls = match &done.choices[0].message.tool_calls {
70 Some(tool_calls)
71 if tool_calls.len() == 1
72 && tool_calls[0].function.name == search::SEARCH_TOOL_NAME =>
73 {
74 &tool_calls[0]
75 }
76 None => {
77 second_request
78 .response
79 .send(Response::Done(done))
80 .await
81 .unwrap();
82 return;
83 }
84 Some(_) => {
85 second_request
86 .response
87 .send(Response::Done(done))
88 .await
89 .unwrap();
90 return;
91 }
92 };
93
94 let RequestMessage::Chat(messages) = &mut second_request.messages else {
95 unreachable!()
96 };
97
98 {
100 let mut message: IndexMap<String, MessageContent> = IndexMap::new();
101 message
102 .insert("role".to_string(), Either::Left("assistant".to_string()));
103 message.insert(
104 "content".to_string(),
105 Either::Left(format!(
106 "{{\"name\":\"{}\",\"arguments\":\"{}\"}}",
107 tool_calls.function.name, tool_calls.function.arguments
108 )),
109 );
110 messages.push(message);
111 }
112 let tool_call_params: SearchFunctionParameters =
113 serde_json::from_str(&tool_calls.function.arguments).unwrap();
114
115 {
117 let tokenizer = get_mut_arcmutex!(this.pipeline)
118 .tokenizer()
119 .expect("A tokenizer is expected for non-diffusion models.");
120 let mut results = search::run_search_tool(&tool_call_params)
121 .unwrap()
122 .into_iter()
123 .map(|result| {
124 let len = {
125 let inp = InputSequence::Raw(Cow::from(&result.content));
126 tokenizer
127 .encode_fast(inp, false)
128 .map(|x| x.len())
129 .unwrap_or(usize::MAX)
130 };
131 (result, len)
132 })
133 .collect::<Vec<_>>();
134 results.sort_by_key(|(_, len)| *len);
136
137 {
138 let device = get_mut_arcmutex!(this.pipeline).device();
139
140 let Some(bert_pipeline) =
141 &mut *get_mut_arcmutex!(this.bert_pipeline)
142 else {
143 unreachable!()
144 };
145
146 let decreasing_indexes = search::rag::compute_most_similar(
147 &device,
148 &tool_call_params.query,
149 results.iter().map(|(res, _)| res).collect::<Vec<_>>(),
150 bert_pipeline,
151 )
152 .unwrap();
153
154 let mut results_old = Vec::new();
156 std::mem::swap(&mut results_old, &mut results);
157 for &index in &decreasing_indexes {
158 let mut current_result: (SearchResult, usize) =
159 Default::default();
160 std::mem::swap(&mut current_result, &mut results_old[index]);
161
162 results.push(current_result);
163 }
164 }
165
166 let max_results_budget_toks =
168 match web_search_options.search_context_size.unwrap_or_default() {
169 SearchContextSize::High => 10000_usize,
170 SearchContextSize::Medium => 7500_usize,
171 SearchContextSize::Low => 3000_usize,
172 };
173 let mut used_results = Vec::new();
174 let mut used_len = 0;
175 for (item, len) in results {
176 if used_len + len >= max_results_budget_toks {
177 break;
178 }
179 used_len += len;
181 used_results.push(item);
182 }
183
184 let tool_result = serde_json::to_string(&used_results)
185 .unwrap()
186 .replace("\\n", "\n")
187 .replace("\\\"", "\"")
188 .replace("\\\\", "\\");
189 info!("Web search executed, using {used_len} tokens of {} search results.", used_results.len());
190
191 let mut message: IndexMap<String, MessageContent> = IndexMap::new();
192 message.insert("role".to_string(), Either::Left("tool".to_string()));
193 message.insert(
194 "content".to_string(),
195 Either::Left(format!("{{\"output\": \"{tool_result}\"}}")),
196 );
197 messages.push(message);
198 }
199
200 this.add_request(second_request).await;
201 });
202 get_mut_arcmutex!(self.handles).push(handle);
203 } else {
204 self.add_request(request).await
205 }
206 }
207 Request::ReIsq(level) => {
208 if let Err(e) = get_mut_arcmutex!(self.pipeline).re_isq_model(level) {
209 warn!("ISQ requantization failed: {e:?}");
210 }
211 }
212 Request::Tokenize(req) => self.tokenize_text(req).await,
213 Request::Detokenize(req) => self.detokenize_text(req).await,
214 Request::Terminate => (),
215 Request::TerminateAllSeqsNextStep => {
216 TERMINATE_ALL_NEXT_STEP.store(true, Ordering::SeqCst)
217 }
218 }
219 }
220
221 async fn add_request(&self, request: NormalRequest) {
222 let is_chat = matches!(
223 request.messages,
224 RequestMessage::Chat(_) | RequestMessage::VisionChat { .. }
225 );
226 let echo_prompt = matches!(
227 request.messages,
228 RequestMessage::Completion {
229 echo_prompt: true,
230 ..
231 }
232 );
233
234 let best_of = match request.messages {
235 RequestMessage::Completion { best_of, .. } => best_of,
236 RequestMessage::Chat(_)
237 | RequestMessage::CompletionTokens(_)
238 | RequestMessage::VisionChat { .. }
239 | RequestMessage::ImageGeneration { .. } => None,
240 };
241 if is_chat
242 && !get_mut_arcmutex!(self.pipeline)
243 .get_chat_template()
244 .as_ref()
245 .is_some_and(|ch_t| ch_t.has_chat_template())
246 {
247 request
248 .response
249 .send(Response::ValidationError(
250 "Received messages for a model which does not have a chat template. Either use a different model or pass a single string as the prompt".into(),
251 )).await.expect("Expected receiver.");
252 return;
253 }
254
255 let images = match request.messages {
256 RequestMessage::VisionChat {
257 ref images,
258 messages: _,
259 } => Some(images.clone()),
260 _ => None,
261 };
262
263 let matcher = Arc::new(handle_seq_error!(
264 ToolCallingMatcher::new(request.tool_choice.unwrap_or(ToolChoice::Auto),),
265 request.response
266 ));
267
268 let image_generation_format = match &request.messages {
269 RequestMessage::ImageGeneration { format, .. } => Some(*format),
270 _ => None,
271 };
272
273 let seq_step_type = match &request.messages {
274 RequestMessage::ImageGeneration { .. } => SeqStepType::OneShot,
275 _ => SeqStepType::PromptAndDecode,
276 };
277
278 let diffusion_params = match &request.messages {
279 RequestMessage::ImageGeneration {
280 generation_params, ..
281 } => Some(generation_params.clone()),
282 _ => None,
283 };
284
285 let (mut prompt_tokens, prompt_text) = match request.messages {
286 RequestMessage::Chat(messages)
287 | RequestMessage::VisionChat {
288 images: _,
289 messages,
290 } => {
291 let pipeline = &*get_mut_arcmutex!(self.pipeline);
292 let tools = request.tools.unwrap_or_default();
293 let template = pipeline
294 .get_processor()
295 .process(pipeline, messages, true, true, tools);
296 handle_seq_error!(template, request.response)
297 }
298 RequestMessage::Completion { text, .. } => {
299 let Some(tokenizer) = &get_mut_arcmutex!(self.pipeline).tokenizer() else {
300 request
301 .response
302 .send(Response::ValidationError(
303 "Completion requests require the pipeline to have a tokenizer".into(),
304 ))
305 .await
306 .expect("Expected receiver.");
307 return;
308 };
309 let prompt = tokenizer
310 .encode_fast(text.clone(), true)
311 .map_err(anyhow::Error::msg);
312 (
313 handle_seq_error!(prompt, request.response)
314 .get_ids()
315 .to_vec(),
316 text,
317 )
318 }
319 RequestMessage::ImageGeneration { prompt, .. } => (vec![u32::MAX], prompt),
320 RequestMessage::CompletionTokens(it) => {
321 let Some(tokenizer) = &get_mut_arcmutex!(self.pipeline).tokenizer() else {
322 request
323 .response
324 .send(Response::ValidationError(
325 "Completion requests w/ raw tokens require the pipeline to have a tokenizer".into(),
326 ))
327 .await
328 .expect("Expected receiver.");
329 return;
330 };
331 let prompt = tokenizer
332 .decode(&it, false)
333 .map_err(|e| anyhow::Error::msg(e.to_string()));
334 (it, handle_seq_error!(prompt, request.response))
335 }
336 };
337 if prompt_tokens.is_empty() {
338 request
339 .response
340 .send(Response::ValidationError(
341 "Received an empty prompt.".into(),
342 ))
343 .await
344 .expect("Expected receiver.");
345 return;
346 }
347
348 if prompt_tokens.len() > get_mut_arcmutex!(self.pipeline).get_metadata().max_seq_len {
349 if !self.truncate_sequence {
350 request
351 .response
352 .send(Response::ValidationError(
353 format!("Prompt sequence length is greater than {}, perhaps consider using `truncate_sequence`?", get_mut_arcmutex!(self.pipeline).get_metadata().max_seq_len).into(),
354 )).await.expect("Expected receiver.");
355 return;
356 } else {
357 let prompt_len = prompt_tokens.len();
358 let max_len = get_mut_arcmutex!(self.pipeline).get_metadata().max_seq_len;
359 let currently_over = prompt_len - max_len;
360 let sampling_max = if let Some(sampling_max) = request.sampling_params.max_len {
361 if currently_over + sampling_max >= prompt_len {
362 10
363 } else {
364 sampling_max
365 }
366 } else {
367 10
368 };
369 prompt_tokens = prompt_tokens[(currently_over + sampling_max)..].to_vec();
370 warn!("Prompt for request {} was {} tokens over the model maximum length. The last {} tokens were truncated to make space for generation.", request.id, currently_over, prompt_len - prompt_tokens.len());
371 }
372 }
373 let prefill_cache = handle_seq_error!(
374 get_mut_arcmutex!(self.prefix_cacher).search_for_matching_cache(
375 &prompt_tokens,
376 images.as_ref().is_some_and(|x| !x.is_empty())
377 ),
378 request.response
379 );
380
381 let topk = request
382 .sampling_params
383 .top_k
384 .map(|x| x as i64)
385 .unwrap_or(-1);
386 let topp = request.sampling_params.top_p.unwrap_or(1.0);
387 let minp = request.sampling_params.min_p.unwrap_or(0.0);
388 let num_hidden_layers = get_mut_arcmutex!(self.pipeline)
389 .get_metadata()
390 .num_hidden_layers;
391
392 let (stop_toks, stop_strings) = match request.sampling_params.stop_toks {
393 None => (vec![], vec![]),
394 Some(StopTokens::Ids(ref i)) => {
395 let tok_env = {
396 let pipeline = get_mut_arcmutex!(self.pipeline);
397 pipeline.get_metadata().tok_env.clone()
398 };
399 for id in i {
400 if let Some(tok_env) = tok_env.as_ref() {
402 let tok_trie = tok_env.tok_trie();
403 if tok_trie.has_extensions(tok_trie.token(*id)) {
404 request
405 .response
406 .send(Response::ValidationError(
407 format!("Stop token {:?} is also a prefix of other tokens and cannot be used as a stop token.", tok_trie.token_str(*id)).into(),
408 ))
409 .await .expect("Expected receiver.");
410 return;
411 }
412 }
413 }
414
415 (i.clone(), vec![])
416 }
417 Some(StopTokens::Seqs(ref s)) => {
418 let mut stop_toks = Vec::new();
419 let mut stop_strings: Vec<String> = Vec::new();
420
421 let (tok_env, tokenizer) = {
422 let pipeline = get_mut_arcmutex!(self.pipeline);
423 let tok_env = pipeline.get_metadata().tok_env.clone();
424 let tokenizer = pipeline.tokenizer();
425 (tok_env, tokenizer)
426 };
427
428 for stop_txt in s {
429 let Some(tokenizer) = &tokenizer else {
430 request
431 .response
432 .send(Response::ValidationError(
433 "Completion requests require the pipeline to have a tokenizer"
434 .into(),
435 ))
436 .await
437 .expect("Expected receiver.");
438 return;
439 };
440 let encoded = tokenizer.encode_fast(stop_txt.to_string(), true);
441 let toks = handle_seq_error!(encoded, request.response)
442 .get_ids()
443 .to_vec();
444
445 if toks.len() == 1 {
446 if tok_env.as_ref().is_some_and(|tok_env| {
447 let tok_trie = tok_env.tok_trie();
448 tok_trie.has_extensions(tok_trie.token(toks[0]))
449 }) {
450 stop_strings.push(stop_txt.clone());
451 } else {
452 stop_toks.push(toks[0]);
453 }
454 } else {
455 stop_strings.push(stop_txt.clone());
456 }
457 }
458
459 (stop_toks, stop_strings)
460 }
461 };
462
463 let group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
464 request.sampling_params.n_choices,
465 request.is_streaming,
466 is_chat,
467 best_of,
468 )));
469
470 let tokenizer = get_mut_arcmutex!(self.pipeline).tokenizer();
471
472 let sampler = Sampler::new(
473 Some(request.sampling_params.temperature.unwrap_or(1.0)),
474 request.sampling_params.top_n_logprobs,
475 tokenizer,
476 request.sampling_params.frequency_penalty,
477 request.sampling_params.presence_penalty,
478 request.sampling_params.dry_params,
479 topk,
480 topp,
481 minp,
482 request.logits_processors.unwrap_or_default(),
483 );
484 let sampler = handle_seq_error!(sampler, request.response);
485
486 if request.sampling_params.n_choices == 0 {
487 request
488 .response
489 .send(Response::ValidationError(
490 "Number of choices must be greater than 0.".into(),
491 ))
492 .await
493 .expect("Expected receiver.");
494 return;
495 }
496
497 for response_index in 0..request.sampling_params.n_choices {
499 let trie = get_mut_arcmutex!(self.pipeline)
500 .get_metadata()
501 .tok_env
502 .clone();
503 let recognizer = match Self::build_sequence_recognizer(&trie, &request.constraint) {
504 Ok(recognizer) => recognizer,
505 Err(err) => {
506 request
507 .response
508 .send(Response::ValidationError(
509 format!("Invalid grammar. {}", err).into(),
510 ))
511 .await
512 .expect("Expected receiver.");
513 return;
514 }
515 };
516
517 let block_size = get_mut_arcmutex!(self.pipeline)
518 .get_metadata()
519 .cache_config
520 .clone()
521 .map(|conf| conf.block_size);
522
523 let eos_toks = get_mut_arcmutex!(self.pipeline)
524 .get_metadata()
525 .eos_tok
526 .clone();
527
528 let seq_preallocated_cache = if get_mut_arcmutex!(self.pipeline).do_preallocated_cache()
529 {
530 let metadata = get_mut_arcmutex!(self.pipeline).get_metadata();
531 let model_metadata = metadata
532 .model_metadata
533 .as_ref()
534 .expect("If a model has a NormalCache it must have a model metadata");
535 let n_tokens = prompt_tokens.len();
536 let required_blocks = n_tokens.div_ceil(NormalCache::CACHE_GROW_SIZE);
537 let max_seq_len = required_blocks * NormalCache::CACHE_GROW_SIZE;
538 let k_shape = (
539 1usize,
540 model_metadata.num_kv_heads(),
541 max_seq_len,
542 model_metadata.k_head_dim(),
543 );
544 let v_shape = (
545 1usize,
546 model_metadata.num_kv_heads(),
547 max_seq_len,
548 model_metadata.v_head_dim(),
549 );
550 let dtype = get_mut_arcmutex!(self.pipeline)
551 .get_metadata()
552 .activation_dtype;
553
554 let k_seq_cache = {
555 let k_seq_cache =
556 Tensor::zeros(k_shape, dtype, &get_mut_arcmutex!(self.pipeline).device());
557 match k_seq_cache {
558 Ok(x) => x,
559 Err(_) => {
560 request
561 .response
562 .send(Response::InternalError(
563 "Failed to allocate preallocated KV cache."
564 .to_string()
565 .into(),
566 ))
567 .await
568 .expect("Expected receiver.");
569 return;
570 }
571 }
572 };
573 let v_seq_cache = if k_shape == v_shape {
574 k_seq_cache.clone()
575 } else {
576 let v_seq_cache =
577 Tensor::zeros(v_shape, dtype, &get_mut_arcmutex!(self.pipeline).device());
578 match v_seq_cache {
579 Ok(x) => x,
580 Err(_) => {
581 request
582 .response
583 .send(Response::InternalError(
584 "Failed to allocate preallocated KV cache."
585 .to_string()
586 .into(),
587 ))
588 .await
589 .expect("Expected receiver.");
590 return;
591 }
592 }
593 };
594 Some((k_seq_cache, v_seq_cache))
595 } else {
596 None
597 };
598
599 let now = SystemTime::now()
600 .duration_since(UNIX_EPOCH)
601 .expect("Time travel has occurred!");
602 let seq = Sequence::new_waiting(
603 prompt_tokens.clone(),
604 prompt_text.clone(),
605 *get_mut_arcmutex!(self.id).deref(),
606 now.as_millis(),
607 num_hidden_layers,
608 request.response.clone(),
609 sampler.clone(),
610 stop_toks.clone(),
611 stop_strings.clone(),
612 request.sampling_params.max_len,
613 request.return_logprobs,
614 get_mut_arcmutex!(self.pipeline).get_metadata().is_xlora,
615 group.clone(),
616 response_index,
617 now.as_secs(),
618 recognizer,
619 request.suffix.clone(),
620 if echo_prompt {
621 Some(prompt_text.clone())
622 } else {
623 None
624 },
625 images.clone(),
626 block_size,
627 Some(matcher.clone()),
628 image_generation_format,
629 seq_step_type,
630 diffusion_params.clone(),
631 seq_preallocated_cache,
632 request.return_raw_logits,
633 eos_toks,
634 );
635 self.logger.add_new_sequence();
636 let seq = if let Some(prefill_cache) = prefill_cache.clone() {
637 self.logger.add_prefix_cache_hit();
638
639 seq.prefill_v2(
640 prefill_cache.normal,
641 prefill_cache.toks,
642 prefill_cache.offset,
643 )
644 } else {
645 seq
646 };
647 *get_mut_arcmutex!(self.id) += 1;
648 get_mut_arcmutex!(self.scheduler).add_seq(seq);
649 }
650 }
651
652 async fn tokenize_text(&self, request: TokenizationRequest) {
653 match request.text {
654 Either::Left(messages) => {
655 let pipeline = &*get_mut_arcmutex!(self.pipeline);
656 let tools = request.tools.unwrap_or_default();
657 let template = pipeline.get_processor().process(
658 pipeline,
659 messages,
660 request.add_generation_prompt,
661 request.add_special_tokens,
662 tools,
663 );
664 let toks = match template {
665 Ok((toks, _)) => toks,
666 Err(e) => {
667 request
668 .response
669 .send(Err(e))
670 .await
671 .expect("Expected receiver.");
672 return;
673 }
674 };
675 request
676 .response
677 .send(Ok(toks))
678 .await
679 .expect("Sender disconnected unexpectedly!");
680 }
681 Either::Right(text) => {
682 let pipeline = &*get_mut_arcmutex!(self.pipeline);
683 let tokenizer = pipeline.tokenizer();
684 let tokenizer = match tokenizer {
685 Some(tokenizer) => tokenizer,
686 None => {
687 request
688 .response
689 .send(Err(anyhow::Error::msg(
690 "Pipeline does not include a toksnizer.",
691 )))
692 .await
693 .expect("Expected receiver.");
694 return;
695 }
696 };
697 let toks = tokenizer.encode_fast(text, request.add_special_tokens);
698 let toks = match toks {
699 Ok(tokenizer) => tokenizer,
700 Err(e) => {
701 request
702 .response
703 .send(Err(anyhow::Error::msg(e)))
704 .await
705 .expect("Expected receiver.");
706 return;
707 }
708 };
709 request
710 .response
711 .send(Ok(toks.get_ids().to_vec()))
712 .await
713 .expect("Sender disconnected unexpectedly!");
714 }
715 };
716 }
717
718 async fn detokenize_text(&self, request: DetokenizationRequest) {
719 let pipeline = &*get_mut_arcmutex!(self.pipeline);
720 let tokenizer = pipeline.tokenizer();
721 let tokenizer = match tokenizer {
722 Some(tokenizer) => tokenizer,
723 None => {
724 request
725 .response
726 .send(Err(anyhow::Error::msg(
727 "Pipeline does not include a toksnizer.",
728 )))
729 .await
730 .expect("Expected receiver.");
731 return;
732 }
733 };
734 let txt = tokenizer.decode(&request.tokens, request.skip_special_tokens);
735 let txt = match txt {
736 Ok(tokenizer) => tokenizer,
737 Err(e) => {
738 request
739 .response
740 .send(Err(anyhow::Error::msg(e)))
741 .await
742 .expect("Expected receiver.");
743 return;
744 }
745 };
746 request
747 .response
748 .send(Ok(txt))
749 .await
750 .expect("Sender disconnected unexpectedly!");
751 }
752}