mistralrs_core/pipeline/
sampling.rs

1use std::sync::Arc;
2
3use candle_core::{DType, Result, Tensor};
4use rand_isaac::Isaac64Rng;
5
6use crate::{
7    prefix_cacher::PrefixCacheManagerV2,
8    sampler::Logprobs,
9    sequence::{Sequence, SequenceRecognizer, SequenceState, StopReason},
10    tools::parse_text_tools,
11};
12
13use super::Pipeline;
14
15macro_rules! fixup_sentencepiece {
16    ($txt:expr) => {
17        $txt.to_string().replace("▁", " ")
18    };
19    (Option $txt:expr) => {
20        match &$txt {
21            Some(txt) => Some(fixup_sentencepiece!(txt)),
22            None => None,
23        }
24    };
25}
26
27pub(crate) async fn finish_or_add_toks_to_seq(
28    this: &dyn Pipeline,
29    prefix_cacher: &mut PrefixCacheManagerV2,
30    seq: &mut Sequence,
31    logprobs: Logprobs,
32    eos_tok: Option<&[u32]>,
33    use_prefix_cacher: bool,
34) -> Result<()> {
35    let mut is_done = seq.is_done(logprobs.token, eos_tok, this.get_metadata().max_seq_len);
36    seq.add_token(
37        logprobs.clone(),
38        this.get_metadata()
39            .tok_env()
40            .ok_or(candle_core::Error::Msg(
41                "`finish_or_add_toks_to_seq` requires the pipeline to have a token trie"
42                    .to_string(),
43            ))?
44            .tok_trie()
45            .decode(&[logprobs.token]),
46        &is_done,
47    );
48
49    // If we can have a tool and we got a tool, stop the sequence early.
50    // Doesn't conflict with the logic below because it does the same thing anyway.
51    if let Some(ref t) = seq.tools {
52        if let Ok(Some(ref d)) = seq.peek_delta() {
53            let (_tool_use_still_possible, tool_use_is_done) =
54                t.prefix_could_be_tool(this, d.as_str())?;
55
56            if tool_use_is_done
57                && matches!(
58                    parse_text_tools(this, d, seq.tools.clone()),
59                    Ok((None, _tools))
60                )
61            {
62                seq.set_state(SequenceState::Done(StopReason::Eos));
63                is_done = Some(StopReason::Eos);
64            }
65        }
66    };
67
68    // Handle streaming requests
69    if seq.get_mut_group().is_streaming {
70        let mut tool_use_still_possible = false;
71        let mut tool_use_is_done = false;
72        if let Some(ref t) = seq.tools {
73            if let Ok(Some(ref d)) = seq.peek_delta() {
74                (tool_use_still_possible, tool_use_is_done) =
75                    t.prefix_could_be_tool(this, d.as_str())?;
76            }
77        };
78
79        // let send = seq.get_toks().len() % 2 == 0 || is_done.is_some();
80        let send = true;
81        if !tool_use_still_possible || tool_use_is_done {
82            if send {
83                if let Some(delta) = crate::handle_seq_error_ok!(seq.get_delta(), seq.responder()) {
84                    if seq.get_mut_group().is_chat {
85                        let (text_new, tool_calls) =
86                            parse_text_tools(this, delta.as_str(), seq.tools.clone())
87                                .map_err(candle_core::Error::msg)?;
88                        if !tool_calls.is_empty() {
89                            is_done = Some(StopReason::ToolCalls);
90                        }
91
92                        seq.add_streaming_chunk_choice_to_group(crate::ChunkChoice {
93                            delta: crate::Delta {
94                                content: fixup_sentencepiece!(
95                                    Option text_new.map(ToString::to_string)
96                                ),
97                                role: "assistant".to_string(),
98                                tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
99                            },
100                            index: seq.get_response_index(),
101                            finish_reason: is_done.map(|x| x.to_string()),
102                            logprobs: if seq.return_logprobs() {
103                                Some(crate::ResponseLogprob {
104                                    token: delta,
105                                    bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
106                                    logprob: logprobs.logprob,
107                                    top_logprobs: logprobs.top_logprobs.unwrap().clone(),
108                                })
109                            } else {
110                                None
111                            },
112                        });
113                    } else {
114                        seq.add_streaming_completion_chunk_choice_to_group(
115                            crate::CompletionChunkChoice {
116                                text: fixup_sentencepiece!(delta),
117                                index: seq.get_response_index(),
118                                finish_reason: is_done.map(|x| x.to_string()),
119                                logprobs: if seq.return_logprobs() {
120                                    Some(crate::ResponseLogprob {
121                                        token: delta,
122                                        bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
123                                        logprob: logprobs.logprob,
124                                        top_logprobs: logprobs.top_logprobs.unwrap().clone(),
125                                    })
126                                } else {
127                                    None
128                                },
129                            },
130                        );
131                    }
132                }
133            }
134
135            if let Some(reason) = is_done {
136                if use_prefix_cacher {
137                    prefix_cacher.add_sequence(seq);
138                    prefix_cacher.evict_caches()?;
139                }
140                seq.set_state(crate::sequence::SequenceState::Done(reason));
141                this.reset_non_granular_state();
142            }
143
144            // Send usage on final chunk.
145            let usage_opt = if is_done.is_some() {
146                let usage = seq.get_mut_group().get_usage();
147                seq.get_mut_group().total_prompt_toks = 0;
148                seq.get_mut_group().total_toks = 0;
149                Some(usage)
150            } else {
151                None
152            };
153
154            if seq
155                .get_mut_group()
156                .maybe_send_streaming_response(seq, this.name().clone(), usage_opt)
157                .await
158                .is_err()
159            {
160                // If we can't send the response, cancel the sequence
161                seq.set_state(crate::sequence::SequenceState::Done(
162                    crate::sequence::StopReason::Canceled,
163                ));
164                this.reset_non_granular_state();
165            }
166        }
167    } else if let Some(mut reason) = is_done {
168        /*
169        ***********************
170        Finish the sequence now
171        ***********************
172        */
173        {
174            seq.set_state(crate::sequence::SequenceState::Done(reason));
175            let (tokenizer, pipeline_name) = {
176                let pipeline_name = this.name();
177                let tokenizer = this.tokenizer();
178                (tokenizer, pipeline_name)
179            };
180
181            let logprobs = if seq.return_logprobs() {
182                let mut logprobs = Vec::new();
183                for logprob in seq.logprobs() {
184                    let resp_logprob = crate::ResponseLogprob {
185                        token: crate::handle_seq_error_ok!(
186                        tokenizer
187                        .as_ref()
188                        .ok_or(candle_core::Error::Msg(
189                            "`finish_or_add_toks_to_seq` requires the pipeline to have a tokenizer"
190                                .to_string(),
191                        ))?.decode(&[logprob.token], false),
192                        seq.responder()
193                    ),
194                        bytes: logprob.bytes.clone().map(|b| b.into_bytes()),
195                        logprob: logprob.logprob,
196                        top_logprobs: logprob.top_logprobs.clone().unwrap(),
197                    };
198                    logprobs.push(resp_logprob);
199                }
200                Some(logprobs)
201            } else {
202                None
203            };
204
205            let text = match reason {
206                crate::sequence::StopReason::Length(_)
207                | crate::sequence::StopReason::ModelLength(_)
208                | crate::sequence::StopReason::Eos
209                | crate::sequence::StopReason::StopTok(_)
210                | crate::sequence::StopReason::Canceled
211                | crate::sequence::StopReason::ToolCalls => {
212                    String::from_utf8_lossy(seq.completion_bytes())
213                        .trim_start()
214                        .to_string()
215                }
216                crate::sequence::StopReason::StopString {
217                    completion_bytes_pos,
218                    ..
219                } => {
220                    let txt = String::from_utf8_lossy(seq.completion_bytes());
221                    txt[..completion_bytes_pos].trim_start().to_string()
222                }
223                crate::sequence::StopReason::GeneratedImage
224                | crate::sequence::StopReason::GeneratedSpeech => {
225                    candle_core::bail!("Stop reason was `GeneratedImage`.")
226                }
227            };
228
229            if seq.get_mut_group().is_chat {
230                let (text_new, tool_calls) =
231                    parse_text_tools(this, text.as_str(), seq.tools.clone())
232                        .map_err(candle_core::Error::msg)?;
233                if !tool_calls.is_empty() {
234                    reason = StopReason::ToolCalls;
235                }
236
237                let choice = crate::Choice {
238                    finish_reason: fixup_sentencepiece!(reason),
239                    index: seq.get_response_index(),
240                    message: crate::ResponseMessage {
241                        content: text_new.map(ToString::to_string),
242                        role: "assistant".to_string(),
243                        tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
244                    },
245                    logprobs: logprobs.map(|l| crate::Logprobs { content: Some(l) }),
246                };
247                seq.add_choice_to_group(choice);
248            } else {
249                let choice = crate::CompletionChoice {
250                    finish_reason: fixup_sentencepiece!(reason),
251                    index: seq.get_response_index(),
252                    text,
253                    logprobs: None,
254                };
255                seq.add_completion_choice_to_group(choice);
256            }
257
258            if use_prefix_cacher {
259                prefix_cacher.add_sequence(seq);
260                prefix_cacher.evict_caches()?;
261            }
262
263            let group = seq.get_mut_group();
264            if group.is_chat {
265                group
266                    .maybe_send_chat_done_response(
267                        crate::ChatCompletionResponse {
268                            id: seq.id().to_string(),
269                            choices: group.get_choices().to_vec(),
270                            created: seq.creation_time(),
271                            model: pipeline_name,
272                            system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
273                            object: "chat.completion".to_string(),
274                            usage: group.get_usage(),
275                        },
276                        seq.responder(),
277                    )
278                    .await
279                    .map_err(candle_core::Error::msg)?;
280            } else {
281                group
282                    .maybe_send_completion_done_response(
283                        crate::CompletionResponse {
284                            id: seq.id().to_string(),
285                            choices: group.get_completion_choices().to_vec(),
286                            created: seq.creation_time(),
287                            model: pipeline_name,
288                            system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
289                            object: "text_completion".to_string(),
290                            usage: group.get_usage(),
291                        },
292                        seq.responder(),
293                    )
294                    .await
295                    .map_err(candle_core::Error::msg)?;
296            }
297        }
298        this.reset_non_granular_state();
299    }
300
301    Ok(())
302}
303
304pub async fn sample_and_add_toks(
305    this: &dyn Pipeline,
306    seqs: &mut [&mut Sequence],
307    logits_seq: Vec<Tensor>,
308    prefix_cacher: &mut PrefixCacheManagerV2,
309    disable_eos_stop: bool,
310    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
311) -> Result<()> {
312    let seqs_len = seqs.len();
313    debug_assert_eq!(logits_seq.len(), seqs_len);
314
315    let use_async_pool = seqs_len > 1;
316
317    let sampling_futures: Vec<_> = std::iter::zip(logits_seq, seqs.iter_mut())
318        .map(|(logits_per_seq, seq)| {
319            let return_logprobs = seq.return_logprobs();
320            sample_sequence(
321                logits_per_seq,
322                seq,
323                return_logprobs,
324                rng.clone(),
325                use_async_pool,
326                false,
327                use_async_pool,
328            )
329        })
330        .collect();
331    let sampled_vec = futures::future::join_all(sampling_futures).await;
332
333    for (sampled, seq) in std::iter::zip(sampled_vec, seqs.iter_mut()) {
334        let next_token = crate::handle_seq_error_stateaware_ok!(sampled, seq);
335
336        let metadata = this.get_metadata();
337        let eos_tok = if disable_eos_stop {
338            None
339        } else {
340            Some(&metadata.eos_tok[..])
341        };
342
343        finish_or_add_toks_to_seq(this, prefix_cacher, seq, next_token, eos_tok, true).await?;
344    }
345
346    Ok(())
347}
348
349/// Async sample optionally adding to trie.
350#[allow(clippy::too_many_arguments)]
351pub async fn sample_sequence(
352    logits: Tensor,
353    seq: &mut Sequence,
354    return_logprobs: bool,
355    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
356    use_async_pool: bool,
357    sample_speculative: bool,
358    multiple_sequences: bool,
359) -> Result<Logprobs> {
360    let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
361
362    let sampler = seq.sampler();
363    let ctx_clone = seq.get_toks().to_vec();
364    let rng_clone = rng.clone();
365    let logits_clone = logits.clone();
366    let first_lobprobs_response = if use_async_pool {
367        tokio_rayon::spawn(move || {
368            sampler.sample(
369                logits_clone,
370                &ctx_clone,
371                return_logprobs,
372                rng_clone,
373                sample_speculative,
374                multiple_sequences,
375            )
376        })
377        .await?
378    } else {
379        sampler.sample(
380            logits_clone,
381            &ctx_clone,
382            return_logprobs,
383            rng_clone,
384            sample_speculative,
385            multiple_sequences,
386        )?
387    };
388
389    let bias_if_not_allowed = match &mut seq.recognizer {
390        SequenceRecognizer::Llguidance(ref mut llg) => {
391            if !llg.is_stopped()
392                && llg
393                    .validate_tokens(&[first_lobprobs_response.token])
394                    .unwrap_or(0)
395                    == 1
396            {
397                None
398            } else {
399                let mask = llg.compute_mask_or_eos().map_err(candle_core::Error::msg)?;
400                if mask.is_allowed(first_lobprobs_response.token) {
401                    // shouldn't really happen, except for EOS
402                    None
403                } else {
404                    let mut acc = vec![-f32::INFINITY; logits.shape().dims1().unwrap()];
405                    mask.iter_set_entries(|idx| {
406                        if idx < acc.len() {
407                            acc[idx] = 0.0;
408                        }
409                    });
410
411                    Some(acc)
412                }
413            }
414        }
415        SequenceRecognizer::None => None,
416    };
417    let second_logprobs_response = match bias_if_not_allowed {
418        Some(acc) => {
419            let new_logits = (&logits + Tensor::from_slice(&acc, acc.len(), logits.device())?)?;
420
421            let ctx_clone = seq.get_toks().to_vec();
422            let rng_clone = rng.clone();
423            let sampler = seq.sampler();
424            if use_async_pool {
425                tokio_rayon::spawn(move || {
426                    sampler.sample(
427                        new_logits,
428                        &ctx_clone,
429                        return_logprobs,
430                        rng_clone,
431                        sample_speculative,
432                        multiple_sequences,
433                    )
434                })
435                .await?
436            } else {
437                sampler.sample(
438                    new_logits,
439                    &ctx_clone,
440                    return_logprobs,
441                    rng_clone,
442                    sample_speculative,
443                    multiple_sequences,
444                )?
445            }
446        }
447        None => first_lobprobs_response,
448    };
449
450    match seq.recognizer {
451        SequenceRecognizer::Llguidance(ref mut llg) => {
452            if !llg.is_stopped() {
453                llg.consume_token(second_logprobs_response.token)
454                    .map_err(candle_core::Error::msg)?;
455            }
456        }
457        SequenceRecognizer::None => {}
458    }
459
460    Ok(second_logprobs_response)
461}
462
463#[derive(Clone)]
464pub struct SpeculativeSample {
465    pub sample: Logprobs,
466}
467
468/// Async sample without modifying sequence (except for the constraint).
469pub async fn sample_target_sequence_speculative(
470    logits: Tensor,
471    seq: &mut Sequence,
472    return_logprobs: bool,
473    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
474    draft_samples: &[SpeculativeSample],
475) -> Result<Vec<SpeculativeSample>> {
476    let n_toks = draft_samples.len();
477
478    // first, rollback the llg
479    match &mut seq.recognizer {
480        SequenceRecognizer::Llguidance(ref mut llg) => {
481            llg.rollback(n_toks).map_err(candle_core::Error::msg)?;
482        }
483        SequenceRecognizer::None => {}
484    }
485
486    let mut sampled = Vec::new();
487    for (chunk, draft) in logits
488        .chunk(n_toks, 1)?
489        .into_iter()
490        .zip(draft_samples.iter())
491    {
492        let sample = sample_sequence(
493            chunk,
494            seq,
495            return_logprobs,
496            rng.clone(),
497            true, // TODO(EricLBuehler): does this hurt perf?
498            true,
499            false,
500        )
501        .await?;
502        let sampled_token = sample.token;
503        sampled.push(SpeculativeSample { sample });
504        if sampled_token != draft.sample.token {
505            break;
506        }
507    }
508    Ok(sampled)
509}