mistralrs_core/
sampler.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4    collections::{HashMap, HashSet},
5    sync::{Arc, LazyLock, Mutex},
6};
7
8use candle_core::{DType, Device, Error, Result, Tensor, D};
9use mistralrs_quant::{CumSumOp, SortOp};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use rand::distr::{weighted::WeightedIndex, Distribution};
14use rand_isaac::Isaac64Rng;
15use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
16use serde::{Deserialize, Serialize};
17use tokenizers::Tokenizer;
18
19static DRY_SEQUENCE_BREAKERS: LazyLock<Vec<String>> =
20    LazyLock::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
23/// Stop sequences or ids.
24pub enum StopTokens {
25    Seqs(Vec<String>),
26    Ids(Vec<u32>),
27}
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
30/// Sampling params are used to control sampling.
31pub struct SamplingParams {
32    pub temperature: Option<f64>,
33    pub top_k: Option<usize>,
34    pub top_p: Option<f64>,
35    pub min_p: Option<f64>,
36    pub top_n_logprobs: usize,
37    pub frequency_penalty: Option<f32>,
38    pub presence_penalty: Option<f32>,
39    pub repetition_penalty: Option<f32>,
40    pub stop_toks: Option<StopTokens>,
41    pub max_len: Option<usize>,
42    pub logits_bias: Option<HashMap<u32, f32>>,
43    pub n_choices: usize,
44    pub dry_params: Option<DrySamplingParams>,
45}
46
47impl SamplingParams {
48    /// This sets up the parameters so that there is:
49    /// - No temperature, topk, topp, minp
50    /// - No penalties, stop tokens, or logit bias
51    /// - No maximum length
52    pub fn deterministic() -> Self {
53        Self {
54            temperature: None,
55            top_k: Some(1),
56            top_p: None,
57            min_p: None,
58            top_n_logprobs: 0,
59            frequency_penalty: None,
60            presence_penalty: None,
61            repetition_penalty: None,
62            stop_toks: None,
63            max_len: None,
64            logits_bias: None,
65            n_choices: 1,
66            dry_params: None,
67        }
68    }
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize)]
72pub struct DrySamplingParams {
73    pub sequence_breakers: Vec<String>,
74    pub multiplier: f32,
75    pub base: f32,
76    pub allowed_length: usize,
77}
78
79impl DrySamplingParams {
80    pub fn new_with_defaults(
81        multiplier: f32,
82        sequence_breakers: Option<Vec<String>>,
83        base: Option<f32>,
84        allowed_length: Option<usize>,
85    ) -> anyhow::Result<Self> {
86        Ok(Self {
87            base: base.unwrap_or(1.75),
88            allowed_length: allowed_length.unwrap_or(2),
89            sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
90            multiplier,
91        })
92    }
93}
94
95impl Default for DrySamplingParams {
96    fn default() -> Self {
97        Self {
98            multiplier: 0.0,
99            base: 1.75,
100            allowed_length: 2,
101            sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
102        }
103    }
104}
105
106#[derive(Clone, Debug)]
107struct DrySamplingParamsInner {
108    pub sequence_breakers: HashSet<u32>,
109    pub multiplier: f32,
110    pub base: f32,
111    pub allowed_length: usize,
112}
113
114impl DrySamplingParamsInner {
115    pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
116        Ok(Self {
117            base: other.base,
118            allowed_length: other.allowed_length,
119            sequence_breakers: HashSet::from_iter(
120                other
121                    .sequence_breakers
122                    .into_iter()
123                    .map(|breaker| {
124                        tokenizer
125                            // Prefix with 'a' to get the correct encoding of the token at the end of a text.
126                            //
127                            // FIXME: This is a hack. See https://github.com/LostRuins/koboldcpp/pull/982
128                            //        for the correct solution which covers multi-token sequence breakers
129                            //        and ambiguous encodings.
130                            .encode_fast(["a", &breaker].concat(), true)
131                            .map_err(anyhow::Error::msg)
132                            .map(|enc| {
133                                let ids = enc.get_ids();
134                                if !ids.is_empty() {
135                                    Some(ids[ids.len() - 1])
136                                } else {
137                                    None
138                                }
139                            })
140                    })
141                    .collect::<anyhow::Result<Vec<_>>>()?
142                    .into_iter()
143                    .flatten()
144                    .collect::<Vec<_>>(),
145            ),
146            multiplier: other.multiplier,
147        })
148    }
149}
150
151/// Customizable logits processor.
152///
153/// # Example
154/// ```rust
155/// use std::{sync::Arc, ops::Mul};
156/// use mistralrs_core::CustomLogitsProcessor;
157/// use candle_core::{Result, Tensor};
158///
159/// struct ThresholdLogitsProcessor;
160/// impl CustomLogitsProcessor for ThresholdLogitsProcessor {
161///     fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
162///         // Mask is 1 for true, 0 for false.
163///         let mask = logits.ge(0.5)?;
164///         logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
165///     }
166/// }
167/// let processor1: Arc<dyn CustomLogitsProcessor> = Arc::new(|logits: &Tensor, _context: &[u32]| logits * 1.23);
168/// let processor2: Arc<dyn CustomLogitsProcessor> = Arc::new(ThresholdLogitsProcessor);
169/// ```
170pub trait CustomLogitsProcessor: Send + Sync {
171    /// Logits and sequence context (prompt and generated tokens), returning modified tokens.
172    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
173}
174
175impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
176    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
177        self(logits, context)
178    }
179}
180
181/// Sampler for sampling.
182#[derive(Clone)]
183pub struct Sampler {
184    temperature: Option<f64>,
185    top_n_logprobs: usize,
186    tokenizer: Option<Arc<Tokenizer>>,
187    frequency_penalty: Option<f32>,
188    presence_penalty: Option<f32>,
189    repetition_penalty: Option<f32>,
190    dry_params: Option<DrySamplingParamsInner>,
191    top_k: i64,
192    top_p: f64,
193    min_p: f64,
194    logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
195    /// Cached Gumbel noise tensor to avoid reallocating it.
196    gumbel_cache: Arc<Mutex<Option<Tensor>>>,
197}
198
199#[cfg_attr(feature = "pyo3_macros", pyclass)]
200#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
201#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
202/// Top-n logprobs element
203pub struct TopLogprob {
204    pub token: u32,
205    pub logprob: f32,
206    pub bytes: Option<String>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct Logprobs {
211    pub token: u32,
212    pub logprob: f32,
213    pub bytes: Option<String>,
214    pub top_logprobs: Option<Vec<TopLogprob>>,
215}
216
217/// Comparator for descending order by probability (second element of tuple).
218#[inline]
219fn cmp_desc_by_prob(a: &(u32, f32), b: &(u32, f32)) -> std::cmp::Ordering {
220    b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
221}
222
223/// Returns the top-k (index, probability) pairs from `probs`, sorted in descending order.
224/// Uses partial sort (O(n) + O(k log k)) instead of full sort (O(n log n)).
225///
226/// If `k >= probs.len()`, returns all elements sorted.
227/// Also zeros out elements in `probs` beyond top-k if `zero_rest` is true.
228fn partial_sort_top_k(probs: &mut [f32], k: usize, zero_rest: bool) -> Vec<(u32, f32)> {
229    let n = probs.len();
230    if n == 0 || k == 0 {
231        return Vec::new();
232    }
233
234    // Build (index, probability) pairs
235    let mut idx_probs: Vec<(u32, f32)> = (0..n as u32).map(|i| (i, probs[i as usize])).collect();
236
237    let k = k.min(n);
238
239    if k < n {
240        // Partial sort: partition so top k elements are in first k positions
241        // select_nth_unstable_by places the k-1th largest at position k-1,
242        // with all larger elements before it (unsorted) and smaller after
243        idx_probs.select_nth_unstable_by(k - 1, cmp_desc_by_prob);
244
245        if zero_rest {
246            // Zero out elements beyond top-k
247            for (idx, _) in idx_probs[k..].iter() {
248                probs[*idx as usize] = 0.0;
249            }
250        }
251
252        // Truncate to top k
253        idx_probs.truncate(k);
254    }
255
256    // Sort just the top k elements (descending by probability)
257    idx_probs.sort_unstable_by(cmp_desc_by_prob);
258
259    idx_probs
260}
261
262/// Find the index of the maximum element in a slice. O(n) scan.
263#[inline]
264fn argmax_f32(values: &[f32]) -> u32 {
265    values
266        .iter()
267        .enumerate()
268        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
269        .map(|(i, _)| i as u32)
270        .unwrap_or(0)
271}
272
273impl Sampler {
274    #[allow(clippy::too_many_arguments)]
275    pub fn new(
276        temperature: Option<f64>,
277        top_n_logprobs: usize,
278        tokenizer: Option<Arc<Tokenizer>>,
279        frequency_penalty: Option<f32>,
280        presence_penalty: Option<f32>,
281        repetition_penalty: Option<f32>,
282        dry_params: Option<DrySamplingParams>,
283        top_k: i64,
284        top_p: f64,
285        min_p: f64,
286        logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
287    ) -> anyhow::Result<Self> {
288        let temperature = if temperature.is_none_or(|v| v < 1e-7) {
289            None
290        } else {
291            temperature
292        };
293        let dry_params = if let Some(ref tokenizer) = tokenizer {
294            dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
295        } else {
296            None
297        };
298        let dry_params = match dry_params {
299            Some(fallible) => Some(fallible?),
300            None => None,
301        };
302        Ok(Self {
303            temperature,
304            top_n_logprobs,
305            tokenizer,
306            frequency_penalty,
307            presence_penalty,
308            repetition_penalty,
309            dry_params,
310            top_k,
311            top_p,
312            min_p,
313            logits_processors,
314            gumbel_cache: Arc::new(Mutex::new(None)),
315        })
316    }
317
318    fn get_top_logprobs(&self, probs: &[f32]) -> Result<Vec<TopLogprob>> {
319        let k = self.top_n_logprobs.min(probs.len());
320        if k == 0 {
321            return Ok(Vec::new());
322        }
323
324        // Use partial sort helper (doesn't modify probs since we pass a copy)
325        let mut probs_copy = probs.to_vec();
326        let top_k = partial_sort_top_k(&mut probs_copy, k, false);
327
328        // Build the result vector with log10 of probabilities and optional decoding
329        let mut result = Vec::with_capacity(k);
330        if let Some(tokenizer) = &self.tokenizer {
331            for (token, prob) in top_k {
332                let decoded = tokenizer
333                    .decode(&[token], false)
334                    .map_err(|e| Error::Msg(e.to_string()))?;
335                result.push(TopLogprob {
336                    token,
337                    logprob: prob.log(10.0),
338                    bytes: Some(decoded),
339                });
340            }
341        } else {
342            for (token, prob) in top_k {
343                result.push(TopLogprob {
344                    token,
345                    logprob: prob.log(10.0),
346                    bytes: None,
347                });
348            }
349        }
350        Ok(result)
351    }
352
353    fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
354        let probs: Vec<f32> = logits.to_vec1()?;
355        let next_token = argmax_f32(&probs);
356        let logprob = probs[next_token as usize].log(10.0);
357
358        let top_logprobs = if return_logprobs {
359            Some(self.get_top_logprobs(&probs)?)
360        } else {
361            None
362        };
363
364        let bytes = if let Some(tokenizer) = &self.tokenizer {
365            Some(
366                tokenizer
367                    .decode(&[next_token], false)
368                    .map_err(|x| Error::Msg(x.to_string()))?,
369            )
370        } else {
371            None
372        };
373
374        Ok(Logprobs {
375            token: next_token,
376            logprob,
377            top_logprobs,
378            bytes,
379        })
380    }
381
382    #[allow(unused)]
383    fn sample_fast(
384        &self,
385        logits: Tensor,
386        context: &[u32],
387        return_logprobs: bool,
388        top_k: i64,
389        top_p: f64,
390        min_p: f64,
391    ) -> Result<Logprobs> {
392        let mut probs = logits.to_dtype(DType::F32)?;
393
394        for processor in &self.logits_processors {
395            probs = processor.apply(&probs, context)?;
396        }
397
398        let context = Tensor::new(context, logits.device())?;
399        let mut counts = logits.zeros_like()?;
400        counts = counts.scatter_add(
401            &context,
402            &context.ones_like()?.to_dtype(counts.dtype())?,
403            D::Minus1,
404        )?;
405
406        let presence = counts
407            .gt(0.)?
408            .where_cond(&counts.ones_like()?, &counts.zeros_like()?)?;
409
410        match self.frequency_penalty {
411            Some(freq_penalty) if freq_penalty != 0. => {
412                probs = (probs - (freq_penalty as f64 * counts)?)?;
413            }
414            _ => (),
415        }
416
417        match self.presence_penalty {
418            Some(pres_penalty) if pres_penalty != 0. => {
419                probs = (probs - (pres_penalty as f64 * &presence)?)?;
420            }
421            _ => (),
422        }
423
424        match self.repetition_penalty {
425            Some(rep_penalty) if rep_penalty != 1. => {
426                let pos_mask = probs.gt(0.)?;
427                let scaled_pos = (&probs / (rep_penalty as f64))?;
428                let scaled_neg = (&probs * (rep_penalty as f64))?;
429                let modified = pos_mask.where_cond(&scaled_pos, &scaled_neg)?;
430
431                let pres_mask = presence.gt(0.)?;
432                probs = pres_mask.where_cond(&modified, &probs)?;
433            }
434            _ => (),
435        }
436
437        probs = candle_nn::ops::softmax_last_dim(&(probs / self.temperature.unwrap_or(1.))?)?;
438
439        // Top-K
440        if top_k > 0 {
441            let sorted_values = probs.fast_sort_asc(D::Minus1)?;
442            let topk_values = sorted_values.narrow(
443                D::Minus1,
444                sorted_values.dim(D::Minus1)? - top_k as usize,
445                top_k as usize,
446            )?;
447
448            // select the kth largest value as threshold
449            let threshold = topk_values.get_on_dim(D::Minus1, 0)?.unsqueeze(0)?;
450            let mask_topk = probs.broadcast_ge(&threshold)?;
451            probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
452        }
453
454        // Top-P (nucleus)
455        if top_p > 0.0 && top_p < 1.0 {
456            let sorted_probs = probs.fast_sort_asc(D::Minus1)?;
457
458            let cumsum = sorted_probs.fast_cumsum(D::Minus1)?;
459
460            let mask_topp = cumsum.le(top_p)?;
461
462            let masked_sorted =
463                mask_topp.where_cond(&sorted_probs, &Tensor::zeros_like(&sorted_probs)?)?;
464
465            let threshold = masked_sorted.max(D::Minus1)?;
466            let threshold = threshold.unsqueeze(D::Minus1)?;
467            let mask_full = probs.broadcast_ge(&threshold)?;
468            probs = mask_full.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
469        }
470
471        // Min-P
472        if min_p > 0.0 && min_p < 1.0 {
473            let max_vals = probs.max(D::Minus1)?;
474            let threshold_min = (max_vals.unsqueeze(D::Minus1)? * min_p)?;
475            let mask_minp = probs.broadcast_gt(&threshold_min)?;
476            probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
477        }
478
479        // Sample using the Gumbel-max trick fully on-device.
480        let log_probs = probs.log()?;
481        // Generate cached Gumbel noise (-log(-log(u))) once.
482        let gumbel = {
483            let mut guard = self.gumbel_cache.lock().unwrap();
484            if guard.is_none() {
485                let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
486                let noise = uniform
487                    .clamp(1e-20, 1.0)?
488                    .log()? // ln(u)
489                    .neg()? // -ln(u)
490                    .log()? // ln(-ln(u))
491                    .neg()?; // -ln(-ln(u))
492                *guard = Some(noise);
493            }
494            guard.as_ref().unwrap().clone()
495        };
496
497        let gumbel_logits = (&log_probs + &gumbel)?;
498        let next_token = gumbel_logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
499
500        // Extract the top‑n log‑probs if the caller asked for them.
501        let (top_logprobs, logprob) = if return_logprobs {
502            let k = self.top_n_logprobs;
503
504            let sorted_values = probs.fast_sort_asc(D::Minus1)?;
505            let topk_values = sorted_values
506                .narrow(
507                    D::Minus1,
508                    sorted_values.dim(D::Minus1)? - top_k as usize,
509                    top_k as usize,
510                )?
511                .to_vec1::<f32>()?;
512
513            let sorted_idxs = probs.fast_argsort_asc(D::Minus1)?;
514            let topk_idxs = sorted_idxs
515                .narrow(
516                    D::Minus1,
517                    sorted_values.dim(D::Minus1)? - top_k as usize,
518                    top_k as usize,
519                )?
520                .to_vec1::<u32>()?;
521
522            let mut result = Vec::with_capacity(k);
523            if let Some(tokenizer) = &self.tokenizer {
524                for (prob, token) in topk_values.iter().zip(topk_idxs) {
525                    let decoded = tokenizer
526                        .decode(&[token], false)
527                        .map_err(|e| Error::Msg(e.to_string()))?;
528                    result.push(TopLogprob {
529                        token,
530                        logprob: prob.log(10.0),
531                        bytes: Some(decoded),
532                    });
533                }
534            } else {
535                for (prob, token) in topk_values.iter().zip(topk_idxs) {
536                    result.push(TopLogprob {
537                        token,
538                        logprob: prob.log(10.0),
539                        bytes: None,
540                    });
541                }
542            }
543
544            let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
545
546            (Some(result), logprob)
547        } else {
548            (None, 1.)
549        };
550
551        let bytes = if let Some(tokenizer) = &self.tokenizer {
552            Some(
553                tokenizer
554                    .decode(&[next_token], false)
555                    .map_err(|x| Error::Msg(x.to_string()))?,
556            )
557        } else {
558            None
559        };
560
561        Ok(Logprobs {
562            token: next_token,
563            logprob,
564            top_logprobs,
565            bytes,
566        })
567    }
568    fn sample_speculative_top_kp_min_p(
569        &self,
570        logits: Tensor,
571        return_logprobs: bool,
572        top_k: i64,
573        top_p: f32,
574        min_p: f32,
575    ) -> Result<Logprobs> {
576        let mut probs: Vec<f32> = logits.to_vec1()?;
577
578        // Determine how many elements we need for partial sort
579        let k = if top_k > 0 {
580            top_k as usize
581        } else {
582            probs.len()
583        };
584
585        // Get sorted top-k indices with partial sort, zeroing out rest
586        let idx_probs = partial_sort_top_k(&mut probs, k, true);
587
588        // TOP P
589        // top-p sampling (or "nucleus sampling") samples from the smallest set of
590        // tokens that exceed probability top_p. This way we never sample tokens that
591        // have very low probabilities and are less likely to go "off the rails".
592
593        // Clamp smaller probabilities to zero.
594        let mut cumsum = 0.;
595        for (index, prob) in &idx_probs {
596            if cumsum >= top_p {
597                probs[*index as usize] = 0.0;
598            } else {
599                cumsum += prob;
600            }
601        }
602
603        // Get max_p from first sorted element
604        let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
605
606        // MIN P
607        // min-p sampling samples from the tokens whose prob are greater than
608        // (max prob of token in dist) * min_p
609
610        // Clamp smaller probabilities to zero.
611        let min_p_threshold = max_p * min_p;
612        for (index, prob) in &idx_probs {
613            if min_p_threshold >= *prob {
614                probs[*index as usize] = 0.0;
615            }
616        }
617
618        // Find argmax directly on the Vec (O(n) scan, no Tensor creation)
619        let next_token = argmax_f32(&probs);
620        let logprob = probs[next_token as usize].log(10.0);
621
622        let top_logprobs = if return_logprobs {
623            Some(self.get_top_logprobs(&probs)?)
624        } else {
625            None
626        };
627
628        let bytes = if let Some(tokenizer) = &self.tokenizer {
629            Some(
630                tokenizer
631                    .decode(&[next_token], false)
632                    .map_err(|x| Error::Msg(x.to_string()))?,
633            )
634        } else {
635            None
636        };
637
638        Ok(Logprobs {
639            token: next_token,
640            logprob,
641            top_logprobs,
642            bytes,
643        })
644    }
645
646    fn sample_multinomial(
647        &self,
648        probs: &[f32],
649        return_logprobs: bool,
650        rng: Arc<Mutex<Isaac64Rng>>,
651    ) -> Result<Logprobs> {
652        let distr = WeightedIndex::new(probs).map_err(Error::wrap)?;
653
654        let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
655        let next_token = distr.sample(&mut mut_ref_rng); // "Find the first item which has a weight *higher* than the chosen weight."
656        let logprob = probs[next_token].log(10.0);
657
658        let top_logprobs = if return_logprobs {
659            Some(self.get_top_logprobs(probs)?)
660        } else {
661            None
662        };
663
664        let bytes = if let Some(tokenizer) = &self.tokenizer {
665            Some(
666                tokenizer
667                    .decode(&[next_token.try_into().unwrap()], false)
668                    .map_err(|x| Error::Msg(x.to_string()))?,
669            )
670        } else {
671            None
672        };
673
674        Ok(Logprobs {
675            token: next_token as u32,
676            logprob,
677            top_logprobs,
678            bytes,
679        })
680    }
681
682    #[allow(clippy::too_many_arguments)]
683    fn sample_top_kp_min_p(
684        &self,
685        probs: &mut [f32],
686        top_k: i64,
687        top_p: f32,
688        min_p: f32,
689        return_logprobs: bool,
690        rng: Arc<Mutex<Isaac64Rng>>,
691    ) -> Result<Logprobs> {
692        // Determine how many elements we need for partial sort
693        let k = if top_k > 0 {
694            top_k as usize
695        } else {
696            probs.len()
697        };
698
699        // Get sorted top-k indices with partial sort, zeroing out rest
700        let idx_probs = partial_sort_top_k(probs, k, true);
701
702        if top_p <= 0.0 || top_p >= 1.0 {
703            return self.sample_multinomial(probs, return_logprobs, rng);
704        }
705
706        // TOP P
707
708        // top-p sampling (or "nucleus sampling") samples from the smallest set of
709        // tokens that exceed probability top_p. This way we never sample tokens that
710        // have very low probabilities and are less likely to go "off the rails".
711
712        // Clamp smaller probabilities to zero.
713        let mut cumsum = 0.;
714        for (index, prob) in &idx_probs {
715            if cumsum >= top_p {
716                probs[*index as usize] = 0.0;
717            } else {
718                cumsum += prob;
719            }
720        }
721
722        if min_p <= 0.0 || min_p >= 1.0 {
723            return self.sample_multinomial(probs, return_logprobs, rng);
724        }
725
726        // Get max_p from first sorted element
727        let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
728
729        // MIN P
730
731        // min-p sampling samples from the tokens whose prob are greater than
732        // (max prob of token in dist) * min_p
733
734        // Clamp smaller probabilities to zero.
735        let min_p_threshold = max_p * min_p;
736        for (index, prob) in &idx_probs {
737            if min_p_threshold >= *prob {
738                probs[*index as usize] = 0.0;
739            }
740        }
741
742        // Sample with clamped probabilities.
743        self.sample_multinomial(probs, return_logprobs, rng)
744    }
745
746    fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
747        if context.is_empty() {
748            candle_core::bail!("Penalty context is empty, this should not happen.");
749        }
750
751        // Dry penalty
752        self.apply_dry_penalty(&mut logits, context)?;
753
754        // Frequency, presence, repetition penalty
755        self.apply_freq_pres_rep_penalty(&mut logits, context)?;
756
757        let vocab_size = logits.len();
758        Tensor::from_vec(logits, vocab_size, &Device::Cpu)
759    }
760
761    fn apply_freq_pres_rep_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
762        if self.frequency_penalty.is_some()
763            || self.presence_penalty.is_some()
764            || self.repetition_penalty.is_some()
765        {
766            let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
767            let presence_penalty = self.presence_penalty.unwrap_or(0.);
768            let repetition_penalty = self.repetition_penalty.unwrap_or(1.);
769
770            //mu[j] -> mu[j] - c[j] * alpha_frequency - float(c[j] > 0) * alpha_presence
771
772            let mut counts = vec![0.0f32; logits.len()];
773            for ctx in context.iter() {
774                // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
775                if *ctx as usize >= logits.len() {
776                    continue;
777                }
778                counts[*ctx as usize] += 1.0;
779            }
780
781            for (token_id, logit) in logits.iter_mut().enumerate() {
782                let count = counts[token_id];
783                *logit = *logit
784                    - count * frequency_penalty
785                    - if count > 0.0 { 1. } else { 0. } * presence_penalty;
786
787                if repetition_penalty != 1.0 && count > 0.0 {
788                    if *logit > 0.0 {
789                        *logit /= repetition_penalty;
790                    } else {
791                        *logit *= repetition_penalty;
792                    }
793                }
794            }
795        }
796        Ok(())
797    }
798
799    /// Threshold for using parallel iteration in dry penalty.
800    /// Below this, sequential is faster due to parallel overhead.
801    const DRY_PENALTY_PAR_THRESHOLD: usize = 1024;
802
803    fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
804        if let Some(ref params) = self.dry_params {
805            if params.multiplier == 0. {
806                return Ok(());
807            }
808
809            let last_token = *context.last().unwrap();
810
811            // Use parallel iteration only for large contexts
812            let match_indices: Vec<usize> = if context.len() > Self::DRY_PENALTY_PAR_THRESHOLD {
813                context
814                    .par_iter()
815                    .enumerate()
816                    .take(context.len() - 1)
817                    .filter(|(_i, x)| last_token == **x)
818                    .map(|(i, _)| i)
819                    .collect()
820            } else {
821                context
822                    .iter()
823                    .enumerate()
824                    .take(context.len() - 1)
825                    .filter(|(_i, x)| last_token == **x)
826                    .map(|(i, _)| i)
827                    .collect()
828            };
829
830            let mut match_lengths = HashMap::new();
831
832            for i in match_indices {
833                let next_token = context[i + 1];
834
835                if params.sequence_breakers.contains(&next_token) {
836                    continue;
837                }
838
839                let mut match_length = 1;
840
841                // Limit match length to avoid quadratic runtime and potential DoS with adversarial inputs.
842                while match_length < 50 {
843                    if match_length > i {
844                        // Start of input
845                        break;
846                    }
847
848                    let j = i - match_length;
849
850                    let prev_tok = context[context.len() - (match_length + 1)];
851                    if context[j] != prev_tok {
852                        // Start of match reached
853                        break;
854                    }
855
856                    if params.sequence_breakers.contains(&prev_tok) {
857                        // Seq breaking tok reached
858                        break;
859                    }
860
861                    match_length += 1;
862                }
863
864                #[allow(clippy::map_entry)]
865                if match_lengths.contains_key(&next_token) {
866                    match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
867                } else {
868                    match_lengths.insert(next_token, match_length);
869                }
870            }
871
872            // Actually apply penalties
873            for (tok, match_len) in match_lengths {
874                if match_len >= params.allowed_length {
875                    // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
876                    if tok as usize >= logits.len() {
877                        continue;
878                    }
879                    let penalty = params.multiplier
880                        * params.base.powf((match_len - params.allowed_length) as f32);
881                    logits[tok as usize] -= penalty;
882                }
883            }
884        }
885        Ok(())
886    }
887
888    #[allow(unused)]
889    /// Sample the provided tokens.
890    ///
891    /// If the temperature is `None`, argmax sampling is used. Otherwise, the selected sampling is used.
892    /// With `top-p` sampling, if the `top-p` value is `<= 0.0` or `>= 1.0`, multinomial sampling is used.
893    pub fn sample(
894        &self,
895        logits: Tensor,
896        context: &[u32],
897        return_logprobs: bool,
898        rng: Arc<Mutex<Isaac64Rng>>,
899        sample_speculative: bool,
900        multiple_sequences: bool,
901    ) -> Result<Logprobs> {
902        // if cfg!(feature = "metal") && !multiple_sequences {
903        //     return self.sample_fast(
904        //         logits,
905        //         context,
906        //         return_logprobs,
907        //         self.top_k,
908        //         self.top_p,
909        //         self.min_p,
910        //     );
911        // }
912
913        let logits = logits.to_vec1()?;
914        let mut logits = self.apply_penalties(logits, context)?;
915        for processor in &self.logits_processors {
916            logits = processor.apply(&logits, context)?;
917        }
918        let next_token = if sample_speculative {
919            match self.temperature {
920                None => self.sample_speculative_top_kp_min_p(
921                    logits,
922                    return_logprobs,
923                    self.top_k,
924                    self.top_p as f32,
925                    self.min_p as f32,
926                )?,
927                Some(temperature) => {
928                    let logits = (&logits / temperature)?;
929                    let probs = candle_nn::ops::softmax_last_dim(&logits)?;
930
931                    self.sample_speculative_top_kp_min_p(
932                        probs,
933                        return_logprobs,
934                        self.top_k,
935                        self.top_p as f32,
936                        self.min_p as f32,
937                    )?
938                }
939            }
940        } else {
941            match self.temperature {
942                None => self.sample_argmax(logits, return_logprobs)?,
943                Some(temperature) => {
944                    let logits = (&logits / temperature)?;
945                    let probs = candle_nn::ops::softmax_last_dim(&logits)?;
946                    let mut probs: Vec<f32> = probs.to_vec1()?;
947
948                    self.sample_top_kp_min_p(
949                        &mut probs,
950                        self.top_k,
951                        self.top_p as f32,
952                        self.min_p as f32,
953                        return_logprobs,
954                        rng,
955                    )?
956                }
957            }
958        };
959        Ok(next_token)
960    }
961}
962
963mod tests {
964    #[test]
965    fn test_argmax() {
966        use super::Sampler;
967        use candle_core::{Device, Tensor};
968        use rand::SeedableRng;
969        use rand_isaac::Isaac64Rng;
970        use std::sync::Arc;
971        use std::sync::Mutex;
972
973        let sampler = Sampler::new(
974            None,
975            10,
976            None,
977            None,
978            None,
979            None,
980            None,
981            32,
982            0.1,
983            0.05,
984            vec![],
985        )
986        .unwrap();
987        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
988        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
989        let res = sampler
990            .sample(
991                logits,
992                &(0..1024).collect::<Vec<_>>(),
993                false,
994                rng,
995                false,
996                false,
997            )
998            .unwrap();
999        assert_eq!(res.token, 1023);
1000        assert_eq!(res.top_logprobs, None);
1001        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
1002    }
1003
1004    #[test]
1005    fn test_gumbel_speculative() {
1006        use super::Sampler;
1007        use candle_core::{Device, Tensor};
1008        use rand::SeedableRng;
1009        use rand_isaac::Isaac64Rng;
1010        use std::sync::Arc;
1011        use std::sync::Mutex;
1012
1013        let sampler = Sampler::new(
1014            None,
1015            10,
1016            None,
1017            None,
1018            None,
1019            None,
1020            None,
1021            32,
1022            0.1,
1023            0.05,
1024            vec![],
1025        )
1026        .unwrap();
1027        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
1028        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
1029        let res = sampler
1030            .sample(
1031                logits,
1032                &(0..1024).collect::<Vec<_>>(),
1033                false,
1034                rng,
1035                true,
1036                false,
1037            )
1038            .unwrap();
1039        assert_eq!(res.token, 1023);
1040        assert_eq!(res.top_logprobs, None);
1041        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
1042    }
1043}