mistralrs_core/
sampler.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4    collections::{HashMap, HashSet},
5    sync::{Arc, Mutex},
6};
7
8use candle_core::{DType, Device, Error, Result, Tensor, D};
9use mistralrs_quant::{CumSumOp, SortOp};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use once_cell::sync::Lazy;
14use rand::distr::{weighted::WeightedIndex, Distribution};
15use rand_isaac::Isaac64Rng;
16use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
17use serde::{Deserialize, Serialize};
18use tokenizers::Tokenizer;
19
20static DRY_SEQUENCE_BREAKERS: Lazy<Vec<String>> =
21    Lazy::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
24/// Stop sequences or ids.
25pub enum StopTokens {
26    Seqs(Vec<String>),
27    Ids(Vec<u32>),
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
31/// Sampling params are used to control sampling.
32pub struct SamplingParams {
33    pub temperature: Option<f64>,
34    pub top_k: Option<usize>,
35    pub top_p: Option<f64>,
36    pub min_p: Option<f64>,
37    pub top_n_logprobs: usize,
38    pub frequency_penalty: Option<f32>,
39    pub presence_penalty: Option<f32>,
40    pub stop_toks: Option<StopTokens>,
41    pub max_len: Option<usize>,
42    pub logits_bias: Option<HashMap<u32, f32>>,
43    pub n_choices: usize,
44    pub dry_params: Option<DrySamplingParams>,
45}
46
47impl SamplingParams {
48    /// This sets up the parameters so that there is:
49    /// - No temperature, topk, topp, minp
50    /// - No penalties, stop tokens, or logit bias
51    /// - No maximum length
52    pub fn deterministic() -> Self {
53        Self {
54            temperature: None,
55            top_k: Some(1),
56            top_p: None,
57            min_p: None,
58            top_n_logprobs: 0,
59            frequency_penalty: None,
60            presence_penalty: None,
61            stop_toks: None,
62            max_len: None,
63            logits_bias: None,
64            n_choices: 1,
65            dry_params: None,
66        }
67    }
68}
69
70#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct DrySamplingParams {
72    pub sequence_breakers: Vec<String>,
73    pub multiplier: f32,
74    pub base: f32,
75    pub allowed_length: usize,
76}
77
78impl DrySamplingParams {
79    pub fn new_with_defaults(
80        multiplier: f32,
81        sequence_breakers: Option<Vec<String>>,
82        base: Option<f32>,
83        allowed_length: Option<usize>,
84    ) -> anyhow::Result<Self> {
85        Ok(Self {
86            base: base.unwrap_or(1.75),
87            allowed_length: allowed_length.unwrap_or(2),
88            sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
89            multiplier,
90        })
91    }
92}
93
94impl Default for DrySamplingParams {
95    fn default() -> Self {
96        Self {
97            multiplier: 0.0,
98            base: 1.75,
99            allowed_length: 2,
100            sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
101        }
102    }
103}
104
105#[derive(Clone, Debug)]
106struct DrySamplingParamsInner {
107    pub sequence_breakers: HashSet<u32>,
108    pub multiplier: f32,
109    pub base: f32,
110    pub allowed_length: usize,
111}
112
113impl DrySamplingParamsInner {
114    pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
115        Ok(Self {
116            base: other.base,
117            allowed_length: other.allowed_length,
118            sequence_breakers: HashSet::from_iter(
119                other
120                    .sequence_breakers
121                    .into_iter()
122                    .map(|breaker| {
123                        tokenizer
124                            // Prefix with 'a' to get the correct encoding of the token at the end of a text.
125                            //
126                            // FIXME: This is a hack. See https://github.com/LostRuins/koboldcpp/pull/982
127                            //        for the correct solution which covers multi-token sequence breakers
128                            //        and ambiguous encodings.
129                            .encode_fast(["a", &breaker].concat(), true)
130                            .map_err(anyhow::Error::msg)
131                            .map(|enc| {
132                                let ids = enc.get_ids();
133                                if !ids.is_empty() {
134                                    None
135                                } else {
136                                    Some(ids[ids.len() - 1])
137                                }
138                            })
139                    })
140                    .collect::<anyhow::Result<Vec<_>>>()?
141                    .into_iter()
142                    .flatten()
143                    .collect::<Vec<_>>(),
144            ),
145            multiplier: other.multiplier,
146        })
147    }
148}
149
150/// Customizable logits processor.
151///
152/// # Example
153/// ```rust
154/// use std::{sync::Arc, ops::Mul};
155/// use mistralrs_core::CustomLogitsProcessor;
156/// use candle_core::{Result, Tensor};
157///
158/// struct ThresholdLogitsProcessor;
159/// impl CustomLogitsProcessor for ThresholdLogitsProcessor {
160///     fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
161///         // Mask is 1 for true, 0 for false.
162///         let mask = logits.ge(0.5)?;
163///         logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
164///     }
165/// }
166/// let processor1: Arc<dyn CustomLogitsProcessor> = Arc::new(|logits: &Tensor, _context: &[u32]| logits * 1.23);
167/// let processor2: Arc<dyn CustomLogitsProcessor> = Arc::new(ThresholdLogitsProcessor);
168/// ```
169pub trait CustomLogitsProcessor: Send + Sync {
170    /// Logits and sequence context (prompt and generated tokens), returning modified tokens.
171    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
172}
173
174impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
175    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
176        self(logits, context)
177    }
178}
179
180/// Sampler for sampling.
181#[derive(Clone)]
182pub struct Sampler {
183    temperature: Option<f64>,
184    top_n_logprobs: usize,
185    tokenizer: Option<Arc<Tokenizer>>,
186    frequency_penalty: Option<f32>,
187    presence_penalty: Option<f32>,
188    dry_params: Option<DrySamplingParamsInner>,
189    top_k: i64,
190    top_p: f64,
191    min_p: f64,
192    logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
193}
194
195#[cfg_attr(feature = "pyo3_macros", pyclass)]
196#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
198/// Top-n logprobs element
199pub struct TopLogprob {
200    pub token: u32,
201    pub logprob: f32,
202    pub bytes: Option<String>,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct Logprobs {
207    pub token: u32,
208    pub logprob: f32,
209    pub bytes: Option<String>,
210    pub top_logprobs: Option<Vec<TopLogprob>>,
211}
212
213fn argmax_sample_last_dim(logits: &Tensor) -> Result<Tensor> {
214    logits.argmax(D::Minus1)
215}
216
217impl Sampler {
218    #[allow(clippy::too_many_arguments)]
219    pub fn new(
220        temperature: Option<f64>,
221        top_n_logprobs: usize,
222        tokenizer: Option<Arc<Tokenizer>>,
223        frequency_penalty: Option<f32>,
224        presence_penalty: Option<f32>,
225        dry_params: Option<DrySamplingParams>,
226        top_k: i64,
227        top_p: f64,
228        min_p: f64,
229        logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
230    ) -> anyhow::Result<Self> {
231        let temperature = if temperature.is_none_or(|v| v < 1e-7) {
232            None
233        } else {
234            temperature
235        };
236        let dry_params = if let Some(ref tokenizer) = tokenizer {
237            dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
238        } else {
239            None
240        };
241        let dry_params = match dry_params {
242            Some(fallible) => Some(fallible?),
243            None => None,
244        };
245        Ok(Self {
246            temperature,
247            top_n_logprobs,
248            tokenizer,
249            frequency_penalty,
250            presence_penalty,
251            dry_params,
252            top_k,
253            top_p,
254            min_p,
255            logits_processors,
256        })
257    }
258
259    fn get_top_logprobs(&self, probs: &[f32], _argsort_indices: &[u32]) -> Result<Vec<TopLogprob>> {
260        // Fast top-k selection without sorting the entire vocabulary
261        let k = self.top_n_logprobs.min(probs.len());
262        if k == 0 {
263            return Ok(Vec::new());
264        }
265        // Build (token, probability) pairs
266        let mut idx_probs: Vec<(u32, f32)> = (0..probs.len() as u32)
267            .map(|i| (i, probs[i as usize]))
268            .collect();
269        // Partition so that the top k probabilities are in the first k positions
270        let (top_k_slice, _, _) = idx_probs.select_nth_unstable_by(k, |a, b| {
271            b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
272        });
273        // Copy and sort only the top k elements by descending probability
274        let mut top_k: Vec<(u32, f32)> = top_k_slice.to_vec();
275        top_k.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
276        // Build the result vector with log10 of probabilities and optional decoding
277        let mut result = Vec::with_capacity(k);
278        if let Some(tokenizer) = &self.tokenizer {
279            for (token, prob) in top_k {
280                let decoded = tokenizer
281                    .decode(&[token], false)
282                    .map_err(|e| Error::Msg(e.to_string()))?;
283                result.push(TopLogprob {
284                    token,
285                    logprob: prob.log(10.0),
286                    bytes: Some(decoded),
287                });
288            }
289        } else {
290            for (token, prob) in top_k {
291                result.push(TopLogprob {
292                    token,
293                    logprob: prob.log(10.0),
294                    bytes: None,
295                });
296            }
297        }
298        Ok(result)
299    }
300
301    fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
302        let next_token = logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
303
304        let probs: Vec<f32> = logits.to_vec1()?;
305
306        let argsort_indices = (0..probs.len() as u32).collect::<Vec<_>>();
307        let logprob = probs[next_token as usize].log(10.0);
308
309        let top_logprobs = if return_logprobs {
310            Some(self.get_top_logprobs(&probs, &argsort_indices)?)
311        } else {
312            None
313        };
314
315        let bytes = if let Some(tokenizer) = &self.tokenizer {
316            Some(
317                tokenizer
318                    .decode(&[next_token], false)
319                    .map_err(|x| Error::Msg(x.to_string()))?,
320            )
321        } else {
322            None
323        };
324
325        Ok(Logprobs {
326            token: next_token,
327            logprob,
328            top_logprobs,
329            bytes,
330        })
331    }
332
333    #[allow(unused)]
334    fn sample_fast(
335        &self,
336        logits: Tensor,
337        context: &[u32],
338        return_logprobs: bool,
339        top_k: i64,
340        top_p: f64,
341        min_p: f64,
342    ) -> Result<Logprobs> {
343        let mut probs = logits.to_dtype(DType::F32)?;
344
345        for processor in &self.logits_processors {
346            probs = processor.apply(&probs, context)?;
347        }
348
349        let context = Tensor::new(context, logits.device())?;
350        let mut counts = logits.zeros_like()?;
351        counts = counts.scatter_add(
352            &context,
353            &context.ones_like()?.to_dtype(counts.dtype())?,
354            D::Minus1,
355        )?;
356
357        let presence = counts
358            .gt(0.)?
359            .where_cond(&counts.ones_like()?, &counts.zeros_like()?)?;
360
361        match self.frequency_penalty {
362            Some(freq_penalty) if freq_penalty != 0. => {
363                probs = (probs - (freq_penalty as f64 * counts)?)?;
364            }
365            _ => (),
366        }
367
368        match self.presence_penalty {
369            Some(pres_penalty) if pres_penalty != 0. => {
370                probs = (probs - (pres_penalty as f64 * presence)?)?;
371            }
372            _ => (),
373        }
374
375        probs = candle_nn::ops::softmax_last_dim(&(probs / self.temperature.unwrap_or(1.))?)?;
376
377        // Top-K
378        if top_k > 0 {
379            let sorted_values = probs.fast_sort_asc(D::Minus1)?;
380            let topk_values = sorted_values.narrow(
381                D::Minus1,
382                sorted_values.dim(D::Minus1)? - top_k as usize,
383                top_k as usize,
384            )?;
385
386            // select the kth largest value as threshold
387            let threshold = topk_values.get_on_dim(D::Minus1, 0)?.unsqueeze(0)?;
388            let mask_topk = probs.broadcast_ge(&threshold)?;
389            probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
390        }
391
392        // Top-P (nucleus)
393        if top_p > 0.0 && top_p < 1.0 {
394            let sorted_probs = probs.fast_sort_asc(D::Minus1)?;
395
396            let cumsum = sorted_probs.fast_cumsum(D::Minus1)?;
397
398            let mask_topp = cumsum.le(top_p)?;
399
400            let masked_sorted =
401                mask_topp.where_cond(&sorted_probs, &Tensor::zeros_like(&sorted_probs)?)?;
402
403            let threshold = masked_sorted.max(D::Minus1)?;
404            let threshold = threshold.unsqueeze(D::Minus1)?;
405            let mask_full = probs.broadcast_ge(&threshold)?;
406            probs = mask_full.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
407        }
408
409        // Min-P
410        if min_p > 0.0 && min_p < 1.0 {
411            let max_vals = probs.max(D::Minus1)?;
412            let threshold_min = (max_vals.unsqueeze(D::Minus1)? * min_p)?;
413            let mask_minp = probs.broadcast_gt(&threshold_min)?;
414            probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
415        }
416
417        let next_token = probs.argmax(D::Minus1)?.to_scalar::<u32>()?;
418
419        // Extract the top‑n log‑probs if the caller asked for them.
420        let (top_logprobs, logprob) = if return_logprobs {
421            let k = self.top_n_logprobs;
422
423            let sorted_values = probs.fast_sort_asc(D::Minus1)?;
424            let topk_values = sorted_values
425                .narrow(
426                    D::Minus1,
427                    sorted_values.dim(D::Minus1)? - top_k as usize,
428                    top_k as usize,
429                )?
430                .to_vec1::<f32>()?;
431
432            let sorted_idxs = probs.fast_argsort_asc(D::Minus1)?;
433            let topk_idxs = sorted_idxs
434                .narrow(
435                    D::Minus1,
436                    sorted_values.dim(D::Minus1)? - top_k as usize,
437                    top_k as usize,
438                )?
439                .to_vec1::<u32>()?;
440
441            let mut result = Vec::with_capacity(k);
442            if let Some(tokenizer) = &self.tokenizer {
443                for (prob, token) in topk_values.iter().zip(topk_idxs) {
444                    let decoded = tokenizer
445                        .decode(&[token], false)
446                        .map_err(|e| Error::Msg(e.to_string()))?;
447                    result.push(TopLogprob {
448                        token,
449                        logprob: prob.log(10.0),
450                        bytes: Some(decoded),
451                    });
452                }
453            } else {
454                for (prob, token) in topk_values.iter().zip(topk_idxs) {
455                    result.push(TopLogprob {
456                        token,
457                        logprob: prob.log(10.0),
458                        bytes: None,
459                    });
460                }
461            }
462
463            let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
464
465            (Some(result), logprob)
466        } else {
467            (None, 1.)
468        };
469
470        let bytes = if let Some(tokenizer) = &self.tokenizer {
471            Some(
472                tokenizer
473                    .decode(&[next_token], false)
474                    .map_err(|x| Error::Msg(x.to_string()))?,
475            )
476        } else {
477            None
478        };
479
480        Ok(Logprobs {
481            token: next_token,
482            logprob,
483            top_logprobs,
484            bytes,
485        })
486    }
487    fn sample_speculative_top_kp_min_p(
488        &self,
489        logits: Tensor,
490        return_logprobs: bool,
491        top_k: i64,
492        top_p: f32,
493        min_p: f32,
494    ) -> Result<Logprobs> {
495        let mut probs: Vec<f32> = logits.to_vec1()?;
496        let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
497
498        if top_k > 0 {
499            // Clamp smaller probabilities to zero.
500            for (index, val) in argsort_indices.iter().enumerate() {
501                if index >= top_k as usize {
502                    probs[*val as usize] = 0.0;
503                }
504            }
505        }
506
507        // TOP P
508
509        // top-p sampling (or "nucleus sampling") samples from the smallest set of
510        // tokens that exceed probability top_p. This way we never sample tokens that
511        // have very low probabilities and are less likely to go "off the rails".
512
513        // Clamp smaller probabilities to zero.
514        let mut cumsum = 0.;
515        for index in &argsort_indices {
516            if cumsum >= top_p {
517                probs[*index as usize] = 0.0;
518            } else {
519                cumsum += probs[*index as usize];
520            }
521        }
522
523        let max_p = probs[argsort_indices[0] as usize];
524
525        // MIN P
526
527        // min-p sampling samples from the tokens whose prob are greater than
528        // (max prob of token in dist) * min_p
529
530        // Clamp smaller probabilities to zero.
531        for index in &argsort_indices {
532            if max_p * min_p >= probs[*index as usize] {
533                probs[*index as usize] = 0.0;
534            }
535        }
536
537        let logits = Tensor::from_slice(&probs, logits.shape(), &Device::Cpu)?;
538
539        let next_token = argmax_sample_last_dim(&logits)?.to_scalar::<u32>()?;
540
541        let logprob = probs[next_token as usize].log(10.0);
542
543        let top_logprobs = if return_logprobs {
544            Some(self.get_top_logprobs(&probs, &argsort_indices)?)
545        } else {
546            None
547        };
548
549        let bytes = if let Some(tokenizer) = &self.tokenizer {
550            Some(
551                tokenizer
552                    .decode(&[next_token], false)
553                    .map_err(|x| Error::Msg(x.to_string()))?,
554            )
555        } else {
556            None
557        };
558
559        Ok(Logprobs {
560            token: next_token,
561            logprob,
562            top_logprobs,
563            bytes,
564        })
565    }
566
567    fn sample_multinomial(
568        &self,
569        probs: &mut Vec<f32>,
570        argsort_indices: Vec<u32>,
571        return_logprobs: bool,
572        rng: Arc<Mutex<Isaac64Rng>>,
573    ) -> Result<Logprobs> {
574        let distr = WeightedIndex::new(&*probs).map_err(Error::wrap)?;
575
576        let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
577        let next_token = distr.sample(&mut mut_ref_rng); // "Find the first item which has a weight *higher* than the chosen weight."
578        let logprob = probs[next_token].log(10.0);
579
580        let top_logprobs = if return_logprobs {
581            Some(self.get_top_logprobs(probs, &argsort_indices)?)
582        } else {
583            None
584        };
585
586        let bytes = if let Some(tokenizer) = &self.tokenizer {
587            Some(
588                tokenizer
589                    .decode(&[next_token.try_into().unwrap()], false)
590                    .map_err(|x| Error::Msg(x.to_string()))?,
591            )
592        } else {
593            None
594        };
595
596        Ok(Logprobs {
597            token: next_token as u32,
598            logprob,
599            top_logprobs,
600            bytes,
601        })
602    }
603
604    #[allow(clippy::too_many_arguments)]
605    fn sample_top_kp_min_p(
606        &self,
607        probs: &mut Vec<f32>,
608        logits: &Tensor,
609        top_k: i64,
610        top_p: f32,
611        min_p: f32,
612        return_logprobs: bool,
613        rng: Arc<Mutex<Isaac64Rng>>,
614    ) -> Result<Logprobs> {
615        let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
616
617        if top_k > 0 {
618            // Clamp smaller probabilities to zero.
619            for (index, val) in argsort_indices.iter().enumerate() {
620                if index >= top_k as usize {
621                    probs[*val as usize] = 0.0;
622                }
623            }
624        }
625
626        if top_p <= 0.0 || top_p >= 1.0 {
627            return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
628        }
629
630        // TOP P
631
632        // top-p sampling (or "nucleus sampling") samples from the smallest set of
633        // tokens that exceed probability top_p. This way we never sample tokens that
634        // have very low probabilities and are less likely to go "off the rails".
635
636        // Clamp smaller probabilities to zero.
637        let mut cumsum = 0.;
638        for index in &argsort_indices {
639            if cumsum >= top_p {
640                probs[*index as usize] = 0.0;
641            } else {
642                cumsum += probs[*index as usize];
643            }
644        }
645
646        if min_p <= 0.0 || min_p >= 1.0 {
647            return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
648        }
649
650        let max_p = probs[argsort_indices[0] as usize];
651
652        // MIN P
653
654        // min-p sampling samples from the tokens whose prob are greater than
655        // (max prob of token in dist) * min_p
656
657        // Clamp smaller probabilities to zero.
658        for index in &argsort_indices {
659            if max_p * min_p >= probs[*index as usize] {
660                probs[*index as usize] = 0.0;
661            }
662        }
663
664        // Sample with clamped probabilities.
665        self.sample_multinomial(probs, argsort_indices, return_logprobs, rng)
666    }
667
668    fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
669        if context.is_empty() {
670            candle_core::bail!("Penalty context is empty, this should not happen.");
671        }
672
673        // Dry penalty
674        self.apply_dry_penalty(&mut logits, context)?;
675
676        // Frequency and Presence penalty
677        self.apply_freq_presc_penalty(&mut logits, context)?;
678
679        let vocab_size = logits.len();
680        Tensor::from_vec(logits, vocab_size, &Device::Cpu)
681    }
682
683    fn apply_freq_presc_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
684        if self.frequency_penalty.is_some() || self.presence_penalty.is_some() {
685            let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
686            let presence_penalty = self.presence_penalty.unwrap_or(0.);
687
688            //mu[j] -> mu[j] - c[j] * alpha_frequency - float(c[j] > 0) * alpha_presence
689
690            let mut counts = vec![0.0f32; logits.len()];
691            for ctx in context.iter() {
692                // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
693                if *ctx as usize >= logits.len() {
694                    continue;
695                }
696                counts[*ctx as usize] += 1.0;
697            }
698
699            for (token_id, logit) in logits.iter_mut().enumerate() {
700                let count = counts[token_id];
701                *logit = *logit
702                    - count * frequency_penalty
703                    - if count > 0.0 { 1. } else { 0. } * presence_penalty;
704            }
705        }
706        Ok(())
707    }
708
709    fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
710        if let Some(ref params) = self.dry_params {
711            if params.multiplier == 0. {
712                return Ok(());
713            }
714
715            let match_indices = context
716                .par_iter()
717                .enumerate()
718                .take(context.len() - 1)
719                .filter(|(_i, x)| *context.last().unwrap() == **x)
720                .map(|(i, _)| i)
721                .collect::<Vec<_>>();
722
723            let mut match_lengths = HashMap::new();
724
725            for i in match_indices {
726                let next_token = context[i + 1];
727
728                if params.sequence_breakers.contains(&next_token) {
729                    continue;
730                }
731
732                let mut match_length = 1;
733
734                // Limit match length to avoid quadratic runtime and potential DoS with adversarial inputs.
735                while match_length < 50 {
736                    if match_length > i {
737                        // Start of input
738                        break;
739                    }
740
741                    let j = i - match_length;
742
743                    let prev_tok = context[context.len() - (match_length + 1)];
744                    if context[j] != prev_tok {
745                        // Start of match reached
746                        break;
747                    }
748
749                    if params.sequence_breakers.contains(&prev_tok) {
750                        // Seq breaking tok reached
751                        break;
752                    }
753
754                    match_length += 1;
755                }
756
757                #[allow(clippy::map_entry)]
758                if match_lengths.contains_key(&next_token) {
759                    match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
760                } else {
761                    match_lengths.insert(next_token, match_length);
762                }
763            }
764
765            // Actually apply penalties
766            for (tok, match_len) in match_lengths {
767                if match_len >= params.allowed_length {
768                    // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
769                    if tok as usize >= logits.len() {
770                        continue;
771                    }
772                    let penalty = params.multiplier
773                        * params.base.powf((match_len - params.allowed_length) as f32);
774                    logits[tok as usize] -= penalty;
775                }
776            }
777        }
778        Ok(())
779    }
780
781    #[allow(unused)]
782    /// Sample the provided tokens.
783    ///
784    /// If the temperature is `None`, argmax sampling is used. Otherwise, the selected sampling is used.
785    /// With `top-p` sampling, if the `top-p` value is `<= 0.0` or `>= 1.0`, multinomial sampling is used.
786    pub fn sample(
787        &self,
788        logits: Tensor,
789        context: &[u32],
790        return_logprobs: bool,
791        rng: Arc<Mutex<Isaac64Rng>>,
792        sample_speculative: bool,
793        multiple_sequences: bool,
794    ) -> Result<Logprobs> {
795        if cfg!(feature = "metal") && !multiple_sequences {
796            return self.sample_fast(
797                logits,
798                context,
799                return_logprobs,
800                self.top_k,
801                self.top_p,
802                self.min_p,
803            );
804        }
805
806        let logits = logits.to_vec1()?;
807        let mut logits = self.apply_penalties(logits, context)?;
808        for processor in &self.logits_processors {
809            logits = processor.apply(&logits, context)?;
810        }
811        let next_token = if sample_speculative {
812            match self.temperature {
813                None => self.sample_speculative_top_kp_min_p(
814                    logits,
815                    return_logprobs,
816                    self.top_k,
817                    self.top_p as f32,
818                    self.min_p as f32,
819                )?,
820                Some(temperature) => {
821                    let logits = (&logits / temperature)?;
822                    let probs = candle_nn::ops::softmax_last_dim(&logits)?;
823
824                    self.sample_speculative_top_kp_min_p(
825                        probs,
826                        return_logprobs,
827                        self.top_k,
828                        self.top_p as f32,
829                        self.min_p as f32,
830                    )?
831                }
832            }
833        } else {
834            match self.temperature {
835                None => self.sample_argmax(logits, return_logprobs)?,
836                Some(temperature) => {
837                    let logits = (&logits / temperature)?;
838                    let logits = candle_nn::ops::softmax_last_dim(&logits)?;
839                    let mut probs: Vec<f32> = logits.to_vec1()?;
840
841                    self.sample_top_kp_min_p(
842                        &mut probs,
843                        &logits,
844                        self.top_k,
845                        self.top_p as f32,
846                        self.min_p as f32,
847                        return_logprobs,
848                        rng,
849                    )?
850                }
851            }
852        };
853        Ok(next_token)
854    }
855}
856
857mod tests {
858    #[test]
859    fn test_argmax() {
860        use super::Sampler;
861        use candle_core::{Device, Tensor};
862        use rand::SeedableRng;
863        use rand_isaac::Isaac64Rng;
864        use std::sync::Arc;
865        use std::sync::Mutex;
866
867        let sampler =
868            Sampler::new(None, 10, None, None, None, None, 32, 0.1, 0.05, vec![]).unwrap();
869        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
870        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
871        let res = sampler
872            .sample(
873                logits,
874                &(0..1024).collect::<Vec<_>>(),
875                false,
876                rng,
877                false,
878                false,
879            )
880            .unwrap();
881        assert_eq!(res.token, 1023);
882        assert_eq!(res.top_logprobs, None);
883        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
884    }
885
886    #[test]
887    fn test_gumbel_speculative() {
888        use super::Sampler;
889        use candle_core::{Device, Tensor};
890        use rand::SeedableRng;
891        use rand_isaac::Isaac64Rng;
892        use std::sync::Arc;
893        use std::sync::Mutex;
894
895        let sampler =
896            Sampler::new(None, 10, None, None, None, None, 32, 0.1, 0.05, vec![]).unwrap();
897        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
898        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
899        let res = sampler
900            .sample(
901                logits,
902                &(0..1024).collect::<Vec<_>>(),
903                false,
904                rng,
905                true,
906                false,
907            )
908            .unwrap();
909        assert_eq!(res.token, 1023);
910        assert_eq!(res.top_logprobs, None);
911        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
912    }
913}