mistralrs_core/
sampler.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4    collections::{HashMap, HashSet},
5    sync::{Arc, Mutex},
6};
7
8use candle_core::{DType, Device, Error, Result, Tensor, D};
9use mistralrs_quant::{CumSumOp, SortOp};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use once_cell::sync::Lazy;
14use rand::distr::{weighted::WeightedIndex, Distribution};
15use rand_isaac::Isaac64Rng;
16use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
17use serde::{Deserialize, Serialize};
18use tokenizers::Tokenizer;
19
20static DRY_SEQUENCE_BREAKERS: Lazy<Vec<String>> =
21    Lazy::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
24/// Stop sequences or ids.
25pub enum StopTokens {
26    Seqs(Vec<String>),
27    Ids(Vec<u32>),
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
31/// Sampling params are used to control sampling.
32pub struct SamplingParams {
33    pub temperature: Option<f64>,
34    pub top_k: Option<usize>,
35    pub top_p: Option<f64>,
36    pub min_p: Option<f64>,
37    pub top_n_logprobs: usize,
38    pub frequency_penalty: Option<f32>,
39    pub presence_penalty: Option<f32>,
40    pub repetition_penalty: Option<f32>,
41    pub stop_toks: Option<StopTokens>,
42    pub max_len: Option<usize>,
43    pub logits_bias: Option<HashMap<u32, f32>>,
44    pub n_choices: usize,
45    pub dry_params: Option<DrySamplingParams>,
46}
47
48impl SamplingParams {
49    /// This sets up the parameters so that there is:
50    /// - No temperature, topk, topp, minp
51    /// - No penalties, stop tokens, or logit bias
52    /// - No maximum length
53    pub fn deterministic() -> Self {
54        Self {
55            temperature: None,
56            top_k: Some(1),
57            top_p: None,
58            min_p: None,
59            top_n_logprobs: 0,
60            frequency_penalty: None,
61            presence_penalty: None,
62            repetition_penalty: None,
63            stop_toks: None,
64            max_len: None,
65            logits_bias: None,
66            n_choices: 1,
67            dry_params: None,
68        }
69    }
70}
71
72#[derive(Clone, Debug, Serialize, Deserialize)]
73pub struct DrySamplingParams {
74    pub sequence_breakers: Vec<String>,
75    pub multiplier: f32,
76    pub base: f32,
77    pub allowed_length: usize,
78}
79
80impl DrySamplingParams {
81    pub fn new_with_defaults(
82        multiplier: f32,
83        sequence_breakers: Option<Vec<String>>,
84        base: Option<f32>,
85        allowed_length: Option<usize>,
86    ) -> anyhow::Result<Self> {
87        Ok(Self {
88            base: base.unwrap_or(1.75),
89            allowed_length: allowed_length.unwrap_or(2),
90            sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
91            multiplier,
92        })
93    }
94}
95
96impl Default for DrySamplingParams {
97    fn default() -> Self {
98        Self {
99            multiplier: 0.0,
100            base: 1.75,
101            allowed_length: 2,
102            sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
103        }
104    }
105}
106
107#[derive(Clone, Debug)]
108struct DrySamplingParamsInner {
109    pub sequence_breakers: HashSet<u32>,
110    pub multiplier: f32,
111    pub base: f32,
112    pub allowed_length: usize,
113}
114
115impl DrySamplingParamsInner {
116    pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
117        Ok(Self {
118            base: other.base,
119            allowed_length: other.allowed_length,
120            sequence_breakers: HashSet::from_iter(
121                other
122                    .sequence_breakers
123                    .into_iter()
124                    .map(|breaker| {
125                        tokenizer
126                            // Prefix with 'a' to get the correct encoding of the token at the end of a text.
127                            //
128                            // FIXME: This is a hack. See https://github.com/LostRuins/koboldcpp/pull/982
129                            //        for the correct solution which covers multi-token sequence breakers
130                            //        and ambiguous encodings.
131                            .encode_fast(["a", &breaker].concat(), true)
132                            .map_err(anyhow::Error::msg)
133                            .map(|enc| {
134                                let ids = enc.get_ids();
135                                if !ids.is_empty() {
136                                    None
137                                } else {
138                                    Some(ids[ids.len() - 1])
139                                }
140                            })
141                    })
142                    .collect::<anyhow::Result<Vec<_>>>()?
143                    .into_iter()
144                    .flatten()
145                    .collect::<Vec<_>>(),
146            ),
147            multiplier: other.multiplier,
148        })
149    }
150}
151
152/// Customizable logits processor.
153///
154/// # Example
155/// ```rust
156/// use std::{sync::Arc, ops::Mul};
157/// use mistralrs_core::CustomLogitsProcessor;
158/// use candle_core::{Result, Tensor};
159///
160/// struct ThresholdLogitsProcessor;
161/// impl CustomLogitsProcessor for ThresholdLogitsProcessor {
162///     fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
163///         // Mask is 1 for true, 0 for false.
164///         let mask = logits.ge(0.5)?;
165///         logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
166///     }
167/// }
168/// let processor1: Arc<dyn CustomLogitsProcessor> = Arc::new(|logits: &Tensor, _context: &[u32]| logits * 1.23);
169/// let processor2: Arc<dyn CustomLogitsProcessor> = Arc::new(ThresholdLogitsProcessor);
170/// ```
171pub trait CustomLogitsProcessor: Send + Sync {
172    /// Logits and sequence context (prompt and generated tokens), returning modified tokens.
173    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
174}
175
176impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
177    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
178        self(logits, context)
179    }
180}
181
182/// Sampler for sampling.
183#[derive(Clone)]
184pub struct Sampler {
185    temperature: Option<f64>,
186    top_n_logprobs: usize,
187    tokenizer: Option<Arc<Tokenizer>>,
188    frequency_penalty: Option<f32>,
189    presence_penalty: Option<f32>,
190    repetition_penalty: Option<f32>,
191    dry_params: Option<DrySamplingParamsInner>,
192    top_k: i64,
193    top_p: f64,
194    min_p: f64,
195    logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
196    /// Cached Gumbel noise tensor to avoid reallocating it.
197    gumbel_cache: Arc<Mutex<Option<Tensor>>>,
198}
199
200#[cfg_attr(feature = "pyo3_macros", pyclass)]
201#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
202#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
203/// Top-n logprobs element
204pub struct TopLogprob {
205    pub token: u32,
206    pub logprob: f32,
207    pub bytes: Option<String>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct Logprobs {
212    pub token: u32,
213    pub logprob: f32,
214    pub bytes: Option<String>,
215    pub top_logprobs: Option<Vec<TopLogprob>>,
216}
217
218fn argmax_sample_last_dim(logits: &Tensor) -> Result<Tensor> {
219    logits.argmax(D::Minus1)
220}
221
222impl Sampler {
223    #[allow(clippy::too_many_arguments)]
224    pub fn new(
225        temperature: Option<f64>,
226        top_n_logprobs: usize,
227        tokenizer: Option<Arc<Tokenizer>>,
228        frequency_penalty: Option<f32>,
229        presence_penalty: Option<f32>,
230        repetition_penalty: Option<f32>,
231        dry_params: Option<DrySamplingParams>,
232        top_k: i64,
233        top_p: f64,
234        min_p: f64,
235        logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
236    ) -> anyhow::Result<Self> {
237        let temperature = if temperature.is_none_or(|v| v < 1e-7) {
238            None
239        } else {
240            temperature
241        };
242        let dry_params = if let Some(ref tokenizer) = tokenizer {
243            dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
244        } else {
245            None
246        };
247        let dry_params = match dry_params {
248            Some(fallible) => Some(fallible?),
249            None => None,
250        };
251        Ok(Self {
252            temperature,
253            top_n_logprobs,
254            tokenizer,
255            frequency_penalty,
256            presence_penalty,
257            repetition_penalty,
258            dry_params,
259            top_k,
260            top_p,
261            min_p,
262            logits_processors,
263            gumbel_cache: Arc::new(Mutex::new(None)),
264        })
265    }
266
267    fn get_top_logprobs(&self, probs: &[f32], _argsort_indices: &[u32]) -> Result<Vec<TopLogprob>> {
268        // Fast top-k selection without sorting the entire vocabulary
269        let k = self.top_n_logprobs.min(probs.len());
270        if k == 0 {
271            return Ok(Vec::new());
272        }
273        // Build (token, probability) pairs
274        let mut idx_probs: Vec<(u32, f32)> = (0..probs.len() as u32)
275            .map(|i| (i, probs[i as usize]))
276            .collect();
277        // Partition so that the top k probabilities are in the first k positions
278        let (top_k_slice, _, _) = idx_probs.select_nth_unstable_by(k, |a, b| {
279            b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
280        });
281        // Copy and sort only the top k elements by descending probability
282        let mut top_k: Vec<(u32, f32)> = top_k_slice.to_vec();
283        top_k.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
284        // Build the result vector with log10 of probabilities and optional decoding
285        let mut result = Vec::with_capacity(k);
286        if let Some(tokenizer) = &self.tokenizer {
287            for (token, prob) in top_k {
288                let decoded = tokenizer
289                    .decode(&[token], false)
290                    .map_err(|e| Error::Msg(e.to_string()))?;
291                result.push(TopLogprob {
292                    token,
293                    logprob: prob.log(10.0),
294                    bytes: Some(decoded),
295                });
296            }
297        } else {
298            for (token, prob) in top_k {
299                result.push(TopLogprob {
300                    token,
301                    logprob: prob.log(10.0),
302                    bytes: None,
303                });
304            }
305        }
306        Ok(result)
307    }
308
309    fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
310        let next_token = logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
311
312        let probs: Vec<f32> = logits.to_vec1()?;
313
314        let argsort_indices = (0..probs.len() as u32).collect::<Vec<_>>();
315        let logprob = probs[next_token as usize].log(10.0);
316
317        let top_logprobs = if return_logprobs {
318            Some(self.get_top_logprobs(&probs, &argsort_indices)?)
319        } else {
320            None
321        };
322
323        let bytes = if let Some(tokenizer) = &self.tokenizer {
324            Some(
325                tokenizer
326                    .decode(&[next_token], false)
327                    .map_err(|x| Error::Msg(x.to_string()))?,
328            )
329        } else {
330            None
331        };
332
333        Ok(Logprobs {
334            token: next_token,
335            logprob,
336            top_logprobs,
337            bytes,
338        })
339    }
340
341    #[allow(unused)]
342    fn sample_fast(
343        &self,
344        logits: Tensor,
345        context: &[u32],
346        return_logprobs: bool,
347        top_k: i64,
348        top_p: f64,
349        min_p: f64,
350    ) -> Result<Logprobs> {
351        let mut probs = logits.to_dtype(DType::F32)?;
352
353        for processor in &self.logits_processors {
354            probs = processor.apply(&probs, context)?;
355        }
356
357        let context = Tensor::new(context, logits.device())?;
358        let mut counts = logits.zeros_like()?;
359        counts = counts.scatter_add(
360            &context,
361            &context.ones_like()?.to_dtype(counts.dtype())?,
362            D::Minus1,
363        )?;
364
365        let presence = counts
366            .gt(0.)?
367            .where_cond(&counts.ones_like()?, &counts.zeros_like()?)?;
368
369        match self.frequency_penalty {
370            Some(freq_penalty) if freq_penalty != 0. => {
371                probs = (probs - (freq_penalty as f64 * counts)?)?;
372            }
373            _ => (),
374        }
375
376        match self.presence_penalty {
377            Some(pres_penalty) if pres_penalty != 0. => {
378                probs = (probs - (pres_penalty as f64 * &presence)?)?;
379            }
380            _ => (),
381        }
382
383        match self.repetition_penalty {
384            Some(rep_penalty) if rep_penalty != 1. => {
385                let pos_mask = probs.gt(0.)?;
386                let scaled_pos = (&probs / (rep_penalty as f64))?;
387                let scaled_neg = (&probs * (rep_penalty as f64))?;
388                let modified = pos_mask.where_cond(&scaled_pos, &scaled_neg)?;
389
390                let pres_mask = presence.gt(0.)?;
391                probs = pres_mask.where_cond(&modified, &probs)?;
392            }
393            _ => (),
394        }
395
396        probs = candle_nn::ops::softmax_last_dim(&(probs / self.temperature.unwrap_or(1.))?)?;
397
398        // Top-K
399        if top_k > 0 {
400            let sorted_values = probs.fast_sort_asc(D::Minus1)?;
401            let topk_values = sorted_values.narrow(
402                D::Minus1,
403                sorted_values.dim(D::Minus1)? - top_k as usize,
404                top_k as usize,
405            )?;
406
407            // select the kth largest value as threshold
408            let threshold = topk_values.get_on_dim(D::Minus1, 0)?.unsqueeze(0)?;
409            let mask_topk = probs.broadcast_ge(&threshold)?;
410            probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
411        }
412
413        // Top-P (nucleus)
414        if top_p > 0.0 && top_p < 1.0 {
415            let sorted_probs = probs.fast_sort_asc(D::Minus1)?;
416
417            let cumsum = sorted_probs.fast_cumsum(D::Minus1)?;
418
419            let mask_topp = cumsum.le(top_p)?;
420
421            let masked_sorted =
422                mask_topp.where_cond(&sorted_probs, &Tensor::zeros_like(&sorted_probs)?)?;
423
424            let threshold = masked_sorted.max(D::Minus1)?;
425            let threshold = threshold.unsqueeze(D::Minus1)?;
426            let mask_full = probs.broadcast_ge(&threshold)?;
427            probs = mask_full.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
428        }
429
430        // Min-P
431        if min_p > 0.0 && min_p < 1.0 {
432            let max_vals = probs.max(D::Minus1)?;
433            let threshold_min = (max_vals.unsqueeze(D::Minus1)? * min_p)?;
434            let mask_minp = probs.broadcast_gt(&threshold_min)?;
435            probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
436        }
437
438        // Sample using the Gumbel-max trick fully on-device.
439        let log_probs = probs.log()?;
440        // Generate cached Gumbel noise (-log(-log(u))) once.
441        let gumbel = {
442            let mut guard = self.gumbel_cache.lock().unwrap();
443            if guard.is_none() {
444                let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
445                let noise = uniform
446                    .clamp(1e-20, 1.0)?
447                    .log()? // ln(u)
448                    .neg()? // -ln(u)
449                    .log()? // ln(-ln(u))
450                    .neg()?; // -ln(-ln(u))
451                *guard = Some(noise);
452            }
453            guard.as_ref().unwrap().clone()
454        };
455
456        let gumbel_logits = (&log_probs + &gumbel)?;
457        let next_token = gumbel_logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
458
459        // Extract the top‑n log‑probs if the caller asked for them.
460        let (top_logprobs, logprob) = if return_logprobs {
461            let k = self.top_n_logprobs;
462
463            let sorted_values = probs.fast_sort_asc(D::Minus1)?;
464            let topk_values = sorted_values
465                .narrow(
466                    D::Minus1,
467                    sorted_values.dim(D::Minus1)? - top_k as usize,
468                    top_k as usize,
469                )?
470                .to_vec1::<f32>()?;
471
472            let sorted_idxs = probs.fast_argsort_asc(D::Minus1)?;
473            let topk_idxs = sorted_idxs
474                .narrow(
475                    D::Minus1,
476                    sorted_values.dim(D::Minus1)? - top_k as usize,
477                    top_k as usize,
478                )?
479                .to_vec1::<u32>()?;
480
481            let mut result = Vec::with_capacity(k);
482            if let Some(tokenizer) = &self.tokenizer {
483                for (prob, token) in topk_values.iter().zip(topk_idxs) {
484                    let decoded = tokenizer
485                        .decode(&[token], false)
486                        .map_err(|e| Error::Msg(e.to_string()))?;
487                    result.push(TopLogprob {
488                        token,
489                        logprob: prob.log(10.0),
490                        bytes: Some(decoded),
491                    });
492                }
493            } else {
494                for (prob, token) in topk_values.iter().zip(topk_idxs) {
495                    result.push(TopLogprob {
496                        token,
497                        logprob: prob.log(10.0),
498                        bytes: None,
499                    });
500                }
501            }
502
503            let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
504
505            (Some(result), logprob)
506        } else {
507            (None, 1.)
508        };
509
510        let bytes = if let Some(tokenizer) = &self.tokenizer {
511            Some(
512                tokenizer
513                    .decode(&[next_token], false)
514                    .map_err(|x| Error::Msg(x.to_string()))?,
515            )
516        } else {
517            None
518        };
519
520        Ok(Logprobs {
521            token: next_token,
522            logprob,
523            top_logprobs,
524            bytes,
525        })
526    }
527    fn sample_speculative_top_kp_min_p(
528        &self,
529        logits: Tensor,
530        return_logprobs: bool,
531        top_k: i64,
532        top_p: f32,
533        min_p: f32,
534    ) -> Result<Logprobs> {
535        let mut probs: Vec<f32> = logits.to_vec1()?;
536        let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
537
538        if top_k > 0 {
539            // Clamp smaller probabilities to zero.
540            for (index, val) in argsort_indices.iter().enumerate() {
541                if index >= top_k as usize {
542                    probs[*val as usize] = 0.0;
543                }
544            }
545        }
546
547        // TOP P
548
549        // top-p sampling (or "nucleus sampling") samples from the smallest set of
550        // tokens that exceed probability top_p. This way we never sample tokens that
551        // have very low probabilities and are less likely to go "off the rails".
552
553        // Clamp smaller probabilities to zero.
554        let mut cumsum = 0.;
555        for index in &argsort_indices {
556            if cumsum >= top_p {
557                probs[*index as usize] = 0.0;
558            } else {
559                cumsum += probs[*index as usize];
560            }
561        }
562
563        let max_p = probs[argsort_indices[0] as usize];
564
565        // MIN P
566
567        // min-p sampling samples from the tokens whose prob are greater than
568        // (max prob of token in dist) * min_p
569
570        // Clamp smaller probabilities to zero.
571        for index in &argsort_indices {
572            if max_p * min_p >= probs[*index as usize] {
573                probs[*index as usize] = 0.0;
574            }
575        }
576
577        let logits = Tensor::from_slice(&probs, logits.shape(), &Device::Cpu)?;
578
579        let next_token = argmax_sample_last_dim(&logits)?.to_scalar::<u32>()?;
580
581        let logprob = probs[next_token as usize].log(10.0);
582
583        let top_logprobs = if return_logprobs {
584            Some(self.get_top_logprobs(&probs, &argsort_indices)?)
585        } else {
586            None
587        };
588
589        let bytes = if let Some(tokenizer) = &self.tokenizer {
590            Some(
591                tokenizer
592                    .decode(&[next_token], false)
593                    .map_err(|x| Error::Msg(x.to_string()))?,
594            )
595        } else {
596            None
597        };
598
599        Ok(Logprobs {
600            token: next_token,
601            logprob,
602            top_logprobs,
603            bytes,
604        })
605    }
606
607    fn sample_multinomial(
608        &self,
609        probs: &mut Vec<f32>,
610        argsort_indices: Vec<u32>,
611        return_logprobs: bool,
612        rng: Arc<Mutex<Isaac64Rng>>,
613    ) -> Result<Logprobs> {
614        let distr = WeightedIndex::new(&*probs).map_err(Error::wrap)?;
615
616        let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
617        let next_token = distr.sample(&mut mut_ref_rng); // "Find the first item which has a weight *higher* than the chosen weight."
618        let logprob = probs[next_token].log(10.0);
619
620        let top_logprobs = if return_logprobs {
621            Some(self.get_top_logprobs(probs, &argsort_indices)?)
622        } else {
623            None
624        };
625
626        let bytes = if let Some(tokenizer) = &self.tokenizer {
627            Some(
628                tokenizer
629                    .decode(&[next_token.try_into().unwrap()], false)
630                    .map_err(|x| Error::Msg(x.to_string()))?,
631            )
632        } else {
633            None
634        };
635
636        Ok(Logprobs {
637            token: next_token as u32,
638            logprob,
639            top_logprobs,
640            bytes,
641        })
642    }
643
644    #[allow(clippy::too_many_arguments)]
645    fn sample_top_kp_min_p(
646        &self,
647        probs: &mut Vec<f32>,
648        logits: &Tensor,
649        top_k: i64,
650        top_p: f32,
651        min_p: f32,
652        return_logprobs: bool,
653        rng: Arc<Mutex<Isaac64Rng>>,
654    ) -> Result<Logprobs> {
655        let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
656
657        if top_k > 0 {
658            // Clamp smaller probabilities to zero.
659            for (index, val) in argsort_indices.iter().enumerate() {
660                if index >= top_k as usize {
661                    probs[*val as usize] = 0.0;
662                }
663            }
664        }
665
666        if top_p <= 0.0 || top_p >= 1.0 {
667            return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
668        }
669
670        // TOP P
671
672        // top-p sampling (or "nucleus sampling") samples from the smallest set of
673        // tokens that exceed probability top_p. This way we never sample tokens that
674        // have very low probabilities and are less likely to go "off the rails".
675
676        // Clamp smaller probabilities to zero.
677        let mut cumsum = 0.;
678        for index in &argsort_indices {
679            if cumsum >= top_p {
680                probs[*index as usize] = 0.0;
681            } else {
682                cumsum += probs[*index as usize];
683            }
684        }
685
686        if min_p <= 0.0 || min_p >= 1.0 {
687            return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
688        }
689
690        let max_p = probs[argsort_indices[0] as usize];
691
692        // MIN P
693
694        // min-p sampling samples from the tokens whose prob are greater than
695        // (max prob of token in dist) * min_p
696
697        // Clamp smaller probabilities to zero.
698        for index in &argsort_indices {
699            if max_p * min_p >= probs[*index as usize] {
700                probs[*index as usize] = 0.0;
701            }
702        }
703
704        // Sample with clamped probabilities.
705        self.sample_multinomial(probs, argsort_indices, return_logprobs, rng)
706    }
707
708    fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
709        if context.is_empty() {
710            candle_core::bail!("Penalty context is empty, this should not happen.");
711        }
712
713        // Dry penalty
714        self.apply_dry_penalty(&mut logits, context)?;
715
716        // Frequency, presence, repetition penalty
717        self.apply_freq_pres_rep_penalty(&mut logits, context)?;
718
719        let vocab_size = logits.len();
720        Tensor::from_vec(logits, vocab_size, &Device::Cpu)
721    }
722
723    fn apply_freq_pres_rep_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
724        if self.frequency_penalty.is_some()
725            || self.presence_penalty.is_some()
726            || self.repetition_penalty.is_some()
727        {
728            let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
729            let presence_penalty = self.presence_penalty.unwrap_or(0.);
730            let repetition_penalty = self.repetition_penalty.unwrap_or(1.);
731
732            //mu[j] -> mu[j] - c[j] * alpha_frequency - float(c[j] > 0) * alpha_presence
733
734            let mut counts = vec![0.0f32; logits.len()];
735            for ctx in context.iter() {
736                // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
737                if *ctx as usize >= logits.len() {
738                    continue;
739                }
740                counts[*ctx as usize] += 1.0;
741            }
742
743            for (token_id, logit) in logits.iter_mut().enumerate() {
744                let count = counts[token_id];
745                *logit = *logit
746                    - count * frequency_penalty
747                    - if count > 0.0 { 1. } else { 0. } * presence_penalty;
748
749                if repetition_penalty != 1.0 && count > 0.0 {
750                    if *logit > 0.0 {
751                        *logit /= repetition_penalty;
752                    } else {
753                        *logit *= repetition_penalty;
754                    }
755                }
756            }
757        }
758        Ok(())
759    }
760
761    fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
762        if let Some(ref params) = self.dry_params {
763            if params.multiplier == 0. {
764                return Ok(());
765            }
766
767            let match_indices = context
768                .par_iter()
769                .enumerate()
770                .take(context.len() - 1)
771                .filter(|(_i, x)| *context.last().unwrap() == **x)
772                .map(|(i, _)| i)
773                .collect::<Vec<_>>();
774
775            let mut match_lengths = HashMap::new();
776
777            for i in match_indices {
778                let next_token = context[i + 1];
779
780                if params.sequence_breakers.contains(&next_token) {
781                    continue;
782                }
783
784                let mut match_length = 1;
785
786                // Limit match length to avoid quadratic runtime and potential DoS with adversarial inputs.
787                while match_length < 50 {
788                    if match_length > i {
789                        // Start of input
790                        break;
791                    }
792
793                    let j = i - match_length;
794
795                    let prev_tok = context[context.len() - (match_length + 1)];
796                    if context[j] != prev_tok {
797                        // Start of match reached
798                        break;
799                    }
800
801                    if params.sequence_breakers.contains(&prev_tok) {
802                        // Seq breaking tok reached
803                        break;
804                    }
805
806                    match_length += 1;
807                }
808
809                #[allow(clippy::map_entry)]
810                if match_lengths.contains_key(&next_token) {
811                    match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
812                } else {
813                    match_lengths.insert(next_token, match_length);
814                }
815            }
816
817            // Actually apply penalties
818            for (tok, match_len) in match_lengths {
819                if match_len >= params.allowed_length {
820                    // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
821                    if tok as usize >= logits.len() {
822                        continue;
823                    }
824                    let penalty = params.multiplier
825                        * params.base.powf((match_len - params.allowed_length) as f32);
826                    logits[tok as usize] -= penalty;
827                }
828            }
829        }
830        Ok(())
831    }
832
833    #[allow(unused)]
834    /// Sample the provided tokens.
835    ///
836    /// If the temperature is `None`, argmax sampling is used. Otherwise, the selected sampling is used.
837    /// With `top-p` sampling, if the `top-p` value is `<= 0.0` or `>= 1.0`, multinomial sampling is used.
838    pub fn sample(
839        &self,
840        logits: Tensor,
841        context: &[u32],
842        return_logprobs: bool,
843        rng: Arc<Mutex<Isaac64Rng>>,
844        sample_speculative: bool,
845        multiple_sequences: bool,
846    ) -> Result<Logprobs> {
847        // if cfg!(feature = "metal") && !multiple_sequences {
848        //     return self.sample_fast(
849        //         logits,
850        //         context,
851        //         return_logprobs,
852        //         self.top_k,
853        //         self.top_p,
854        //         self.min_p,
855        //     );
856        // }
857
858        let logits = logits.to_vec1()?;
859        let mut logits = self.apply_penalties(logits, context)?;
860        for processor in &self.logits_processors {
861            logits = processor.apply(&logits, context)?;
862        }
863        let next_token = if sample_speculative {
864            match self.temperature {
865                None => self.sample_speculative_top_kp_min_p(
866                    logits,
867                    return_logprobs,
868                    self.top_k,
869                    self.top_p as f32,
870                    self.min_p as f32,
871                )?,
872                Some(temperature) => {
873                    let logits = (&logits / temperature)?;
874                    let probs = candle_nn::ops::softmax_last_dim(&logits)?;
875
876                    self.sample_speculative_top_kp_min_p(
877                        probs,
878                        return_logprobs,
879                        self.top_k,
880                        self.top_p as f32,
881                        self.min_p as f32,
882                    )?
883                }
884            }
885        } else {
886            match self.temperature {
887                None => self.sample_argmax(logits, return_logprobs)?,
888                Some(temperature) => {
889                    let logits = (&logits / temperature)?;
890                    let logits = candle_nn::ops::softmax_last_dim(&logits)?;
891                    let mut probs: Vec<f32> = logits.to_vec1()?;
892
893                    self.sample_top_kp_min_p(
894                        &mut probs,
895                        &logits,
896                        self.top_k,
897                        self.top_p as f32,
898                        self.min_p as f32,
899                        return_logprobs,
900                        rng,
901                    )?
902                }
903            }
904        };
905        Ok(next_token)
906    }
907}
908
909mod tests {
910    #[test]
911    fn test_argmax() {
912        use super::Sampler;
913        use candle_core::{Device, Tensor};
914        use rand::SeedableRng;
915        use rand_isaac::Isaac64Rng;
916        use std::sync::Arc;
917        use std::sync::Mutex;
918
919        let sampler = Sampler::new(
920            None,
921            10,
922            None,
923            None,
924            None,
925            None,
926            None,
927            32,
928            0.1,
929            0.05,
930            vec![],
931        )
932        .unwrap();
933        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
934        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
935        let res = sampler
936            .sample(
937                logits,
938                &(0..1024).collect::<Vec<_>>(),
939                false,
940                rng,
941                false,
942                false,
943            )
944            .unwrap();
945        assert_eq!(res.token, 1023);
946        assert_eq!(res.top_logprobs, None);
947        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
948    }
949
950    #[test]
951    fn test_gumbel_speculative() {
952        use super::Sampler;
953        use candle_core::{Device, Tensor};
954        use rand::SeedableRng;
955        use rand_isaac::Isaac64Rng;
956        use std::sync::Arc;
957        use std::sync::Mutex;
958
959        let sampler = Sampler::new(
960            None,
961            10,
962            None,
963            None,
964            None,
965            None,
966            None,
967            32,
968            0.1,
969            0.05,
970            vec![],
971        )
972        .unwrap();
973        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
974        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
975        let res = sampler
976            .sample(
977                logits,
978                &(0..1024).collect::<Vec<_>>(),
979                false,
980                rng,
981                true,
982                false,
983            )
984            .unwrap();
985        assert_eq!(res.token, 1023);
986        assert_eq!(res.top_logprobs, None);
987        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
988    }
989}