mistralrs_core/pipeline/
sampling.rs

1use std::sync::Arc;
2
3use candle_core::{DType, Device, Result, Tensor};
4use rand_isaac::Isaac64Rng;
5
6use crate::{
7    prefix_cacher::PrefixCacheManagerV2,
8    sampler::Logprobs,
9    sequence::{Sequence, SequenceRecognizer, SequenceState, StopReason},
10    tools::parse_text_tools,
11};
12
13use super::Pipeline;
14
15macro_rules! fixup_sentencepiece {
16    ($txt:expr) => {
17        $txt.to_string().replace("▁", " ")
18    };
19    (Option $txt:expr) => {
20        match &$txt {
21            Some(txt) => Some(fixup_sentencepiece!(txt)),
22            None => None,
23        }
24    };
25}
26
27pub(crate) async fn finish_or_add_toks_to_seq(
28    this: &dyn Pipeline,
29    prefix_cacher: &mut PrefixCacheManagerV2,
30    seq: &mut Sequence,
31    logprobs: Logprobs,
32    eos_tok: Option<&[u32]>,
33    use_prefix_cacher: bool,
34) -> Result<()> {
35    // Cache metadata lookup
36    let meta = this.get_metadata();
37    let max_len = meta.max_seq_len;
38    let mut is_done = seq.is_done(logprobs.token, eos_tok, max_len);
39    seq.add_token(
40        logprobs.clone(),
41        this.get_metadata()
42            .tok_env()
43            .ok_or(candle_core::Error::Msg(
44                "`finish_or_add_toks_to_seq` requires the pipeline to have a token trie"
45                    .to_string(),
46            ))?
47            .tok_trie()
48            .decode(&[logprobs.token]),
49        &is_done,
50    );
51
52    // If we can have a tool and we got a tool, stop the sequence early.
53    // Doesn't conflict with the logic below because it does the same thing anyway.
54    if let Some(ref t) = seq.tools {
55        if let Ok(Some(ref d)) = seq.peek_delta() {
56            let (_tool_use_still_possible, tool_use_is_done) =
57                t.prefix_could_be_tool(this, d.as_str())?;
58
59            if tool_use_is_done
60                && matches!(
61                    parse_text_tools(this, d, seq.tools.clone()),
62                    Ok((None, _tools))
63                )
64            {
65                seq.set_state(SequenceState::Done(StopReason::Eos));
66                is_done = Some(StopReason::Eos);
67            }
68        }
69    };
70
71    // Handle streaming requests
72    if seq.get_mut_group().is_streaming {
73        let mut tool_use_still_possible = false;
74        let mut tool_use_is_done = false;
75        if let Some(ref t) = seq.tools {
76            if let Ok(Some(ref d)) = seq.peek_delta() {
77                (tool_use_still_possible, tool_use_is_done) =
78                    t.prefix_could_be_tool(this, d.as_str())?;
79            }
80        };
81
82        let send = seq.get_toks().len() % 2 == 0 || is_done.is_some();
83        if !tool_use_still_possible || tool_use_is_done {
84            if send {
85                if let Some(delta) = crate::handle_seq_error_ok!(seq.get_delta(), seq.responder()) {
86                    if seq.get_mut_group().is_chat {
87                        let (text_new, tool_calls) =
88                            parse_text_tools(this, delta.as_str(), seq.tools.clone())
89                                .map_err(candle_core::Error::msg)?;
90
91                        if !tool_calls.is_empty() && is_done.is_none() {
92                            is_done = Some(StopReason::Eos);
93                        };
94                        seq.add_streaming_chunk_choice_to_group(crate::ChunkChoice {
95                            delta: crate::Delta {
96                                content: fixup_sentencepiece!(
97                                    Option text_new.map(ToString::to_string)
98                                ),
99                                role: "assistant".to_string(),
100                                tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
101                            },
102                            index: seq.get_response_index(),
103                            finish_reason: is_done.map(|x| x.to_string()),
104                            logprobs: if seq.return_logprobs() {
105                                Some(crate::ResponseLogprob {
106                                    token: delta,
107                                    bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
108                                    logprob: logprobs.logprob,
109                                    top_logprobs: logprobs.top_logprobs.unwrap().clone(),
110                                })
111                            } else {
112                                None
113                            },
114                        });
115                    } else {
116                        seq.add_streaming_completion_chunk_choice_to_group(
117                            crate::CompletionChunkChoice {
118                                text: fixup_sentencepiece!(delta),
119                                index: seq.get_response_index(),
120                                finish_reason: is_done.map(|x| x.to_string()),
121                                logprobs: if seq.return_logprobs() {
122                                    Some(crate::ResponseLogprob {
123                                        token: delta,
124                                        bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
125                                        logprob: logprobs.logprob,
126                                        top_logprobs: logprobs.top_logprobs.unwrap().clone(),
127                                    })
128                                } else {
129                                    None
130                                },
131                            },
132                        );
133                    }
134                }
135            }
136
137            if let Some(reason) = is_done {
138                if use_prefix_cacher {
139                    prefix_cacher.add_sequence(seq);
140                    prefix_cacher.evict_to_cpu()?;
141                }
142                seq.set_state(crate::sequence::SequenceState::Done(reason));
143                this.reset_non_granular_state();
144            }
145
146            // Send usage on final chunk.
147            let usage_opt = if is_done.is_some() {
148                let usage = seq.get_mut_group().get_usage();
149                seq.get_mut_group().total_prompt_toks = 0;
150                seq.get_mut_group().total_toks = 0;
151                Some(usage)
152            } else {
153                None
154            };
155
156            if seq
157                .get_mut_group()
158                .maybe_send_streaming_response(seq, this.name().clone(), usage_opt)
159                .await
160                .is_err()
161            {
162                // If we can't send the response, cancel the sequence
163                seq.set_state(crate::sequence::SequenceState::Done(
164                    crate::sequence::StopReason::Canceled,
165                ));
166                this.reset_non_granular_state();
167            }
168        }
169    } else if let Some(reason) = is_done {
170        /*
171        ***********************
172        Finish the sequence now
173        ***********************
174        */
175        {
176            seq.set_state(crate::sequence::SequenceState::Done(reason));
177            let (tokenizer, pipeline_name) = {
178                let pipeline_name = this.name();
179                let tokenizer = this.tokenizer();
180                (tokenizer, pipeline_name)
181            };
182
183            let logprobs = if seq.return_logprobs() {
184                let mut logprobs = Vec::new();
185                for logprob in seq.logprobs() {
186                    let resp_logprob = crate::ResponseLogprob {
187                        token: crate::handle_seq_error_ok!(
188                        tokenizer
189                        .as_ref()
190                        .ok_or(candle_core::Error::Msg(
191                            "`finish_or_add_toks_to_seq` requires the pipeline to have a tokenizer"
192                                .to_string(),
193                        ))?.decode(&[logprob.token], false),
194                        seq.responder()
195                    ),
196                        bytes: logprob.bytes.clone().map(|b| b.into_bytes()),
197                        logprob: logprob.logprob,
198                        top_logprobs: logprob.top_logprobs.clone().unwrap(),
199                    };
200                    logprobs.push(resp_logprob);
201                }
202                Some(logprobs)
203            } else {
204                None
205            };
206
207            let text = match reason {
208                crate::sequence::StopReason::Length(_)
209                | crate::sequence::StopReason::ModelLength(_)
210                | crate::sequence::StopReason::Eos
211                | crate::sequence::StopReason::StopTok(_)
212                | crate::sequence::StopReason::Canceled => {
213                    String::from_utf8_lossy(seq.completion_bytes())
214                        .trim_start()
215                        .to_string()
216                }
217                crate::sequence::StopReason::StopString {
218                    completion_bytes_pos,
219                    ..
220                } => {
221                    let txt = String::from_utf8_lossy(seq.completion_bytes());
222                    txt[..completion_bytes_pos].trim_start().to_string()
223                }
224                crate::sequence::StopReason::GeneratedImage => {
225                    candle_core::bail!("Stop reason was `GeneratedImage`.")
226                }
227            };
228
229            if seq.get_mut_group().is_chat {
230                let (text_new, tool_calls) =
231                    parse_text_tools(this, text.as_str(), seq.tools.clone())
232                        .map_err(candle_core::Error::msg)?;
233                let choice = crate::Choice {
234                    finish_reason: fixup_sentencepiece!(reason),
235                    index: seq.get_response_index(),
236                    message: crate::ResponseMessage {
237                        content: text_new.map(ToString::to_string),
238                        role: "assistant".to_string(),
239                        tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
240                    },
241                    logprobs: logprobs.map(|l| crate::Logprobs { content: Some(l) }),
242                };
243                seq.add_choice_to_group(choice);
244            } else {
245                let choice = crate::CompletionChoice {
246                    finish_reason: fixup_sentencepiece!(reason),
247                    index: seq.get_response_index(),
248                    text,
249                    logprobs: None,
250                };
251                seq.add_completion_choice_to_group(choice);
252            }
253
254            if use_prefix_cacher {
255                prefix_cacher.add_sequence(seq);
256                prefix_cacher.evict_to_cpu()?;
257            }
258
259            let group = seq.get_mut_group();
260            if group.is_chat {
261                group
262                    .maybe_send_chat_done_response(
263                        crate::ChatCompletionResponse {
264                            id: seq.id().to_string(),
265                            choices: group.get_choices().to_vec(),
266                            created: seq.creation_time(),
267                            model: pipeline_name,
268                            system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
269                            object: "chat.completion".to_string(),
270                            usage: group.get_usage(),
271                        },
272                        seq.responder(),
273                    )
274                    .await
275                    .map_err(candle_core::Error::msg)?;
276            } else {
277                group
278                    .maybe_send_completion_done_response(
279                        crate::CompletionResponse {
280                            id: seq.id().to_string(),
281                            choices: group.get_completion_choices().to_vec(),
282                            created: seq.creation_time(),
283                            model: pipeline_name,
284                            system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
285                            object: "text_completion".to_string(),
286                            usage: group.get_usage(),
287                        },
288                        seq.responder(),
289                    )
290                    .await
291                    .map_err(candle_core::Error::msg)?;
292            }
293        }
294        this.reset_non_granular_state();
295    }
296
297    Ok(())
298}
299
300pub async fn sample_and_add_toks(
301    this: &dyn Pipeline,
302    seqs: &mut [&mut Sequence],
303    logits_seq: Vec<Tensor>,
304    prefix_cacher: &mut PrefixCacheManagerV2,
305    disable_eos_stop: bool,
306    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
307) -> Result<()> {
308    let seqs_len = seqs.len();
309    debug_assert_eq!(logits_seq.len(), seqs_len);
310
311    let use_async_pool = seqs_len > 1;
312
313    let sampling_futures: Vec<_> = std::iter::zip(logits_seq, seqs.iter_mut())
314        .map(|(logits_per_seq, seq)| {
315            let return_logprobs = seq.return_logprobs();
316            sample_sequence(
317                logits_per_seq,
318                seq,
319                return_logprobs,
320                rng.clone(),
321                use_async_pool,
322                false,
323            )
324        })
325        .collect();
326    let sampled_vec = futures::future::join_all(sampling_futures).await;
327
328    for (sampled, seq) in std::iter::zip(sampled_vec, seqs.iter_mut()) {
329        let next_token = crate::handle_seq_error_stateaware_ok!(sampled, seq);
330
331        let metadata = this.get_metadata();
332        let eos_tok = if disable_eos_stop {
333            None
334        } else {
335            Some(&metadata.eos_tok[..])
336        };
337
338        finish_or_add_toks_to_seq(this, prefix_cacher, seq, next_token, eos_tok, true).await?;
339    }
340
341    Ok(())
342}
343
344/// Async sample optionally adding to trie.
345#[allow(clippy::too_many_arguments)]
346pub async fn sample_sequence(
347    logits: Tensor,
348    seq: &mut Sequence,
349    return_logprobs: bool,
350    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
351    use_async_pool: bool,
352    sample_speculative: bool,
353) -> Result<Logprobs> {
354    let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
355
356    let sampler = seq.sampler();
357    let ctx_clone = seq.get_toks().to_vec();
358    let rng_clone = rng.clone();
359    let logits_clone = logits.clone();
360    let first_lobprobs_response = if use_async_pool {
361        tokio_rayon::spawn(move || {
362            sampler.sample(
363                logits_clone,
364                &ctx_clone,
365                return_logprobs,
366                rng_clone,
367                sample_speculative,
368            )
369        })
370        .await?
371    } else {
372        sampler.sample(
373            logits_clone,
374            &ctx_clone,
375            return_logprobs,
376            rng_clone,
377            sample_speculative,
378        )?
379    };
380
381    let bias_if_not_allowed = match &mut seq.recognizer {
382        SequenceRecognizer::Llguidance(ref mut llg) => {
383            if !llg.is_stopped()
384                && llg
385                    .validate_tokens(&[first_lobprobs_response.token])
386                    .unwrap_or(0)
387                    == 1
388            {
389                None
390            } else {
391                let mask = llg.compute_mask_or_eos().map_err(candle_core::Error::msg)?;
392                if mask.is_allowed(first_lobprobs_response.token) {
393                    // shouldn't really happen, except for EOS
394                    None
395                } else {
396                    let mut acc = vec![-f32::INFINITY; logits.shape().dims1().unwrap()];
397                    mask.iter_set_entries(|idx| {
398                        if idx < acc.len() {
399                            acc[idx] = 0.0;
400                        }
401                    });
402
403                    Some(acc)
404                }
405            }
406        }
407        SequenceRecognizer::None => None,
408    };
409    let second_logprobs_response = match bias_if_not_allowed {
410        Some(acc) => {
411            let new_logits = (logits + Tensor::from_slice(&acc, acc.len(), &Device::Cpu)?)?;
412
413            let ctx_clone = seq.get_toks().to_vec();
414            let rng_clone = rng.clone();
415            let sampler = seq.sampler();
416            if use_async_pool {
417                tokio_rayon::spawn(move || {
418                    sampler.sample(
419                        new_logits,
420                        &ctx_clone,
421                        return_logprobs,
422                        rng_clone,
423                        sample_speculative,
424                    )
425                })
426                .await?
427            } else {
428                sampler.sample(
429                    new_logits,
430                    &ctx_clone,
431                    return_logprobs,
432                    rng_clone,
433                    sample_speculative,
434                )?
435            }
436        }
437        None => first_lobprobs_response,
438    };
439
440    match seq.recognizer {
441        SequenceRecognizer::Llguidance(ref mut llg) => {
442            if !llg.is_stopped() {
443                llg.consume_token(second_logprobs_response.token)
444                    .map_err(candle_core::Error::msg)?;
445            }
446        }
447        SequenceRecognizer::None => {}
448    }
449
450    Ok(second_logprobs_response)
451}
452
453#[derive(Clone)]
454pub struct SpeculativeSample {
455    pub sample: Logprobs,
456}
457
458/// Async sample without modifying sequence (except for the constraint).
459pub async fn sample_target_sequence_speculative(
460    logits: Tensor,
461    seq: &mut Sequence,
462    return_logprobs: bool,
463    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
464    draft_samples: &[SpeculativeSample],
465) -> Result<Vec<SpeculativeSample>> {
466    let n_toks = draft_samples.len();
467
468    // first, rollback the llg
469    match &mut seq.recognizer {
470        SequenceRecognizer::Llguidance(ref mut llg) => {
471            llg.rollback(n_toks).map_err(candle_core::Error::msg)?;
472        }
473        SequenceRecognizer::None => {}
474    }
475
476    let mut sampled = Vec::new();
477    for (chunk, draft) in logits
478        .chunk(n_toks, 1)?
479        .into_iter()
480        .zip(draft_samples.iter())
481    {
482        let sample = sample_sequence(
483            chunk,
484            seq,
485            return_logprobs,
486            rng.clone(),
487            true, // TODO(EricLBuehler): does this hurt perf?
488            true,
489        )
490        .await?;
491        let sampled_token = sample.token;
492        sampled.push(SpeculativeSample { sample });
493        if sampled_token != draft.sample.token {
494            break;
495        }
496    }
497    Ok(sampled)
498}