mistralrs_core/pipeline/
sampling.rs

1use std::sync::Arc;
2
3use candle_core::{DType, Result, Tensor};
4use rand_isaac::Isaac64Rng;
5
6use crate::{
7    prefix_cacher::PrefixCacheManagerV2,
8    sampler::Logprobs,
9    sequence::{Sequence, SequenceRecognizer, SequenceState, StopReason},
10    tools::{parse_text_tools, ToolCallResponse, ToolCallType},
11};
12use mistralrs_mcp::CalledFunction;
13
14use super::Pipeline;
15
16macro_rules! fixup_sentencepiece {
17    ($txt:expr) => {
18        $txt.to_string().replace("▁", " ")
19    };
20    (Option $txt:expr) => {
21        match &$txt {
22            Some(txt) => Some(fixup_sentencepiece!(txt)),
23            None => None,
24        }
25    };
26}
27
28pub(crate) async fn finish_or_add_toks_to_seq(
29    this: &dyn Pipeline,
30    prefix_cacher: &mut PrefixCacheManagerV2,
31    seq: &mut Sequence,
32    logprobs: Logprobs,
33    eos_tok: Option<&[u32]>,
34    use_prefix_cacher: bool,
35) -> Result<()> {
36    let mut is_done = seq.is_done(logprobs.token, eos_tok, this.get_metadata().max_seq_len);
37    seq.add_token(
38        logprobs.clone(),
39        this.get_metadata()
40            .tok_env()
41            .ok_or(candle_core::Error::Msg(
42                "`finish_or_add_toks_to_seq` requires the pipeline to have a token trie"
43                    .to_string(),
44            ))?
45            .tok_trie()
46            .decode(&[logprobs.token]),
47        &is_done,
48    );
49
50    // If we can have a tool and we got a tool, stop the sequence early.
51    // Doesn't conflict with the logic below because it does the same thing anyway.
52    if let Some(ref t) = seq.tools {
53        if let Ok(Some(ref d)) = seq.peek_delta() {
54            let (_tool_use_still_possible, tool_use_is_done) =
55                t.prefix_could_be_tool(this, d.as_str())?;
56
57            if tool_use_is_done
58                && matches!(
59                    parse_text_tools(this, d, seq.tools.clone()),
60                    Ok((None, _tools))
61                )
62            {
63                seq.set_state(SequenceState::Done(StopReason::Eos));
64                is_done = Some(StopReason::Eos);
65            }
66        }
67    };
68
69    // Handle streaming requests
70    if seq.get_mut_group().is_streaming {
71        let mut tool_use_still_possible = false;
72        let mut tool_use_is_done = false;
73        if let Some(ref t) = seq.tools {
74            if let Ok(Some(ref d)) = seq.peek_delta() {
75                (tool_use_still_possible, tool_use_is_done) =
76                    t.prefix_could_be_tool(this, d.as_str())?;
77            }
78        };
79
80        // let send = seq.get_toks().len() % 2 == 0 || is_done.is_some();
81        let send = true;
82        // Send chunks when:
83        // 1. Tool call is not possible (!tool_use_still_possible) - normal streaming
84        // 2. Tool call is complete (tool_use_is_done) - send the tool call
85        // 3. Sequence is done (is_done.is_some()) - send buffered output as text since it wasn't a valid tool call
86        if !tool_use_still_possible || tool_use_is_done || is_done.is_some() {
87            if send {
88                let delta_result = seq.get_delta();
89                if let Some(delta) = crate::handle_seq_error_ok!(delta_result, seq.responder()) {
90                    if seq.get_mut_group().is_chat {
91                        // Check if we're in Harmony mode and use parsed content
92                        let (content_delta, reasoning_delta) = if seq.is_harmony_mode() {
93                            // In Harmony mode, use the parsed final content and reasoning
94                            let final_delta = seq.get_harmony_final_delta();
95                            let reasoning = seq.get_harmony_reasoning_delta();
96                            (final_delta, reasoning)
97                        } else {
98                            // Not in Harmony mode, use raw delta
99                            let (text_new, _) =
100                                parse_text_tools(this, delta.as_str(), seq.tools.clone())
101                                    .map_err(candle_core::Error::msg)?;
102                            (text_new.map(ToString::to_string), None)
103                        };
104
105                        // Detect tool calls
106                        let tool_calls = if seq.is_harmony_mode() {
107                            // In Harmony mode, only finalize tool calls when the sequence is done
108                            // (EOS token or stop string), not when we first detect a tool call.
109                            // This ensures tool call arguments are fully generated.
110                            if is_done.is_some() && seq.has_harmony_tool_calls() {
111                                // Sequence is done and has tool calls - finalize and send them
112                                is_done = Some(StopReason::ToolCalls);
113                                let harmony_tool_calls = seq.get_harmony_tool_calls();
114                                harmony_tool_calls
115                                    .into_iter()
116                                    .enumerate()
117                                    .map(|(i, tc)| ToolCallResponse {
118                                        index: i,
119                                        id: tc.id,
120                                        tp: ToolCallType::Function,
121                                        function: CalledFunction {
122                                            name: tc.name,
123                                            arguments: tc.arguments,
124                                        },
125                                    })
126                                    .collect()
127                            } else {
128                                vec![]
129                            }
130                        } else {
131                            // Not in Harmony mode - parse text for tool calls
132                            let (_, tool_calls) =
133                                parse_text_tools(this, delta.as_str(), seq.tools.clone())
134                                    .map_err(candle_core::Error::msg)?;
135                            if !tool_calls.is_empty() {
136                                is_done = Some(StopReason::ToolCalls);
137                            }
138                            tool_calls
139                        };
140
141                        seq.add_streaming_chunk_choice_to_group(crate::ChunkChoice {
142                            delta: crate::Delta {
143                                content: fixup_sentencepiece!(Option content_delta),
144                                role: "assistant".to_string(),
145                                tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
146                                reasoning_content: reasoning_delta,
147                            },
148                            index: seq.get_response_index(),
149                            finish_reason: is_done.map(|x| x.to_string()),
150                            logprobs: if seq.return_logprobs() {
151                                Some(crate::ResponseLogprob {
152                                    token: delta,
153                                    bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
154                                    logprob: logprobs.logprob,
155                                    top_logprobs: logprobs.top_logprobs.unwrap().clone(),
156                                })
157                            } else {
158                                None
159                            },
160                        });
161                    } else {
162                        seq.add_streaming_completion_chunk_choice_to_group(
163                            crate::CompletionChunkChoice {
164                                text: fixup_sentencepiece!(delta),
165                                index: seq.get_response_index(),
166                                finish_reason: is_done.map(|x| x.to_string()),
167                                logprobs: if seq.return_logprobs() {
168                                    Some(crate::ResponseLogprob {
169                                        token: delta,
170                                        bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
171                                        logprob: logprobs.logprob,
172                                        top_logprobs: logprobs.top_logprobs.unwrap().clone(),
173                                    })
174                                } else {
175                                    None
176                                },
177                            },
178                        );
179                    }
180                }
181            }
182
183            // Send usage on final chunk.
184            let usage_opt = if is_done.is_some() {
185                let usage = seq.get_mut_group().get_usage();
186                seq.get_mut_group().total_prompt_toks = 0;
187                seq.get_mut_group().total_toks = 0;
188                Some(usage)
189            } else {
190                None
191            };
192
193            if seq
194                .get_mut_group()
195                .maybe_send_streaming_response(seq, this.name().clone(), usage_opt)
196                .await
197                .is_err()
198            {
199                // If we can't send the response, cancel the sequence
200                seq.set_state(crate::sequence::SequenceState::Done(
201                    crate::sequence::StopReason::Canceled,
202                ));
203                this.reset_non_granular_state();
204            }
205        }
206
207        // Handle Done state regardless of tool detection - must be outside the tool_use check
208        // to ensure sequence completes even when tool detection thinks output might be a tool call
209        if let Some(reason) = is_done {
210            if use_prefix_cacher {
211                prefix_cacher.add_sequence(seq);
212                prefix_cacher.evict_caches()?;
213            }
214            seq.set_state(crate::sequence::SequenceState::Done(reason));
215            this.reset_non_granular_state();
216        }
217    } else if let Some(mut reason) = is_done {
218        /*
219        ***********************
220        Finish the sequence now
221        ***********************
222        */
223        {
224            seq.set_state(crate::sequence::SequenceState::Done(reason));
225            let (tokenizer, pipeline_name) = {
226                let pipeline_name = this.name();
227                let tokenizer = this.tokenizer();
228                (tokenizer, pipeline_name)
229            };
230
231            let logprobs = if seq.return_logprobs() {
232                let mut logprobs = Vec::new();
233                for logprob in seq.logprobs() {
234                    let resp_logprob = crate::ResponseLogprob {
235                        token: crate::handle_seq_error_ok!(
236                        tokenizer
237                        .as_ref()
238                        .ok_or(candle_core::Error::Msg(
239                            "`finish_or_add_toks_to_seq` requires the pipeline to have a tokenizer"
240                                .to_string(),
241                        ))?.decode(&[logprob.token], false),
242                        seq.responder()
243                    ),
244                        bytes: logprob.bytes.clone().map(|b| b.into_bytes()),
245                        logprob: logprob.logprob,
246                        top_logprobs: logprob.top_logprobs.clone().unwrap(),
247                    };
248                    logprobs.push(resp_logprob);
249                }
250                Some(logprobs)
251            } else {
252                None
253            };
254
255            // Signal EOS to Harmony parser if in Harmony mode
256            seq.harmony_process_eos();
257
258            let text = match reason {
259                crate::sequence::StopReason::Length(_)
260                | crate::sequence::StopReason::ModelLength(_)
261                | crate::sequence::StopReason::Eos
262                | crate::sequence::StopReason::StopTok(_)
263                | crate::sequence::StopReason::Canceled
264                | crate::sequence::StopReason::ToolCalls => {
265                    String::from_utf8_lossy(seq.completion_bytes())
266                        .trim_start()
267                        .to_string()
268                }
269                crate::sequence::StopReason::StopString {
270                    completion_bytes_pos,
271                    ..
272                } => {
273                    let txt = String::from_utf8_lossy(seq.completion_bytes());
274                    txt[..completion_bytes_pos].trim_start().to_string()
275                }
276                crate::sequence::StopReason::GeneratedImage
277                | crate::sequence::StopReason::GeneratedSpeech => {
278                    candle_core::bail!("Stop reason was `GeneratedImage`.")
279                }
280            };
281
282            if seq.get_mut_group().is_chat {
283                // In Harmony mode, use Harmony's parsed content and tool calls
284                let (text_new, tool_calls, reasoning_content) = if seq.is_harmony_mode() {
285                    let final_content = seq.get_harmony_final_content();
286                    let reasoning = seq.get_harmony_reasoning_content();
287
288                    // Get Harmony tool calls
289                    let harmony_tool_calls = seq.get_harmony_tool_calls();
290                    let tool_calls: Vec<ToolCallResponse> = harmony_tool_calls
291                        .into_iter()
292                        .enumerate()
293                        .map(|(i, tc)| ToolCallResponse {
294                            index: i,
295                            id: tc.id,
296                            tp: ToolCallType::Function,
297                            function: CalledFunction {
298                                name: tc.name,
299                                arguments: tc.arguments,
300                            },
301                        })
302                        .collect();
303
304                    (final_content, tool_calls, reasoning)
305                } else {
306                    // Not in Harmony mode - parse text for tool calls
307                    let (text_new, tool_calls) =
308                        parse_text_tools(this, text.as_str(), seq.tools.clone())
309                            .map_err(candle_core::Error::msg)?;
310                    (text_new.map(ToString::to_string), tool_calls, None)
311                };
312
313                if !tool_calls.is_empty() {
314                    reason = StopReason::ToolCalls;
315                }
316
317                let choice = crate::Choice {
318                    finish_reason: fixup_sentencepiece!(reason),
319                    index: seq.get_response_index(),
320                    message: crate::ResponseMessage {
321                        content: text_new,
322                        role: "assistant".to_string(),
323                        tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
324                        reasoning_content,
325                    },
326                    logprobs: logprobs.map(|l| crate::Logprobs { content: Some(l) }),
327                };
328                seq.add_choice_to_group(choice);
329            } else {
330                let choice = crate::CompletionChoice {
331                    finish_reason: fixup_sentencepiece!(reason),
332                    index: seq.get_response_index(),
333                    text,
334                    logprobs: None,
335                };
336                seq.add_completion_choice_to_group(choice);
337            }
338
339            if use_prefix_cacher {
340                prefix_cacher.add_sequence(seq);
341                prefix_cacher.evict_caches()?;
342            }
343
344            let group = seq.get_mut_group();
345            if group.is_chat {
346                group
347                    .maybe_send_chat_done_response(
348                        crate::ChatCompletionResponse {
349                            id: seq.id().to_string(),
350                            choices: group.get_choices().to_vec(),
351                            created: seq.creation_time(),
352                            model: pipeline_name,
353                            system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
354                            object: "chat.completion".to_string(),
355                            usage: group.get_usage(),
356                        },
357                        seq.responder(),
358                    )
359                    .await
360                    .map_err(candle_core::Error::msg)?;
361            } else {
362                group
363                    .maybe_send_completion_done_response(
364                        crate::CompletionResponse {
365                            id: seq.id().to_string(),
366                            choices: group.get_completion_choices().to_vec(),
367                            created: seq.creation_time(),
368                            model: pipeline_name,
369                            system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
370                            object: "text_completion".to_string(),
371                            usage: group.get_usage(),
372                        },
373                        seq.responder(),
374                    )
375                    .await
376                    .map_err(candle_core::Error::msg)?;
377            }
378        }
379        this.reset_non_granular_state();
380    }
381
382    Ok(())
383}
384
385pub async fn sample_and_add_toks(
386    this: &dyn Pipeline,
387    seqs: &mut [&mut Sequence],
388    logits_seq: Vec<Tensor>,
389    prefix_cacher: &mut PrefixCacheManagerV2,
390    disable_eos_stop: bool,
391    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
392) -> Result<()> {
393    let seqs_len = seqs.len();
394    debug_assert_eq!(logits_seq.len(), seqs_len);
395
396    let use_async_pool = seqs_len > 1;
397
398    let sampling_futures: Vec<_> = std::iter::zip(logits_seq, seqs.iter_mut())
399        .map(|(logits_per_seq, seq)| {
400            let return_logprobs = seq.return_logprobs();
401            sample_sequence(
402                logits_per_seq,
403                seq,
404                return_logprobs,
405                rng.clone(),
406                use_async_pool,
407                false,
408                use_async_pool,
409            )
410        })
411        .collect();
412    let sampled_vec = futures::future::join_all(sampling_futures).await;
413
414    for (sampled, seq) in std::iter::zip(sampled_vec, seqs.iter_mut()) {
415        let next_token = crate::handle_seq_error_stateaware_ok!(sampled, seq);
416
417        let metadata = this.get_metadata();
418        let eos_tok = if disable_eos_stop {
419            None
420        } else {
421            Some(&metadata.eos_tok[..])
422        };
423
424        finish_or_add_toks_to_seq(this, prefix_cacher, seq, next_token, eos_tok, true).await?;
425    }
426
427    Ok(())
428}
429
430/// Async sample optionally adding to trie.
431#[allow(clippy::too_many_arguments)]
432pub async fn sample_sequence(
433    logits: Tensor,
434    seq: &mut Sequence,
435    return_logprobs: bool,
436    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
437    use_async_pool: bool,
438    sample_speculative: bool,
439    multiple_sequences: bool,
440) -> Result<Logprobs> {
441    let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
442
443    let sampler = seq.sampler();
444    let ctx_clone = seq.get_toks().to_vec();
445    let rng_clone = rng.clone();
446    let logits_clone = logits.clone();
447    let first_lobprobs_response = if use_async_pool {
448        tokio_rayon::spawn(move || {
449            sampler.sample(
450                logits_clone,
451                &ctx_clone,
452                return_logprobs,
453                rng_clone,
454                sample_speculative,
455                multiple_sequences,
456            )
457        })
458        .await?
459    } else {
460        sampler.sample(
461            logits_clone,
462            &ctx_clone,
463            return_logprobs,
464            rng_clone,
465            sample_speculative,
466            multiple_sequences,
467        )?
468    };
469
470    let bias_if_not_allowed = match &mut seq.recognizer {
471        SequenceRecognizer::Llguidance(ref mut llg) => {
472            if !llg.is_stopped()
473                && llg
474                    .validate_tokens(&[first_lobprobs_response.token])
475                    .unwrap_or(0)
476                    == 1
477            {
478                None
479            } else {
480                let mask = llg.compute_mask_or_eos().map_err(candle_core::Error::msg)?;
481                if mask.is_allowed(first_lobprobs_response.token) {
482                    // shouldn't really happen, except for EOS
483                    None
484                } else {
485                    let mut acc = vec![-f32::INFINITY; logits.shape().dims1().unwrap()];
486                    mask.iter_set_entries(|idx| {
487                        if idx < acc.len() {
488                            acc[idx] = 0.0;
489                        }
490                    });
491
492                    Some(acc)
493                }
494            }
495        }
496        SequenceRecognizer::None => None,
497    };
498    let second_logprobs_response = match bias_if_not_allowed {
499        Some(acc) => {
500            let new_logits = (&logits + Tensor::from_slice(&acc, acc.len(), logits.device())?)?;
501
502            let ctx_clone = seq.get_toks().to_vec();
503            let rng_clone = rng.clone();
504            let sampler = seq.sampler();
505            if use_async_pool {
506                tokio_rayon::spawn(move || {
507                    sampler.sample(
508                        new_logits,
509                        &ctx_clone,
510                        return_logprobs,
511                        rng_clone,
512                        sample_speculative,
513                        multiple_sequences,
514                    )
515                })
516                .await?
517            } else {
518                sampler.sample(
519                    new_logits,
520                    &ctx_clone,
521                    return_logprobs,
522                    rng_clone,
523                    sample_speculative,
524                    multiple_sequences,
525                )?
526            }
527        }
528        None => first_lobprobs_response,
529    };
530
531    match seq.recognizer {
532        SequenceRecognizer::Llguidance(ref mut llg) => {
533            if !llg.is_stopped() {
534                llg.consume_token(second_logprobs_response.token)
535                    .map_err(candle_core::Error::msg)?;
536            }
537        }
538        SequenceRecognizer::None => {}
539    }
540
541    Ok(second_logprobs_response)
542}
543
544#[derive(Clone)]
545pub struct SpeculativeSample {
546    pub sample: Logprobs,
547}
548
549/// Async sample without modifying sequence (except for the constraint).
550pub async fn sample_target_sequence_speculative(
551    logits: Tensor,
552    seq: &mut Sequence,
553    return_logprobs: bool,
554    rng: Arc<std::sync::Mutex<Isaac64Rng>>,
555    draft_samples: &[SpeculativeSample],
556) -> Result<Vec<SpeculativeSample>> {
557    let n_toks = draft_samples.len();
558
559    // first, rollback the llg
560    match &mut seq.recognizer {
561        SequenceRecognizer::Llguidance(ref mut llg) => {
562            llg.rollback(n_toks).map_err(candle_core::Error::msg)?;
563        }
564        SequenceRecognizer::None => {}
565    }
566
567    let mut sampled = Vec::new();
568    for (chunk, draft) in logits
569        .chunk(n_toks, 1)?
570        .into_iter()
571        .zip(draft_samples.iter())
572    {
573        let sample = sample_sequence(
574            chunk,
575            seq,
576            return_logprobs,
577            rng.clone(),
578            true, // TODO(EricLBuehler): does this hurt perf?
579            true,
580            false,
581        )
582        .await?;
583        let sampled_token = sample.token;
584        sampled.push(SpeculativeSample { sample });
585        if sampled_token != draft.sample.token {
586            break;
587        }
588    }
589    Ok(sampled)
590}