mistralrs_core/pipeline/
sampling.rs

1use std::sync::Arc;
2
3use candle_core::{DType, Device, Result, Tensor};
4use rand_isaac::Isaac64Rng;
5
6use crate::{
7    prefix_cacher::PrefixCacheManagerV2,
8    sampler::Logprobs,
9    sequence::{Sequence, SequenceRecognizer, SequenceState, StopReason},
10    tools::parse_text_tools,
11};
12
13use super::Pipeline;
14
15macro_rules! fixup_sentencepiece {
16    ($txt:expr) => {
17        $txt.to_string().replace("▁", " ")
18    };
19    (Option $txt:expr) => {
20        match &$txt {
21            Some(txt) => Some(fixup_sentencepiece!(txt)),
22            None => None,
23        }
24    };
25}
26
27pub(crate) async fn finish_or_add_toks_to_seq(
28    this: &dyn Pipeline,
29    prefix_cacher: &mut PrefixCacheManagerV2,
30    seq: &mut Sequence,
31    logprobs: Logprobs,
32    eos_tok: Option<&[u32]>,
33    use_prefix_cacher: bool,
34) -> Result<()> {
35    let mut is_done = seq.is_done(logprobs.token, eos_tok, this.get_metadata().max_seq_len);
36    seq.add_token(
37        logprobs.clone(),
38        this.get_metadata()
39            .tok_env
40            .as_ref()
41            .ok_or(candle_core::Error::Msg(
42                "`finish_or_add_toks_to_seq` requires the pipeline to have a token trie"
43                    .to_string(),
44            ))?
45            .tok_trie()
46            .decode(&[logprobs.token]),
47        &is_done,
48    );
49
50    // If we can have a tool and we got a tool, stop the sequence early.
51    // Doesn't conflict with the logic below because it does the same thing anyway.
52    if let Some(ref t) = seq.tools {
53        if let Ok(Some(ref d)) = seq.peek_delta() {
54            let (_tool_use_still_possible, tool_use_is_done) =
55                t.prefix_could_be_tool(this, d.as_str())?;
56
57            if tool_use_is_done
58                && matches!(
59                    parse_text_tools(this, d, seq.tools.clone()),
60                    Ok((None, _tools))
61                )
62            {
63                seq.set_state(SequenceState::Done(StopReason::Eos));
64                is_done = Some(StopReason::Eos);
65            }
66        }
67    };
68
69    // Handle streaming requests
70    if seq.get_mut_group().is_streaming {
71        let mut tool_use_still_possible = false;
72        let mut tool_use_is_done = false;
73        if let Some(ref t) = seq.tools {
74            if let Ok(Some(ref d)) = seq.peek_delta() {
75                (tool_use_still_possible, tool_use_is_done) =
76                    t.prefix_could_be_tool(this, d.as_str())?;
77            }
78        };
79
80        if !tool_use_still_possible || tool_use_is_done {
81            if let Some(delta) = crate::handle_seq_error_ok!(seq.get_delta(), seq.responder()) {
82                if seq.get_mut_group().is_chat {
83                    let (text_new, tool_calls) =
84                        parse_text_tools(this, delta.as_str(), seq.tools.clone())
85                            .map_err(candle_core::Error::msg)?;
86
87                    if !tool_calls.is_empty() && is_done.is_none() {
88                        is_done = Some(StopReason::Eos);
89                    };
90                    seq.add_streaming_chunk_choice_to_group(crate::ChunkChoice {
91                        delta: crate::Delta {
92                            content: fixup_sentencepiece!(
93                                Option text_new.map(ToString::to_string)
94                            ),
95                            role: "assistant".to_string(),
96                            tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
97                        },
98                        index: seq.get_response_index(),
99                        finish_reason: is_done.map(|x| x.to_string()),
100                        logprobs: if seq.return_logprobs() {
101                            Some(crate::ResponseLogprob {
102                                token: delta,
103                                bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
104                                logprob: logprobs.logprob,
105                                top_logprobs: logprobs.top_logprobs.unwrap().clone(),
106                            })
107                        } else {
108                            None
109                        },
110                    });
111                } else {
112                    seq.add_streaming_completion_chunk_choice_to_group(
113                        crate::CompletionChunkChoice {
114                            text: fixup_sentencepiece!(delta),
115                            index: seq.get_response_index(),
116                            finish_reason: is_done.map(|x| x.to_string()),
117                            logprobs: if seq.return_logprobs() {
118                                Some(crate::ResponseLogprob {
119                                    token: delta,
120                                    bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
121                                    logprob: logprobs.logprob,
122                                    top_logprobs: logprobs.top_logprobs.unwrap().clone(),
123                                })
124                            } else {
125                                None
126                            },
127                        },
128                    );
129                }
130
131                if let Some(reason) = is_done {
132                    if use_prefix_cacher {
133                        prefix_cacher.add_sequence(seq);
134                        prefix_cacher.evict_to_cpu()?;
135                    }
136                    seq.set_state(crate::sequence::SequenceState::Done(reason));
137                    this.reset_non_granular_state();
138                }
139
140                // Send usage on final chunk.
141                let usage_opt = if is_done.is_some() {
142                    let usage = seq.get_mut_group().get_usage();
143                    seq.get_mut_group().total_prompt_toks = 0;
144                    seq.get_mut_group().total_toks = 0;
145                    Some(usage)
146                } else {
147                    None
148                };
149
150                if seq
151                    .get_mut_group()
152                    .maybe_send_streaming_response(seq, this.name().clone(), usage_opt)
153                    .await
154                    .is_err()
155                {
156                    // If we can't send the response, cancel the sequence
157                    seq.set_state(crate::sequence::SequenceState::Done(
158                        crate::sequence::StopReason::Canceled,
159                    ));
160                    this.reset_non_granular_state();
161                }
162            }
163        }
164    } else if let Some(reason) = is_done {
165        /*
166        ***********************
167        Finish the sequence now
168        ***********************
169        */
170        {
171            seq.set_state(crate::sequence::SequenceState::Done(reason));
172            let (tokenizer, pipeline_name) = {
173                let pipeline_name = this.name();
174                let tokenizer = this.tokenizer();
175                (tokenizer, pipeline_name)
176            };
177
178            let logprobs = if seq.return_logprobs() {
179                let mut logprobs = Vec::new();
180                for logprob in seq.logprobs() {
181                    let resp_logprob = crate::ResponseLogprob {
182                        token: crate::handle_seq_error_ok!(
183                        tokenizer
184                        .as_ref()
185                        .ok_or(candle_core::Error::Msg(
186                            "`finish_or_add_toks_to_seq` requires the pipeline to have a tokenizer"
187                                .to_string(),
188                        ))?.decode(&[logprob.token], false),
189                        seq.responder()
190                    ),
191                        bytes: logprob.bytes.clone().map(|b| b.into_bytes()),
192                        logprob: logprob.logprob,
193                        top_logprobs: logprob.top_logprobs.clone().unwrap(),
194                    };
195                    logprobs.push(resp_logprob);
196                }
197                Some(logprobs)
198            } else {
199                None
200            };
201
202            let text = match reason {
203                crate::sequence::StopReason::Length(_)
204                | crate::sequence::StopReason::ModelLength(_)
205                | crate::sequence::StopReason::Eos
206                | crate::sequence::StopReason::StopTok(_)
207                | crate::sequence::StopReason::Canceled => {
208                    String::from_utf8_lossy(seq.completion_bytes())
209                        .trim_start()
210                        .to_string()
211                }
212                crate::sequence::StopReason::StopString {
213                    completion_bytes_pos,
214                    ..
215                } => {
216                    let txt = String::from_utf8_lossy(seq.completion_bytes());
217                    txt[..completion_bytes_pos].trim_start().to_string()
218                }
219                crate::sequence::StopReason::GeneratedImage => {
220                    candle_core::bail!("Stop reason was `GeneratedImage`.")
221                }
222            };
223
224            if seq.get_mut_group().is_chat {
225                let (text_new, tool_calls) =
226                    parse_text_tools(this, text.as_str(), seq.tools.clone())
227                        .map_err(candle_core::Error::msg)?;
228                let choice = crate::Choice {
229                    finish_reason: fixup_sentencepiece!(reason),
230                    index: seq.get_response_index(),
231                    message: crate::ResponseMessage {
232                        content: text_new.map(ToString::to_string),
233                        role: "assistant".to_string(),
234                        tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
235                    },
236                    logprobs: logprobs.map(|l| crate::Logprobs { content: Some(l) }),
237                };
238                seq.add_choice_to_group(choice);
239            } else {
240                let choice = crate::CompletionChoice {
241                    finish_reason: fixup_sentencepiece!(reason),
242                    index: seq.get_response_index(),
243                    text,
244                    logprobs: None,
245                };
246                seq.add_completion_choice_to_group(choice);
247            }
248
249            if use_prefix_cacher {
250                prefix_cacher.add_sequence(seq);
251                prefix_cacher.evict_to_cpu()?;
252            }
253
254            let group = seq.get_mut_group();
255            if group.is_chat {
256                group
257                    .maybe_send_chat_done_response(
258                        crate::ChatCompletionResponse {
259                            id: seq.id().to_string(),
260                            choices: group.get_choices().to_vec(),
261                            created: seq.creation_time(),
262                            model: pipeline_name,
263                            system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
264                            object: "chat.completion".to_string(),
265                            usage: group.get_usage(),
266                        },
267                        seq.responder(),
268                    )
269                    .await
270                    .map_err(candle_core::Error::msg)?;
271            } else {
272                group
273                    .maybe_send_completion_done_response(
274                        crate::CompletionResponse {
275                            id: seq.id().to_string(),
276                            choices: group.get_completion_choices().to_vec(),
277                            created: seq.creation_time(),
278                            model: pipeline_name,
279                            system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
280                            object: "text_completion".to_string(),
281                            usage: group.get_usage(),
282                        },
283                        seq.responder(),
284                    )
285                    .await
286                    .map_err(candle_core::Error::msg)?;
287            }
288        }
289        this.reset_non_granular_state();
290    }
291
292    Ok(())
293}
294
295pub async fn sample_and_add_toks(
296    this: &dyn Pipeline,
297    seqs: &mut [&mut Sequence],
298    logits_seq: Vec<Tensor>,
299    prefix_cacher: &mut PrefixCacheManagerV2,
300    disable_eos_stop: bool,
301    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
302) -> Result<()> {
303    let seqs_len = seqs.len();
304    debug_assert_eq!(logits_seq.len(), seqs_len);
305
306    let use_async_pool = seqs_len > 1;
307
308    let sampling_futures: Vec<_> = std::iter::zip(logits_seq, seqs.iter_mut())
309        .map(|(logits_per_seq, seq)| {
310            let return_logprobs = seq.return_logprobs();
311            sample_sequence(
312                logits_per_seq,
313                seq,
314                return_logprobs,
315                rng.clone(),
316                use_async_pool,
317                true, // Append result to trie
318                false,
319            )
320        })
321        .collect();
322    let sampled_vec = futures::future::join_all(sampling_futures).await;
323
324    for (sampled, seq) in std::iter::zip(sampled_vec, seqs.iter_mut()) {
325        let next_token = crate::handle_seq_error_stateaware_ok!(sampled, seq);
326
327        let metadata = this.get_metadata();
328        let eos_tok = if disable_eos_stop {
329            None
330        } else {
331            Some(&metadata.eos_tok[..])
332        };
333
334        finish_or_add_toks_to_seq(this, prefix_cacher, seq, next_token, eos_tok, true).await?;
335    }
336
337    Ok(())
338}
339
340/// Async sample optionally adding to trie.
341#[allow(clippy::too_many_arguments)]
342pub async fn sample_sequence(
343    logits: Tensor,
344    seq: &mut Sequence,
345    return_logprobs: bool,
346    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
347    use_async_pool: bool,
348    add_to_trie: bool,
349    sample_speculative: bool,
350) -> Result<Logprobs> {
351    let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
352
353    let sampler = seq.sampler();
354    let ctx_clone = seq.get_toks().to_vec();
355    let rng_clone = rng.clone();
356    let logits_clone = logits.clone();
357    let first_lobprobs_response = if use_async_pool {
358        tokio_rayon::spawn(move || {
359            sampler.sample(
360                logits_clone,
361                &ctx_clone,
362                return_logprobs,
363                rng_clone,
364                sample_speculative,
365            )
366        })
367        .await?
368    } else {
369        sampler.sample(
370            logits_clone,
371            &ctx_clone,
372            return_logprobs,
373            rng_clone,
374            sample_speculative,
375        )?
376    };
377
378    let bias_if_not_allowed = match &mut seq.recognizer {
379        SequenceRecognizer::Llguidance(ref mut llg) => {
380            let step_res = llg.compute_mask().map_err(candle_core::Error::msg)?;
381            if let Some(mask) = &step_res.sample_mask {
382                if mask.is_allowed(first_lobprobs_response.token) {
383                    None
384                } else {
385                    let mut acc = vec![-f32::INFINITY; logits.shape().dims1().unwrap()];
386                    mask.iter_set_entries(|idx| {
387                        if idx < acc.len() {
388                            acc[idx] = 0.0;
389                        }
390                    });
391
392                    Some(acc)
393                }
394            } else if step_res.is_stop() {
395                let mut acc = vec![-f32::INFINITY; logits.shape().dims1().unwrap()];
396                for eos_tok in seq.eos_tokens() {
397                    acc[*eos_tok as usize] = 0.0;
398                }
399                Some(acc)
400            } else {
401                None
402            }
403        }
404        SequenceRecognizer::None => None,
405    };
406    let second_logprobs_response = match bias_if_not_allowed {
407        Some(acc) => {
408            let new_logits = (logits + Tensor::from_slice(&acc, acc.len(), &Device::Cpu)?)?;
409
410            let ctx_clone = seq.get_toks().to_vec();
411            let rng_clone = rng.clone();
412            let sampler = seq.sampler();
413            if use_async_pool {
414                tokio_rayon::spawn(move || {
415                    sampler.sample(
416                        new_logits,
417                        &ctx_clone,
418                        return_logprobs,
419                        rng_clone,
420                        sample_speculative,
421                    )
422                })
423                .await?
424            } else {
425                sampler.sample(
426                    new_logits,
427                    &ctx_clone,
428                    return_logprobs,
429                    rng_clone,
430                    sample_speculative,
431                )?
432            }
433        }
434        None => first_lobprobs_response,
435    };
436
437    if add_to_trie {
438        match seq.recognizer {
439            SequenceRecognizer::Llguidance(ref mut llg) => {
440                llg.commit_token(Some(second_logprobs_response.token))
441                    .map_err(candle_core::Error::msg)?;
442            }
443            SequenceRecognizer::None => {}
444        }
445    }
446    Ok(second_logprobs_response)
447}
448
449#[derive(Clone)]
450pub struct SpeculativeSample {
451    pub sample: Logprobs,
452}
453
454/// Async sample without modifying sequence.
455pub async fn sample_target_sequence_speculative(
456    logits: Tensor,
457    seq: &mut Sequence,
458    return_logprobs: bool,
459    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
460    n_toks: usize,
461) -> Result<Vec<SpeculativeSample>> {
462    let mut sampled = Vec::new();
463    for chunk in logits.chunk(n_toks, 1)? {
464        sampled.push(SpeculativeSample {
465            sample: sample_sequence(
466                chunk,
467                seq,
468                return_logprobs,
469                rng.clone(),
470                true,  // TODO(EricLBuehler): does this hurt perf?
471                false, // Do not append to trie (yet)
472                true,
473            )
474            .await?,
475        });
476    }
477    Ok(sampled)
478}