mistralrs_core/engine/
add_request.rs

1use crate::{
2    pipeline::NormalCache,
3    request::{DetokenizationRequest, NormalRequest, SearchContextSize, TokenizationRequest},
4    search::{self, SearchFunctionParameters, SearchResult},
5    sequence::SeqStepType,
6    tools::{ToolCallingMatcher, ToolChoice},
7    MessageContent, RequestMessage, Response, ResponseOk,
8};
9use candle_core::Tensor;
10use either::Either;
11use indexmap::IndexMap;
12use std::{
13    borrow::Cow,
14    ops::Deref,
15    sync::{atomic::Ordering, Arc},
16    time::{SystemTime, UNIX_EPOCH},
17};
18use tokenizers::InputSequence;
19use tracing::{info, warn};
20
21use crate::{
22    get_mut_arcmutex, handle_seq_error,
23    request::Request,
24    sampler::Sampler,
25    sequence::{Sequence, SequenceGroup},
26    StopTokens,
27};
28
29use super::{Engine, TERMINATE_ALL_NEXT_STEP};
30
31impl Engine {
32    pub async fn handle_request(self: Arc<Self>, request: Request) {
33        match request {
34            Request::Normal(request) => {
35                if matches!(
36                    request.messages,
37                    RequestMessage::Chat { .. } | RequestMessage::VisionChat { .. }
38                ) && request.web_search_options.is_some()
39                    && !request.is_streaming
40                    && get_mut_arcmutex!(self.bert_pipeline).is_some()
41                {
42                    let Some(web_search_options) = request.web_search_options.clone() else {
43                        unreachable!()
44                    };
45                    let mut first_request = request.clone();
46                    // Actually add the search tool here
47                    first_request
48                        .tools
49                        .get_or_insert_with(Vec::new)
50                        .push(search::get_search_tool(&web_search_options).unwrap());
51
52                    let mut second_request = first_request.clone();
53                    first_request.web_search_options = None;
54                    second_request.web_search_options = None;
55
56                    let this = self.clone();
57                    let handle = tokio::spawn(async move {
58                        let (new_sender, mut first_receiver) = tokio::sync::mpsc::channel(1);
59                        second_request.response = new_sender;
60                        std::mem::swap(&mut first_request.response, &mut second_request.response);
61
62                        this.add_request(first_request).await;
63                        let ResponseOk::Done(done) =
64                            first_receiver.recv().await.unwrap().as_result().unwrap()
65                        else {
66                            unreachable!()
67                        };
68
69                        let tool_calls = match &done.choices[0].message.tool_calls {
70                            Some(tool_calls)
71                                if tool_calls.len() == 1
72                                    && tool_calls[0].function.name == search::SEARCH_TOOL_NAME =>
73                            {
74                                &tool_calls[0]
75                            }
76                            None => {
77                                second_request
78                                    .response
79                                    .send(Response::Done(done))
80                                    .await
81                                    .unwrap();
82                                return;
83                            }
84                            Some(_) => {
85                                second_request
86                                    .response
87                                    .send(Response::Done(done))
88                                    .await
89                                    .unwrap();
90                                return;
91                            }
92                        };
93
94                        let RequestMessage::Chat(messages) = &mut second_request.messages else {
95                            unreachable!()
96                        };
97
98                        // Add assistant call message
99                        {
100                            let mut message: IndexMap<String, MessageContent> = IndexMap::new();
101                            message
102                                .insert("role".to_string(), Either::Left("assistant".to_string()));
103                            message.insert(
104                                "content".to_string(),
105                                Either::Left(format!(
106                                    "{{\"name\":\"{}\",\"arguments\":\"{}\"}}",
107                                    tool_calls.function.name, tool_calls.function.arguments
108                                )),
109                            );
110                            messages.push(message);
111                        }
112                        let tool_call_params: SearchFunctionParameters =
113                            serde_json::from_str(&tool_calls.function.arguments).unwrap();
114
115                        // Add tool response
116                        {
117                            let tokenizer = get_mut_arcmutex!(this.pipeline)
118                                .tokenizer()
119                                .expect("A tokenizer is expected for non-diffusion models.");
120                            let mut results = search::run_search_tool(&tool_call_params)
121                                .unwrap()
122                                .into_iter()
123                                .map(|result| {
124                                    let len = {
125                                        let inp = InputSequence::Raw(Cow::from(&result.content));
126                                        tokenizer
127                                            .encode_fast(inp, false)
128                                            .map(|x| x.len())
129                                            .unwrap_or(usize::MAX)
130                                    };
131                                    (result, len)
132                                })
133                                .collect::<Vec<_>>();
134                            // Sort increasing by tokenized length, if it fails, put it at the end.
135                            results.sort_by_key(|(_, len)| *len);
136
137                            {
138                                let device = get_mut_arcmutex!(this.pipeline).device();
139
140                                let Some(bert_pipeline) =
141                                    &mut *get_mut_arcmutex!(this.bert_pipeline)
142                                else {
143                                    unreachable!()
144                                };
145
146                                let decreasing_indexes = search::rag::compute_most_similar(
147                                    &device,
148                                    &tool_call_params.query,
149                                    results.iter().map(|(res, _)| res).collect::<Vec<_>>(),
150                                    bert_pipeline,
151                                )
152                                .unwrap();
153
154                                // Rerank the results
155                                let mut results_old = Vec::new();
156                                std::mem::swap(&mut results_old, &mut results);
157                                for &index in &decreasing_indexes {
158                                    let mut current_result: (SearchResult, usize) =
159                                        Default::default();
160                                    std::mem::swap(&mut current_result, &mut results_old[index]);
161
162                                    results.push(current_result);
163                                }
164                            }
165
166                            // Manage context size by # of tokens. Apply default here.
167                            let max_results_budget_toks =
168                                match web_search_options.search_context_size.unwrap_or_default() {
169                                    SearchContextSize::High => 10000_usize,
170                                    SearchContextSize::Medium => 7500_usize,
171                                    SearchContextSize::Low => 3000_usize,
172                                };
173                            let mut used_results = Vec::new();
174                            let mut used_len = 0;
175                            for (item, len) in results {
176                                if used_len + len >= max_results_budget_toks {
177                                    break;
178                                }
179                                // So the info! below gets the correct value
180                                used_len += len;
181                                used_results.push(item);
182                            }
183
184                            let tool_result = serde_json::to_string(&used_results)
185                                .unwrap()
186                                .replace("\\n", "\n")
187                                .replace("\\\"", "\"")
188                                .replace("\\\\", "\\");
189                            info!("Web search executed, using {used_len} tokens of {} search results.", used_results.len());
190
191                            let mut message: IndexMap<String, MessageContent> = IndexMap::new();
192                            message.insert("role".to_string(), Either::Left("tool".to_string()));
193                            message.insert(
194                                "content".to_string(),
195                                Either::Left(format!("{{\"output\": \"{tool_result}\"}}")),
196                            );
197                            messages.push(message);
198                        }
199
200                        this.add_request(second_request).await;
201                    });
202                    get_mut_arcmutex!(self.handles).push(handle);
203                } else {
204                    self.add_request(request).await
205                }
206            }
207            Request::ReIsq(level) => {
208                if let Err(e) = get_mut_arcmutex!(self.pipeline).re_isq_model(level) {
209                    warn!("ISQ requantization failed: {e:?}");
210                }
211            }
212            Request::Tokenize(req) => self.tokenize_text(req).await,
213            Request::Detokenize(req) => self.detokenize_text(req).await,
214            Request::Terminate => (),
215            Request::TerminateAllSeqsNextStep => {
216                TERMINATE_ALL_NEXT_STEP.store(true, Ordering::SeqCst)
217            }
218        }
219    }
220
221    async fn add_request(&self, request: NormalRequest) {
222        let is_chat = matches!(
223            request.messages,
224            RequestMessage::Chat(_) | RequestMessage::VisionChat { .. }
225        );
226        let echo_prompt = matches!(
227            request.messages,
228            RequestMessage::Completion {
229                echo_prompt: true,
230                ..
231            }
232        );
233
234        let best_of = match request.messages {
235            RequestMessage::Completion { best_of, .. } => best_of,
236            RequestMessage::Chat(_)
237            | RequestMessage::CompletionTokens(_)
238            | RequestMessage::VisionChat { .. }
239            | RequestMessage::ImageGeneration { .. } => None,
240        };
241        if is_chat
242            && !get_mut_arcmutex!(self.pipeline)
243                .get_chat_template()
244                .as_ref()
245                .is_some_and(|ch_t| ch_t.has_chat_template())
246        {
247            request
248                    .response
249                    .send(Response::ValidationError(
250                        "Received messages for a model which does not have a chat template. Either use a different model or pass a single string as the prompt".into(),
251                    )).await.expect("Expected receiver.");
252            return;
253        }
254
255        let images = match request.messages {
256            RequestMessage::VisionChat {
257                ref images,
258                messages: _,
259            } => Some(images.clone()),
260            _ => None,
261        };
262
263        let matcher = Arc::new(handle_seq_error!(
264            ToolCallingMatcher::new(request.tool_choice.unwrap_or(ToolChoice::Auto),),
265            request.response
266        ));
267
268        let image_generation_format = match &request.messages {
269            RequestMessage::ImageGeneration { format, .. } => Some(*format),
270            _ => None,
271        };
272
273        let seq_step_type = match &request.messages {
274            RequestMessage::ImageGeneration { .. } => SeqStepType::OneShot,
275            _ => SeqStepType::PromptAndDecode,
276        };
277
278        let diffusion_params = match &request.messages {
279            RequestMessage::ImageGeneration {
280                generation_params, ..
281            } => Some(generation_params.clone()),
282            _ => None,
283        };
284
285        let (mut prompt_tokens, prompt_text) = match request.messages {
286            RequestMessage::Chat(messages)
287            | RequestMessage::VisionChat {
288                images: _,
289                messages,
290            } => {
291                let pipeline = &*get_mut_arcmutex!(self.pipeline);
292                let tools = request.tools.unwrap_or_default();
293                let template = pipeline
294                    .get_processor()
295                    .process(pipeline, messages, true, true, tools);
296                handle_seq_error!(template, request.response)
297            }
298            RequestMessage::Completion { text, .. } => {
299                let Some(tokenizer) = &get_mut_arcmutex!(self.pipeline).tokenizer() else {
300                    request
301                        .response
302                        .send(Response::ValidationError(
303                            "Completion requests require the pipeline to have a tokenizer".into(),
304                        ))
305                        .await
306                        .expect("Expected receiver.");
307                    return;
308                };
309                let prompt = tokenizer
310                    .encode_fast(text.clone(), true)
311                    .map_err(anyhow::Error::msg);
312                (
313                    handle_seq_error!(prompt, request.response)
314                        .get_ids()
315                        .to_vec(),
316                    text,
317                )
318            }
319            RequestMessage::ImageGeneration { prompt, .. } => (vec![u32::MAX], prompt),
320            RequestMessage::CompletionTokens(it) => {
321                let Some(tokenizer) = &get_mut_arcmutex!(self.pipeline).tokenizer() else {
322                    request
323                        .response
324                        .send(Response::ValidationError(
325                            "Completion requests w/ raw tokens require the pipeline to have a tokenizer".into(),
326                        ))
327                        .await
328                        .expect("Expected receiver.");
329                    return;
330                };
331                let prompt = tokenizer
332                    .decode(&it, false)
333                    .map_err(|e| anyhow::Error::msg(e.to_string()));
334                (it, handle_seq_error!(prompt, request.response))
335            }
336        };
337        if prompt_tokens.is_empty() {
338            request
339                .response
340                .send(Response::ValidationError(
341                    "Received an empty prompt.".into(),
342                ))
343                .await
344                .expect("Expected receiver.");
345            return;
346        }
347
348        if prompt_tokens.len() > get_mut_arcmutex!(self.pipeline).get_metadata().max_seq_len {
349            if !self.truncate_sequence {
350                request
351                    .response
352                    .send(Response::ValidationError(
353                        format!("Prompt sequence length is greater than {}, perhaps consider using `truncate_sequence`?", get_mut_arcmutex!(self.pipeline).get_metadata().max_seq_len).into(),
354                    )).await.expect("Expected receiver.");
355                return;
356            } else {
357                let prompt_len = prompt_tokens.len();
358                let max_len = get_mut_arcmutex!(self.pipeline).get_metadata().max_seq_len;
359                let currently_over = prompt_len - max_len;
360                let sampling_max = if let Some(sampling_max) = request.sampling_params.max_len {
361                    if currently_over + sampling_max >= prompt_len {
362                        10
363                    } else {
364                        sampling_max
365                    }
366                } else {
367                    10
368                };
369                prompt_tokens = prompt_tokens[(currently_over + sampling_max)..].to_vec();
370                warn!("Prompt for request {} was {} tokens over the model maximum length. The last {} tokens were truncated to make space for generation.", request.id, currently_over, prompt_len - prompt_tokens.len());
371            }
372        }
373        let prefill_cache = handle_seq_error!(
374            get_mut_arcmutex!(self.prefix_cacher).search_for_matching_cache(
375                &prompt_tokens,
376                images.as_ref().is_some_and(|x| !x.is_empty())
377            ),
378            request.response
379        );
380
381        let topk = request
382            .sampling_params
383            .top_k
384            .map(|x| x as i64)
385            .unwrap_or(-1);
386        let topp = request.sampling_params.top_p.unwrap_or(1.0);
387        let minp = request.sampling_params.min_p.unwrap_or(0.0);
388        let num_hidden_layers = get_mut_arcmutex!(self.pipeline)
389            .get_metadata()
390            .num_hidden_layers;
391
392        let (stop_toks, stop_strings) = match request.sampling_params.stop_toks {
393            None => (vec![], vec![]),
394            Some(StopTokens::Ids(ref i)) => {
395                let tok_env = {
396                    let pipeline = get_mut_arcmutex!(self.pipeline);
397                    pipeline.get_metadata().tok_env.clone()
398                };
399                for id in i {
400                    // We can't use ` ` (space) as a stop token because other tokens like ` moon` start with a space.
401                    if let Some(tok_env) = tok_env.as_ref() {
402                        let tok_trie = tok_env.tok_trie();
403                        if tok_trie.has_extensions(tok_trie.token(*id)) {
404                            request
405                                .response
406                                .send(Response::ValidationError(
407                                    format!("Stop token {:?} is also a prefix of other tokens and cannot be used as a stop token.", tok_trie.token_str(*id)).into(),
408                                ))
409                                .await .expect("Expected receiver.");
410                            return;
411                        }
412                    }
413                }
414
415                (i.clone(), vec![])
416            }
417            Some(StopTokens::Seqs(ref s)) => {
418                let mut stop_toks = Vec::new();
419                let mut stop_strings: Vec<String> = Vec::new();
420
421                let (tok_env, tokenizer) = {
422                    let pipeline = get_mut_arcmutex!(self.pipeline);
423                    let tok_env = pipeline.get_metadata().tok_env.clone();
424                    let tokenizer = pipeline.tokenizer();
425                    (tok_env, tokenizer)
426                };
427
428                for stop_txt in s {
429                    let Some(tokenizer) = &tokenizer else {
430                        request
431                            .response
432                            .send(Response::ValidationError(
433                                "Completion requests require the pipeline to have a tokenizer"
434                                    .into(),
435                            ))
436                            .await
437                            .expect("Expected receiver.");
438                        return;
439                    };
440                    let encoded = tokenizer.encode_fast(stop_txt.to_string(), true);
441                    let toks = handle_seq_error!(encoded, request.response)
442                        .get_ids()
443                        .to_vec();
444
445                    if toks.len() == 1 {
446                        if tok_env.as_ref().is_some_and(|tok_env| {
447                            let tok_trie = tok_env.tok_trie();
448                            tok_trie.has_extensions(tok_trie.token(toks[0]))
449                        }) {
450                            stop_strings.push(stop_txt.clone());
451                        } else {
452                            stop_toks.push(toks[0]);
453                        }
454                    } else {
455                        stop_strings.push(stop_txt.clone());
456                    }
457                }
458
459                (stop_toks, stop_strings)
460            }
461        };
462
463        let group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
464            request.sampling_params.n_choices,
465            request.is_streaming,
466            is_chat,
467            best_of,
468        )));
469
470        let tokenizer = get_mut_arcmutex!(self.pipeline).tokenizer();
471
472        let sampler = Sampler::new(
473            Some(request.sampling_params.temperature.unwrap_or(1.0)),
474            request.sampling_params.top_n_logprobs,
475            tokenizer,
476            request.sampling_params.frequency_penalty,
477            request.sampling_params.presence_penalty,
478            request.sampling_params.dry_params,
479            topk,
480            topp,
481            minp,
482            request.logits_processors.unwrap_or_default(),
483        );
484        let sampler = handle_seq_error!(sampler, request.response);
485
486        if request.sampling_params.n_choices == 0 {
487            request
488                .response
489                .send(Response::ValidationError(
490                    "Number of choices must be greater than 0.".into(),
491                ))
492                .await
493                .expect("Expected receiver.");
494            return;
495        }
496
497        // Add sequences
498        for response_index in 0..request.sampling_params.n_choices {
499            let trie = get_mut_arcmutex!(self.pipeline)
500                .get_metadata()
501                .tok_env
502                .clone();
503            let recognizer = match Self::build_sequence_recognizer(&trie, &request.constraint) {
504                Ok(recognizer) => recognizer,
505                Err(err) => {
506                    request
507                        .response
508                        .send(Response::ValidationError(
509                            format!("Invalid grammar. {}", err).into(),
510                        ))
511                        .await
512                        .expect("Expected receiver.");
513                    return;
514                }
515            };
516
517            let block_size = get_mut_arcmutex!(self.pipeline)
518                .get_metadata()
519                .cache_config
520                .clone()
521                .map(|conf| conf.block_size);
522
523            let eos_toks = get_mut_arcmutex!(self.pipeline)
524                .get_metadata()
525                .eos_tok
526                .clone();
527
528            let seq_preallocated_cache = if get_mut_arcmutex!(self.pipeline).do_preallocated_cache()
529            {
530                let metadata = get_mut_arcmutex!(self.pipeline).get_metadata();
531                let model_metadata = metadata
532                    .model_metadata
533                    .as_ref()
534                    .expect("If a model has a NormalCache it must have a model metadata");
535                let n_tokens = prompt_tokens.len();
536                let required_blocks = n_tokens.div_ceil(NormalCache::CACHE_GROW_SIZE);
537                let max_seq_len = required_blocks * NormalCache::CACHE_GROW_SIZE;
538                let k_shape = (
539                    1usize,
540                    model_metadata.num_kv_heads(),
541                    max_seq_len,
542                    model_metadata.k_head_dim(),
543                );
544                let v_shape = (
545                    1usize,
546                    model_metadata.num_kv_heads(),
547                    max_seq_len,
548                    model_metadata.v_head_dim(),
549                );
550                let dtype = get_mut_arcmutex!(self.pipeline)
551                    .get_metadata()
552                    .activation_dtype;
553
554                let k_seq_cache = {
555                    let k_seq_cache =
556                        Tensor::zeros(k_shape, dtype, &get_mut_arcmutex!(self.pipeline).device());
557                    match k_seq_cache {
558                        Ok(x) => x,
559                        Err(_) => {
560                            request
561                                .response
562                                .send(Response::InternalError(
563                                    "Failed to allocate preallocated KV cache."
564                                        .to_string()
565                                        .into(),
566                                ))
567                                .await
568                                .expect("Expected receiver.");
569                            return;
570                        }
571                    }
572                };
573                let v_seq_cache = if k_shape == v_shape {
574                    k_seq_cache.clone()
575                } else {
576                    let v_seq_cache =
577                        Tensor::zeros(v_shape, dtype, &get_mut_arcmutex!(self.pipeline).device());
578                    match v_seq_cache {
579                        Ok(x) => x,
580                        Err(_) => {
581                            request
582                                .response
583                                .send(Response::InternalError(
584                                    "Failed to allocate preallocated KV cache."
585                                        .to_string()
586                                        .into(),
587                                ))
588                                .await
589                                .expect("Expected receiver.");
590                            return;
591                        }
592                    }
593                };
594                Some((k_seq_cache, v_seq_cache))
595            } else {
596                None
597            };
598
599            let now = SystemTime::now()
600                .duration_since(UNIX_EPOCH)
601                .expect("Time travel has occurred!");
602            let seq = Sequence::new_waiting(
603                prompt_tokens.clone(),
604                prompt_text.clone(),
605                *get_mut_arcmutex!(self.id).deref(),
606                now.as_millis(),
607                num_hidden_layers,
608                request.response.clone(),
609                sampler.clone(),
610                stop_toks.clone(),
611                stop_strings.clone(),
612                request.sampling_params.max_len,
613                request.return_logprobs,
614                get_mut_arcmutex!(self.pipeline).get_metadata().is_xlora,
615                group.clone(),
616                response_index,
617                now.as_secs(),
618                recognizer,
619                request.suffix.clone(),
620                if echo_prompt {
621                    Some(prompt_text.clone())
622                } else {
623                    None
624                },
625                images.clone(),
626                block_size,
627                Some(matcher.clone()),
628                image_generation_format,
629                seq_step_type,
630                diffusion_params.clone(),
631                seq_preallocated_cache,
632                request.return_raw_logits,
633                eos_toks,
634            );
635            self.logger.add_new_sequence();
636            let seq = if let Some(prefill_cache) = prefill_cache.clone() {
637                self.logger.add_prefix_cache_hit();
638
639                seq.prefill_v2(
640                    prefill_cache.normal,
641                    prefill_cache.toks,
642                    prefill_cache.offset,
643                )
644            } else {
645                seq
646            };
647            *get_mut_arcmutex!(self.id) += 1;
648            get_mut_arcmutex!(self.scheduler).add_seq(seq);
649        }
650    }
651
652    async fn tokenize_text(&self, request: TokenizationRequest) {
653        match request.text {
654            Either::Left(messages) => {
655                let pipeline = &*get_mut_arcmutex!(self.pipeline);
656                let tools = request.tools.unwrap_or_default();
657                let template = pipeline.get_processor().process(
658                    pipeline,
659                    messages,
660                    request.add_generation_prompt,
661                    request.add_special_tokens,
662                    tools,
663                );
664                let toks = match template {
665                    Ok((toks, _)) => toks,
666                    Err(e) => {
667                        request
668                            .response
669                            .send(Err(e))
670                            .await
671                            .expect("Expected receiver.");
672                        return;
673                    }
674                };
675                request
676                    .response
677                    .send(Ok(toks))
678                    .await
679                    .expect("Sender disconnected unexpectedly!");
680            }
681            Either::Right(text) => {
682                let pipeline = &*get_mut_arcmutex!(self.pipeline);
683                let tokenizer = pipeline.tokenizer();
684                let tokenizer = match tokenizer {
685                    Some(tokenizer) => tokenizer,
686                    None => {
687                        request
688                            .response
689                            .send(Err(anyhow::Error::msg(
690                                "Pipeline does not include a toksnizer.",
691                            )))
692                            .await
693                            .expect("Expected receiver.");
694                        return;
695                    }
696                };
697                let toks = tokenizer.encode_fast(text, request.add_special_tokens);
698                let toks = match toks {
699                    Ok(tokenizer) => tokenizer,
700                    Err(e) => {
701                        request
702                            .response
703                            .send(Err(anyhow::Error::msg(e)))
704                            .await
705                            .expect("Expected receiver.");
706                        return;
707                    }
708                };
709                request
710                    .response
711                    .send(Ok(toks.get_ids().to_vec()))
712                    .await
713                    .expect("Sender disconnected unexpectedly!");
714            }
715        };
716    }
717
718    async fn detokenize_text(&self, request: DetokenizationRequest) {
719        let pipeline = &*get_mut_arcmutex!(self.pipeline);
720        let tokenizer = pipeline.tokenizer();
721        let tokenizer = match tokenizer {
722            Some(tokenizer) => tokenizer,
723            None => {
724                request
725                    .response
726                    .send(Err(anyhow::Error::msg(
727                        "Pipeline does not include a toksnizer.",
728                    )))
729                    .await
730                    .expect("Expected receiver.");
731                return;
732            }
733        };
734        let txt = tokenizer.decode(&request.tokens, request.skip_special_tokens);
735        let txt = match txt {
736            Ok(tokenizer) => tokenizer,
737            Err(e) => {
738                request
739                    .response
740                    .send(Err(anyhow::Error::msg(e)))
741                    .await
742                    .expect("Expected receiver.");
743                return;
744            }
745        };
746        request
747            .response
748            .send(Ok(txt))
749            .await
750            .expect("Sender disconnected unexpectedly!");
751    }
752}