mistralrs_core/
sampler.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4    collections::{HashMap, HashSet},
5    sync::{Arc, LazyLock, Mutex},
6};
7
8use candle_core::{DType, Device, Error, Result, Tensor, D};
9use mistralrs_quant::{CumSumOp, SortOp};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use rand::distr::{weighted::WeightedIndex, Distribution};
14use rand_isaac::Isaac64Rng;
15use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
16use serde::{Deserialize, Serialize};
17use tokenizers::Tokenizer;
18
19static DRY_SEQUENCE_BREAKERS: LazyLock<Vec<String>> =
20    LazyLock::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
23/// Stop sequences or ids.
24pub enum StopTokens {
25    Seqs(Vec<String>),
26    Ids(Vec<u32>),
27}
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
30/// Sampling params are used to control sampling.
31pub struct SamplingParams {
32    pub temperature: Option<f64>,
33    pub top_k: Option<usize>,
34    pub top_p: Option<f64>,
35    pub min_p: Option<f64>,
36    pub top_n_logprobs: usize,
37    pub frequency_penalty: Option<f32>,
38    pub presence_penalty: Option<f32>,
39    pub repetition_penalty: Option<f32>,
40    pub stop_toks: Option<StopTokens>,
41    pub max_len: Option<usize>,
42    pub logits_bias: Option<HashMap<u32, f32>>,
43    pub n_choices: usize,
44    pub dry_params: Option<DrySamplingParams>,
45}
46
47impl SamplingParams {
48    /// This sets up the parameters so that there is:
49    /// - No temperature, topk, topp, minp
50    /// - No penalties, stop tokens, or logit bias
51    /// - No maximum length
52    pub fn deterministic() -> Self {
53        Self {
54            temperature: None,
55            top_k: Some(1),
56            top_p: None,
57            min_p: None,
58            top_n_logprobs: 0,
59            frequency_penalty: None,
60            presence_penalty: None,
61            repetition_penalty: None,
62            stop_toks: None,
63            max_len: None,
64            logits_bias: None,
65            n_choices: 1,
66            dry_params: None,
67        }
68    }
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize)]
72pub struct DrySamplingParams {
73    pub sequence_breakers: Vec<String>,
74    pub multiplier: f32,
75    pub base: f32,
76    pub allowed_length: usize,
77}
78
79impl DrySamplingParams {
80    pub fn new_with_defaults(
81        multiplier: f32,
82        sequence_breakers: Option<Vec<String>>,
83        base: Option<f32>,
84        allowed_length: Option<usize>,
85    ) -> anyhow::Result<Self> {
86        Ok(Self {
87            base: base.unwrap_or(1.75),
88            allowed_length: allowed_length.unwrap_or(2),
89            sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
90            multiplier,
91        })
92    }
93}
94
95impl Default for DrySamplingParams {
96    fn default() -> Self {
97        Self {
98            multiplier: 0.0,
99            base: 1.75,
100            allowed_length: 2,
101            sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
102        }
103    }
104}
105
106#[derive(Clone, Debug)]
107struct DrySamplingParamsInner {
108    pub sequence_breakers: HashSet<u32>,
109    pub multiplier: f32,
110    pub base: f32,
111    pub allowed_length: usize,
112}
113
114impl DrySamplingParamsInner {
115    pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
116        Ok(Self {
117            base: other.base,
118            allowed_length: other.allowed_length,
119            sequence_breakers: HashSet::from_iter(
120                other
121                    .sequence_breakers
122                    .into_iter()
123                    .map(|breaker| {
124                        tokenizer
125                            // Prefix with 'a' to get the correct encoding of the token at the end of a text.
126                            //
127                            // FIXME: This is a hack. See https://github.com/LostRuins/koboldcpp/pull/982
128                            //        for the correct solution which covers multi-token sequence breakers
129                            //        and ambiguous encodings.
130                            .encode_fast(["a", &breaker].concat(), true)
131                            .map_err(anyhow::Error::msg)
132                            .map(|enc| {
133                                let ids = enc.get_ids();
134                                if !ids.is_empty() {
135                                    Some(ids[ids.len() - 1])
136                                } else {
137                                    None
138                                }
139                            })
140                    })
141                    .collect::<anyhow::Result<Vec<_>>>()?
142                    .into_iter()
143                    .flatten()
144                    .collect::<Vec<_>>(),
145            ),
146            multiplier: other.multiplier,
147        })
148    }
149}
150
151/// Customizable logits processor.
152///
153/// # Example
154/// ```rust
155/// use std::{sync::Arc, ops::Mul};
156/// use mistralrs_core::CustomLogitsProcessor;
157/// use candle_core::{Result, Tensor};
158///
159/// struct ThresholdLogitsProcessor;
160/// impl CustomLogitsProcessor for ThresholdLogitsProcessor {
161///     fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
162///         // Mask is 1 for true, 0 for false.
163///         let mask = logits.ge(0.5)?;
164///         logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
165///     }
166/// }
167/// let processor1: Arc<dyn CustomLogitsProcessor> = Arc::new(|logits: &Tensor, _context: &[u32]| logits * 1.23);
168/// let processor2: Arc<dyn CustomLogitsProcessor> = Arc::new(ThresholdLogitsProcessor);
169/// ```
170pub trait CustomLogitsProcessor: Send + Sync {
171    /// Logits and sequence context (prompt and generated tokens), returning modified tokens.
172    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
173}
174
175impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
176    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
177        self(logits, context)
178    }
179}
180
181/// Sampler for sampling.
182#[derive(Clone)]
183pub struct Sampler {
184    temperature: Option<f64>,
185    top_n_logprobs: usize,
186    tokenizer: Option<Arc<Tokenizer>>,
187    frequency_penalty: Option<f32>,
188    presence_penalty: Option<f32>,
189    repetition_penalty: Option<f32>,
190    dry_params: Option<DrySamplingParamsInner>,
191    top_k: i64,
192    top_p: f64,
193    min_p: f64,
194    logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
195    /// Cached Gumbel noise tensor to avoid reallocating it.
196    gumbel_cache: Arc<Mutex<Option<Tensor>>>,
197}
198
199#[cfg_attr(feature = "pyo3_macros", pyclass)]
200#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
201#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
202/// Top-n logprobs element
203pub struct TopLogprob {
204    pub token: u32,
205    pub logprob: f32,
206    pub bytes: Option<String>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct Logprobs {
211    pub token: u32,
212    pub logprob: f32,
213    pub bytes: Option<String>,
214    pub top_logprobs: Option<Vec<TopLogprob>>,
215}
216
217fn argmax_sample_last_dim(logits: &Tensor) -> Result<Tensor> {
218    logits.argmax(D::Minus1)
219}
220
221impl Sampler {
222    #[allow(clippy::too_many_arguments)]
223    pub fn new(
224        temperature: Option<f64>,
225        top_n_logprobs: usize,
226        tokenizer: Option<Arc<Tokenizer>>,
227        frequency_penalty: Option<f32>,
228        presence_penalty: Option<f32>,
229        repetition_penalty: Option<f32>,
230        dry_params: Option<DrySamplingParams>,
231        top_k: i64,
232        top_p: f64,
233        min_p: f64,
234        logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
235    ) -> anyhow::Result<Self> {
236        let temperature = if temperature.is_none_or(|v| v < 1e-7) {
237            None
238        } else {
239            temperature
240        };
241        let dry_params = if let Some(ref tokenizer) = tokenizer {
242            dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
243        } else {
244            None
245        };
246        let dry_params = match dry_params {
247            Some(fallible) => Some(fallible?),
248            None => None,
249        };
250        Ok(Self {
251            temperature,
252            top_n_logprobs,
253            tokenizer,
254            frequency_penalty,
255            presence_penalty,
256            repetition_penalty,
257            dry_params,
258            top_k,
259            top_p,
260            min_p,
261            logits_processors,
262            gumbel_cache: Arc::new(Mutex::new(None)),
263        })
264    }
265
266    fn get_top_logprobs(&self, probs: &[f32], _argsort_indices: &[u32]) -> Result<Vec<TopLogprob>> {
267        // Fast top-k selection without sorting the entire vocabulary
268        let k = self.top_n_logprobs.min(probs.len());
269        if k == 0 {
270            return Ok(Vec::new());
271        }
272        // Build (token, probability) pairs
273        let mut idx_probs: Vec<(u32, f32)> = (0..probs.len() as u32)
274            .map(|i| (i, probs[i as usize]))
275            .collect();
276        // Partition so that the top k probabilities are in the first k positions
277        let (top_k_slice, _, _) = idx_probs.select_nth_unstable_by(k, |a, b| {
278            b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
279        });
280        // Copy and sort only the top k elements by descending probability
281        let mut top_k: Vec<(u32, f32)> = top_k_slice.to_vec();
282        top_k.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
283        // Build the result vector with log10 of probabilities and optional decoding
284        let mut result = Vec::with_capacity(k);
285        if let Some(tokenizer) = &self.tokenizer {
286            for (token, prob) in top_k {
287                let decoded = tokenizer
288                    .decode(&[token], false)
289                    .map_err(|e| Error::Msg(e.to_string()))?;
290                result.push(TopLogprob {
291                    token,
292                    logprob: prob.log(10.0),
293                    bytes: Some(decoded),
294                });
295            }
296        } else {
297            for (token, prob) in top_k {
298                result.push(TopLogprob {
299                    token,
300                    logprob: prob.log(10.0),
301                    bytes: None,
302                });
303            }
304        }
305        Ok(result)
306    }
307
308    fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
309        let next_token = logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
310
311        let probs: Vec<f32> = logits.to_vec1()?;
312
313        let argsort_indices = (0..probs.len() as u32).collect::<Vec<_>>();
314        let logprob = probs[next_token as usize].log(10.0);
315
316        let top_logprobs = if return_logprobs {
317            Some(self.get_top_logprobs(&probs, &argsort_indices)?)
318        } else {
319            None
320        };
321
322        let bytes = if let Some(tokenizer) = &self.tokenizer {
323            Some(
324                tokenizer
325                    .decode(&[next_token], false)
326                    .map_err(|x| Error::Msg(x.to_string()))?,
327            )
328        } else {
329            None
330        };
331
332        Ok(Logprobs {
333            token: next_token,
334            logprob,
335            top_logprobs,
336            bytes,
337        })
338    }
339
340    #[allow(unused)]
341    fn sample_fast(
342        &self,
343        logits: Tensor,
344        context: &[u32],
345        return_logprobs: bool,
346        top_k: i64,
347        top_p: f64,
348        min_p: f64,
349    ) -> Result<Logprobs> {
350        let mut probs = logits.to_dtype(DType::F32)?;
351
352        for processor in &self.logits_processors {
353            probs = processor.apply(&probs, context)?;
354        }
355
356        let context = Tensor::new(context, logits.device())?;
357        let mut counts = logits.zeros_like()?;
358        counts = counts.scatter_add(
359            &context,
360            &context.ones_like()?.to_dtype(counts.dtype())?,
361            D::Minus1,
362        )?;
363
364        let presence = counts
365            .gt(0.)?
366            .where_cond(&counts.ones_like()?, &counts.zeros_like()?)?;
367
368        match self.frequency_penalty {
369            Some(freq_penalty) if freq_penalty != 0. => {
370                probs = (probs - (freq_penalty as f64 * counts)?)?;
371            }
372            _ => (),
373        }
374
375        match self.presence_penalty {
376            Some(pres_penalty) if pres_penalty != 0. => {
377                probs = (probs - (pres_penalty as f64 * &presence)?)?;
378            }
379            _ => (),
380        }
381
382        match self.repetition_penalty {
383            Some(rep_penalty) if rep_penalty != 1. => {
384                let pos_mask = probs.gt(0.)?;
385                let scaled_pos = (&probs / (rep_penalty as f64))?;
386                let scaled_neg = (&probs * (rep_penalty as f64))?;
387                let modified = pos_mask.where_cond(&scaled_pos, &scaled_neg)?;
388
389                let pres_mask = presence.gt(0.)?;
390                probs = pres_mask.where_cond(&modified, &probs)?;
391            }
392            _ => (),
393        }
394
395        probs = candle_nn::ops::softmax_last_dim(&(probs / self.temperature.unwrap_or(1.))?)?;
396
397        // Top-K
398        if top_k > 0 {
399            let sorted_values = probs.fast_sort_asc(D::Minus1)?;
400            let topk_values = sorted_values.narrow(
401                D::Minus1,
402                sorted_values.dim(D::Minus1)? - top_k as usize,
403                top_k as usize,
404            )?;
405
406            // select the kth largest value as threshold
407            let threshold = topk_values.get_on_dim(D::Minus1, 0)?.unsqueeze(0)?;
408            let mask_topk = probs.broadcast_ge(&threshold)?;
409            probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
410        }
411
412        // Top-P (nucleus)
413        if top_p > 0.0 && top_p < 1.0 {
414            let sorted_probs = probs.fast_sort_asc(D::Minus1)?;
415
416            let cumsum = sorted_probs.fast_cumsum(D::Minus1)?;
417
418            let mask_topp = cumsum.le(top_p)?;
419
420            let masked_sorted =
421                mask_topp.where_cond(&sorted_probs, &Tensor::zeros_like(&sorted_probs)?)?;
422
423            let threshold = masked_sorted.max(D::Minus1)?;
424            let threshold = threshold.unsqueeze(D::Minus1)?;
425            let mask_full = probs.broadcast_ge(&threshold)?;
426            probs = mask_full.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
427        }
428
429        // Min-P
430        if min_p > 0.0 && min_p < 1.0 {
431            let max_vals = probs.max(D::Minus1)?;
432            let threshold_min = (max_vals.unsqueeze(D::Minus1)? * min_p)?;
433            let mask_minp = probs.broadcast_gt(&threshold_min)?;
434            probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
435        }
436
437        // Sample using the Gumbel-max trick fully on-device.
438        let log_probs = probs.log()?;
439        // Generate cached Gumbel noise (-log(-log(u))) once.
440        let gumbel = {
441            let mut guard = self.gumbel_cache.lock().unwrap();
442            if guard.is_none() {
443                let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
444                let noise = uniform
445                    .clamp(1e-20, 1.0)?
446                    .log()? // ln(u)
447                    .neg()? // -ln(u)
448                    .log()? // ln(-ln(u))
449                    .neg()?; // -ln(-ln(u))
450                *guard = Some(noise);
451            }
452            guard.as_ref().unwrap().clone()
453        };
454
455        let gumbel_logits = (&log_probs + &gumbel)?;
456        let next_token = gumbel_logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
457
458        // Extract the top‑n log‑probs if the caller asked for them.
459        let (top_logprobs, logprob) = if return_logprobs {
460            let k = self.top_n_logprobs;
461
462            let sorted_values = probs.fast_sort_asc(D::Minus1)?;
463            let topk_values = sorted_values
464                .narrow(
465                    D::Minus1,
466                    sorted_values.dim(D::Minus1)? - top_k as usize,
467                    top_k as usize,
468                )?
469                .to_vec1::<f32>()?;
470
471            let sorted_idxs = probs.fast_argsort_asc(D::Minus1)?;
472            let topk_idxs = sorted_idxs
473                .narrow(
474                    D::Minus1,
475                    sorted_values.dim(D::Minus1)? - top_k as usize,
476                    top_k as usize,
477                )?
478                .to_vec1::<u32>()?;
479
480            let mut result = Vec::with_capacity(k);
481            if let Some(tokenizer) = &self.tokenizer {
482                for (prob, token) in topk_values.iter().zip(topk_idxs) {
483                    let decoded = tokenizer
484                        .decode(&[token], false)
485                        .map_err(|e| Error::Msg(e.to_string()))?;
486                    result.push(TopLogprob {
487                        token,
488                        logprob: prob.log(10.0),
489                        bytes: Some(decoded),
490                    });
491                }
492            } else {
493                for (prob, token) in topk_values.iter().zip(topk_idxs) {
494                    result.push(TopLogprob {
495                        token,
496                        logprob: prob.log(10.0),
497                        bytes: None,
498                    });
499                }
500            }
501
502            let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
503
504            (Some(result), logprob)
505        } else {
506            (None, 1.)
507        };
508
509        let bytes = if let Some(tokenizer) = &self.tokenizer {
510            Some(
511                tokenizer
512                    .decode(&[next_token], false)
513                    .map_err(|x| Error::Msg(x.to_string()))?,
514            )
515        } else {
516            None
517        };
518
519        Ok(Logprobs {
520            token: next_token,
521            logprob,
522            top_logprobs,
523            bytes,
524        })
525    }
526    fn sample_speculative_top_kp_min_p(
527        &self,
528        logits: Tensor,
529        return_logprobs: bool,
530        top_k: i64,
531        top_p: f32,
532        min_p: f32,
533    ) -> Result<Logprobs> {
534        let mut probs: Vec<f32> = logits.to_vec1()?;
535        let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
536
537        if top_k > 0 {
538            // Clamp smaller probabilities to zero.
539            for (index, val) in argsort_indices.iter().enumerate() {
540                if index >= top_k as usize {
541                    probs[*val as usize] = 0.0;
542                }
543            }
544        }
545
546        // TOP P
547
548        // top-p sampling (or "nucleus sampling") samples from the smallest set of
549        // tokens that exceed probability top_p. This way we never sample tokens that
550        // have very low probabilities and are less likely to go "off the rails".
551
552        // Clamp smaller probabilities to zero.
553        let mut cumsum = 0.;
554        for index in &argsort_indices {
555            if cumsum >= top_p {
556                probs[*index as usize] = 0.0;
557            } else {
558                cumsum += probs[*index as usize];
559            }
560        }
561
562        let max_p = probs[argsort_indices[0] as usize];
563
564        // MIN P
565
566        // min-p sampling samples from the tokens whose prob are greater than
567        // (max prob of token in dist) * min_p
568
569        // Clamp smaller probabilities to zero.
570        for index in &argsort_indices {
571            if max_p * min_p >= probs[*index as usize] {
572                probs[*index as usize] = 0.0;
573            }
574        }
575
576        let logits = Tensor::from_slice(&probs, logits.shape(), &Device::Cpu)?;
577
578        let next_token = argmax_sample_last_dim(&logits)?.to_scalar::<u32>()?;
579
580        let logprob = probs[next_token as usize].log(10.0);
581
582        let top_logprobs = if return_logprobs {
583            Some(self.get_top_logprobs(&probs, &argsort_indices)?)
584        } else {
585            None
586        };
587
588        let bytes = if let Some(tokenizer) = &self.tokenizer {
589            Some(
590                tokenizer
591                    .decode(&[next_token], false)
592                    .map_err(|x| Error::Msg(x.to_string()))?,
593            )
594        } else {
595            None
596        };
597
598        Ok(Logprobs {
599            token: next_token,
600            logprob,
601            top_logprobs,
602            bytes,
603        })
604    }
605
606    fn sample_multinomial(
607        &self,
608        probs: &mut Vec<f32>,
609        argsort_indices: Vec<u32>,
610        return_logprobs: bool,
611        rng: Arc<Mutex<Isaac64Rng>>,
612    ) -> Result<Logprobs> {
613        let distr = WeightedIndex::new(&*probs).map_err(Error::wrap)?;
614
615        let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
616        let next_token = distr.sample(&mut mut_ref_rng); // "Find the first item which has a weight *higher* than the chosen weight."
617        let logprob = probs[next_token].log(10.0);
618
619        let top_logprobs = if return_logprobs {
620            Some(self.get_top_logprobs(probs, &argsort_indices)?)
621        } else {
622            None
623        };
624
625        let bytes = if let Some(tokenizer) = &self.tokenizer {
626            Some(
627                tokenizer
628                    .decode(&[next_token.try_into().unwrap()], false)
629                    .map_err(|x| Error::Msg(x.to_string()))?,
630            )
631        } else {
632            None
633        };
634
635        Ok(Logprobs {
636            token: next_token as u32,
637            logprob,
638            top_logprobs,
639            bytes,
640        })
641    }
642
643    #[allow(clippy::too_many_arguments)]
644    fn sample_top_kp_min_p(
645        &self,
646        probs: &mut Vec<f32>,
647        logits: &Tensor,
648        top_k: i64,
649        top_p: f32,
650        min_p: f32,
651        return_logprobs: bool,
652        rng: Arc<Mutex<Isaac64Rng>>,
653    ) -> Result<Logprobs> {
654        let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
655
656        if top_k > 0 {
657            // Clamp smaller probabilities to zero.
658            for (index, val) in argsort_indices.iter().enumerate() {
659                if index >= top_k as usize {
660                    probs[*val as usize] = 0.0;
661                }
662            }
663        }
664
665        if top_p <= 0.0 || top_p >= 1.0 {
666            return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
667        }
668
669        // TOP P
670
671        // top-p sampling (or "nucleus sampling") samples from the smallest set of
672        // tokens that exceed probability top_p. This way we never sample tokens that
673        // have very low probabilities and are less likely to go "off the rails".
674
675        // Clamp smaller probabilities to zero.
676        let mut cumsum = 0.;
677        for index in &argsort_indices {
678            if cumsum >= top_p {
679                probs[*index as usize] = 0.0;
680            } else {
681                cumsum += probs[*index as usize];
682            }
683        }
684
685        if min_p <= 0.0 || min_p >= 1.0 {
686            return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
687        }
688
689        let max_p = probs[argsort_indices[0] as usize];
690
691        // MIN P
692
693        // min-p sampling samples from the tokens whose prob are greater than
694        // (max prob of token in dist) * min_p
695
696        // Clamp smaller probabilities to zero.
697        for index in &argsort_indices {
698            if max_p * min_p >= probs[*index as usize] {
699                probs[*index as usize] = 0.0;
700            }
701        }
702
703        // Sample with clamped probabilities.
704        self.sample_multinomial(probs, argsort_indices, return_logprobs, rng)
705    }
706
707    fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
708        if context.is_empty() {
709            candle_core::bail!("Penalty context is empty, this should not happen.");
710        }
711
712        // Dry penalty
713        self.apply_dry_penalty(&mut logits, context)?;
714
715        // Frequency, presence, repetition penalty
716        self.apply_freq_pres_rep_penalty(&mut logits, context)?;
717
718        let vocab_size = logits.len();
719        Tensor::from_vec(logits, vocab_size, &Device::Cpu)
720    }
721
722    fn apply_freq_pres_rep_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
723        if self.frequency_penalty.is_some()
724            || self.presence_penalty.is_some()
725            || self.repetition_penalty.is_some()
726        {
727            let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
728            let presence_penalty = self.presence_penalty.unwrap_or(0.);
729            let repetition_penalty = self.repetition_penalty.unwrap_or(1.);
730
731            //mu[j] -> mu[j] - c[j] * alpha_frequency - float(c[j] > 0) * alpha_presence
732
733            let mut counts = vec![0.0f32; logits.len()];
734            for ctx in context.iter() {
735                // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
736                if *ctx as usize >= logits.len() {
737                    continue;
738                }
739                counts[*ctx as usize] += 1.0;
740            }
741
742            for (token_id, logit) in logits.iter_mut().enumerate() {
743                let count = counts[token_id];
744                *logit = *logit
745                    - count * frequency_penalty
746                    - if count > 0.0 { 1. } else { 0. } * presence_penalty;
747
748                if repetition_penalty != 1.0 && count > 0.0 {
749                    if *logit > 0.0 {
750                        *logit /= repetition_penalty;
751                    } else {
752                        *logit *= repetition_penalty;
753                    }
754                }
755            }
756        }
757        Ok(())
758    }
759
760    fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
761        if let Some(ref params) = self.dry_params {
762            if params.multiplier == 0. {
763                return Ok(());
764            }
765
766            let match_indices = context
767                .par_iter()
768                .enumerate()
769                .take(context.len() - 1)
770                .filter(|(_i, x)| *context.last().unwrap() == **x)
771                .map(|(i, _)| i)
772                .collect::<Vec<_>>();
773
774            let mut match_lengths = HashMap::new();
775
776            for i in match_indices {
777                let next_token = context[i + 1];
778
779                if params.sequence_breakers.contains(&next_token) {
780                    continue;
781                }
782
783                let mut match_length = 1;
784
785                // Limit match length to avoid quadratic runtime and potential DoS with adversarial inputs.
786                while match_length < 50 {
787                    if match_length > i {
788                        // Start of input
789                        break;
790                    }
791
792                    let j = i - match_length;
793
794                    let prev_tok = context[context.len() - (match_length + 1)];
795                    if context[j] != prev_tok {
796                        // Start of match reached
797                        break;
798                    }
799
800                    if params.sequence_breakers.contains(&prev_tok) {
801                        // Seq breaking tok reached
802                        break;
803                    }
804
805                    match_length += 1;
806                }
807
808                #[allow(clippy::map_entry)]
809                if match_lengths.contains_key(&next_token) {
810                    match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
811                } else {
812                    match_lengths.insert(next_token, match_length);
813                }
814            }
815
816            // Actually apply penalties
817            for (tok, match_len) in match_lengths {
818                if match_len >= params.allowed_length {
819                    // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
820                    if tok as usize >= logits.len() {
821                        continue;
822                    }
823                    let penalty = params.multiplier
824                        * params.base.powf((match_len - params.allowed_length) as f32);
825                    logits[tok as usize] -= penalty;
826                }
827            }
828        }
829        Ok(())
830    }
831
832    #[allow(unused)]
833    /// Sample the provided tokens.
834    ///
835    /// If the temperature is `None`, argmax sampling is used. Otherwise, the selected sampling is used.
836    /// With `top-p` sampling, if the `top-p` value is `<= 0.0` or `>= 1.0`, multinomial sampling is used.
837    pub fn sample(
838        &self,
839        logits: Tensor,
840        context: &[u32],
841        return_logprobs: bool,
842        rng: Arc<Mutex<Isaac64Rng>>,
843        sample_speculative: bool,
844        multiple_sequences: bool,
845    ) -> Result<Logprobs> {
846        // if cfg!(feature = "metal") && !multiple_sequences {
847        //     return self.sample_fast(
848        //         logits,
849        //         context,
850        //         return_logprobs,
851        //         self.top_k,
852        //         self.top_p,
853        //         self.min_p,
854        //     );
855        // }
856
857        let logits = logits.to_vec1()?;
858        let mut logits = self.apply_penalties(logits, context)?;
859        for processor in &self.logits_processors {
860            logits = processor.apply(&logits, context)?;
861        }
862        let next_token = if sample_speculative {
863            match self.temperature {
864                None => self.sample_speculative_top_kp_min_p(
865                    logits,
866                    return_logprobs,
867                    self.top_k,
868                    self.top_p as f32,
869                    self.min_p as f32,
870                )?,
871                Some(temperature) => {
872                    let logits = (&logits / temperature)?;
873                    let probs = candle_nn::ops::softmax_last_dim(&logits)?;
874
875                    self.sample_speculative_top_kp_min_p(
876                        probs,
877                        return_logprobs,
878                        self.top_k,
879                        self.top_p as f32,
880                        self.min_p as f32,
881                    )?
882                }
883            }
884        } else {
885            match self.temperature {
886                None => self.sample_argmax(logits, return_logprobs)?,
887                Some(temperature) => {
888                    let logits = (&logits / temperature)?;
889                    let logits = candle_nn::ops::softmax_last_dim(&logits)?;
890                    let mut probs: Vec<f32> = logits.to_vec1()?;
891
892                    self.sample_top_kp_min_p(
893                        &mut probs,
894                        &logits,
895                        self.top_k,
896                        self.top_p as f32,
897                        self.min_p as f32,
898                        return_logprobs,
899                        rng,
900                    )?
901                }
902            }
903        };
904        Ok(next_token)
905    }
906}
907
908mod tests {
909    #[test]
910    fn test_argmax() {
911        use super::Sampler;
912        use candle_core::{Device, Tensor};
913        use rand::SeedableRng;
914        use rand_isaac::Isaac64Rng;
915        use std::sync::Arc;
916        use std::sync::Mutex;
917
918        let sampler = Sampler::new(
919            None,
920            10,
921            None,
922            None,
923            None,
924            None,
925            None,
926            32,
927            0.1,
928            0.05,
929            vec![],
930        )
931        .unwrap();
932        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
933        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
934        let res = sampler
935            .sample(
936                logits,
937                &(0..1024).collect::<Vec<_>>(),
938                false,
939                rng,
940                false,
941                false,
942            )
943            .unwrap();
944        assert_eq!(res.token, 1023);
945        assert_eq!(res.top_logprobs, None);
946        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
947    }
948
949    #[test]
950    fn test_gumbel_speculative() {
951        use super::Sampler;
952        use candle_core::{Device, Tensor};
953        use rand::SeedableRng;
954        use rand_isaac::Isaac64Rng;
955        use std::sync::Arc;
956        use std::sync::Mutex;
957
958        let sampler = Sampler::new(
959            None,
960            10,
961            None,
962            None,
963            None,
964            None,
965            None,
966            32,
967            0.1,
968            0.05,
969            vec![],
970        )
971        .unwrap();
972        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
973        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
974        let res = sampler
975            .sample(
976                logits,
977                &(0..1024).collect::<Vec<_>>(),
978                false,
979                rng,
980                true,
981                false,
982            )
983            .unwrap();
984        assert_eq!(res.token, 1023);
985        assert_eq!(res.top_logprobs, None);
986        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
987    }
988}