mistralrs_core/
sampler.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4    collections::{HashMap, HashSet},
5    sync::{Arc, Mutex},
6};
7
8use candle_core::{DType, Device, Error, Result, Tensor, D};
9use mistralrs_quant::{CumSumOp, SortOp};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use once_cell::sync::Lazy;
14use rand::distr::{weighted::WeightedIndex, Distribution};
15use rand_isaac::Isaac64Rng;
16use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
17use serde::{Deserialize, Serialize};
18use tokenizers::Tokenizer;
19
20static DRY_SEQUENCE_BREAKERS: Lazy<Vec<String>> =
21    Lazy::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
24/// Stop sequences or ids.
25pub enum StopTokens {
26    Seqs(Vec<String>),
27    Ids(Vec<u32>),
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
31/// Sampling params are used to control sampling.
32pub struct SamplingParams {
33    pub temperature: Option<f64>,
34    pub top_k: Option<usize>,
35    pub top_p: Option<f64>,
36    pub min_p: Option<f64>,
37    pub top_n_logprobs: usize,
38    pub frequency_penalty: Option<f32>,
39    pub presence_penalty: Option<f32>,
40    pub stop_toks: Option<StopTokens>,
41    pub max_len: Option<usize>,
42    pub logits_bias: Option<HashMap<u32, f32>>,
43    pub n_choices: usize,
44    pub dry_params: Option<DrySamplingParams>,
45}
46
47impl SamplingParams {
48    /// This sets up the parameters so that there is:
49    /// - No temperature, topk, topp, minp
50    /// - No penalties, stop tokens, or logit bias
51    /// - No maximum length
52    pub fn deterministic() -> Self {
53        Self {
54            temperature: None,
55            top_k: Some(1),
56            top_p: None,
57            min_p: None,
58            top_n_logprobs: 0,
59            frequency_penalty: None,
60            presence_penalty: None,
61            stop_toks: None,
62            max_len: None,
63            logits_bias: None,
64            n_choices: 1,
65            dry_params: None,
66        }
67    }
68}
69
70#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct DrySamplingParams {
72    pub sequence_breakers: Vec<String>,
73    pub multiplier: f32,
74    pub base: f32,
75    pub allowed_length: usize,
76}
77
78impl DrySamplingParams {
79    pub fn new_with_defaults(
80        multiplier: f32,
81        sequence_breakers: Option<Vec<String>>,
82        base: Option<f32>,
83        allowed_length: Option<usize>,
84    ) -> anyhow::Result<Self> {
85        Ok(Self {
86            base: base.unwrap_or(1.75),
87            allowed_length: allowed_length.unwrap_or(2),
88            sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
89            multiplier,
90        })
91    }
92}
93
94impl Default for DrySamplingParams {
95    fn default() -> Self {
96        Self {
97            multiplier: 0.0,
98            base: 1.75,
99            allowed_length: 2,
100            sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
101        }
102    }
103}
104
105#[derive(Clone, Debug)]
106struct DrySamplingParamsInner {
107    pub sequence_breakers: HashSet<u32>,
108    pub multiplier: f32,
109    pub base: f32,
110    pub allowed_length: usize,
111}
112
113impl DrySamplingParamsInner {
114    pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
115        Ok(Self {
116            base: other.base,
117            allowed_length: other.allowed_length,
118            sequence_breakers: HashSet::from_iter(
119                other
120                    .sequence_breakers
121                    .into_iter()
122                    .map(|breaker| {
123                        tokenizer
124                            // Prefix with 'a' to get the correct encoding of the token at the end of a text.
125                            //
126                            // FIXME: This is a hack. See https://github.com/LostRuins/koboldcpp/pull/982
127                            //        for the correct solution which covers multi-token sequence breakers
128                            //        and ambiguous encodings.
129                            .encode_fast(["a", &breaker].concat(), true)
130                            .map_err(anyhow::Error::msg)
131                            .map(|enc| {
132                                let ids = enc.get_ids();
133                                if !ids.is_empty() {
134                                    None
135                                } else {
136                                    Some(ids[ids.len() - 1])
137                                }
138                            })
139                    })
140                    .collect::<anyhow::Result<Vec<_>>>()?
141                    .into_iter()
142                    .flatten()
143                    .collect::<Vec<_>>(),
144            ),
145            multiplier: other.multiplier,
146        })
147    }
148}
149
150/// Customizable logits processor.
151///
152/// # Example
153/// ```rust
154/// use std::{sync::Arc, ops::Mul};
155/// use mistralrs_core::CustomLogitsProcessor;
156/// use candle_core::{Result, Tensor};
157///
158/// struct ThresholdLogitsProcessor;
159/// impl CustomLogitsProcessor for ThresholdLogitsProcessor {
160///     fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
161///         // Mask is 1 for true, 0 for false.
162///         let mask = logits.ge(0.5)?;
163///         logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
164///     }
165/// }
166/// let processor1: Arc<dyn CustomLogitsProcessor> = Arc::new(|logits: &Tensor, _context: &[u32]| logits * 1.23);
167/// let processor2: Arc<dyn CustomLogitsProcessor> = Arc::new(ThresholdLogitsProcessor);
168/// ```
169pub trait CustomLogitsProcessor: Send + Sync {
170    /// Logits and sequence context (prompt and generated tokens), returning modified tokens.
171    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
172}
173
174impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
175    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
176        self(logits, context)
177    }
178}
179
180/// Sampler for sampling.
181#[derive(Clone)]
182pub struct Sampler {
183    temperature: Option<f64>,
184    top_n_logprobs: usize,
185    tokenizer: Option<Arc<Tokenizer>>,
186    frequency_penalty: Option<f32>,
187    presence_penalty: Option<f32>,
188    dry_params: Option<DrySamplingParamsInner>,
189    top_k: i64,
190    top_p: f64,
191    min_p: f64,
192    logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
193    /// Cached Gumbel noise tensor to avoid reallocating it.
194    gumbel_cache: Arc<Mutex<Option<Tensor>>>,
195}
196
197#[cfg_attr(feature = "pyo3_macros", pyclass)]
198#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
199#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
200/// Top-n logprobs element
201pub struct TopLogprob {
202    pub token: u32,
203    pub logprob: f32,
204    pub bytes: Option<String>,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct Logprobs {
209    pub token: u32,
210    pub logprob: f32,
211    pub bytes: Option<String>,
212    pub top_logprobs: Option<Vec<TopLogprob>>,
213}
214
215fn argmax_sample_last_dim(logits: &Tensor) -> Result<Tensor> {
216    logits.argmax(D::Minus1)
217}
218
219impl Sampler {
220    #[allow(clippy::too_many_arguments)]
221    pub fn new(
222        temperature: Option<f64>,
223        top_n_logprobs: usize,
224        tokenizer: Option<Arc<Tokenizer>>,
225        frequency_penalty: Option<f32>,
226        presence_penalty: Option<f32>,
227        dry_params: Option<DrySamplingParams>,
228        top_k: i64,
229        top_p: f64,
230        min_p: f64,
231        logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
232    ) -> anyhow::Result<Self> {
233        let temperature = if temperature.is_none_or(|v| v < 1e-7) {
234            None
235        } else {
236            temperature
237        };
238        let dry_params = if let Some(ref tokenizer) = tokenizer {
239            dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
240        } else {
241            None
242        };
243        let dry_params = match dry_params {
244            Some(fallible) => Some(fallible?),
245            None => None,
246        };
247        Ok(Self {
248            temperature,
249            top_n_logprobs,
250            tokenizer,
251            frequency_penalty,
252            presence_penalty,
253            dry_params,
254            top_k,
255            top_p,
256            min_p,
257            logits_processors,
258            gumbel_cache: Arc::new(Mutex::new(None)),
259        })
260    }
261
262    fn get_top_logprobs(&self, probs: &[f32], _argsort_indices: &[u32]) -> Result<Vec<TopLogprob>> {
263        // Fast top-k selection without sorting the entire vocabulary
264        let k = self.top_n_logprobs.min(probs.len());
265        if k == 0 {
266            return Ok(Vec::new());
267        }
268        // Build (token, probability) pairs
269        let mut idx_probs: Vec<(u32, f32)> = (0..probs.len() as u32)
270            .map(|i| (i, probs[i as usize]))
271            .collect();
272        // Partition so that the top k probabilities are in the first k positions
273        let (top_k_slice, _, _) = idx_probs.select_nth_unstable_by(k, |a, b| {
274            b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
275        });
276        // Copy and sort only the top k elements by descending probability
277        let mut top_k: Vec<(u32, f32)> = top_k_slice.to_vec();
278        top_k.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
279        // Build the result vector with log10 of probabilities and optional decoding
280        let mut result = Vec::with_capacity(k);
281        if let Some(tokenizer) = &self.tokenizer {
282            for (token, prob) in top_k {
283                let decoded = tokenizer
284                    .decode(&[token], false)
285                    .map_err(|e| Error::Msg(e.to_string()))?;
286                result.push(TopLogprob {
287                    token,
288                    logprob: prob.log(10.0),
289                    bytes: Some(decoded),
290                });
291            }
292        } else {
293            for (token, prob) in top_k {
294                result.push(TopLogprob {
295                    token,
296                    logprob: prob.log(10.0),
297                    bytes: None,
298                });
299            }
300        }
301        Ok(result)
302    }
303
304    fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
305        let next_token = logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
306
307        let probs: Vec<f32> = logits.to_vec1()?;
308
309        let argsort_indices = (0..probs.len() as u32).collect::<Vec<_>>();
310        let logprob = probs[next_token as usize].log(10.0);
311
312        let top_logprobs = if return_logprobs {
313            Some(self.get_top_logprobs(&probs, &argsort_indices)?)
314        } else {
315            None
316        };
317
318        let bytes = if let Some(tokenizer) = &self.tokenizer {
319            Some(
320                tokenizer
321                    .decode(&[next_token], false)
322                    .map_err(|x| Error::Msg(x.to_string()))?,
323            )
324        } else {
325            None
326        };
327
328        Ok(Logprobs {
329            token: next_token,
330            logprob,
331            top_logprobs,
332            bytes,
333        })
334    }
335
336    #[allow(unused)]
337    fn sample_fast(
338        &self,
339        logits: Tensor,
340        context: &[u32],
341        return_logprobs: bool,
342        top_k: i64,
343        top_p: f64,
344        min_p: f64,
345    ) -> Result<Logprobs> {
346        let mut probs = logits.to_dtype(DType::F32)?;
347
348        for processor in &self.logits_processors {
349            probs = processor.apply(&probs, context)?;
350        }
351
352        let context = Tensor::new(context, logits.device())?;
353        let mut counts = logits.zeros_like()?;
354        counts = counts.scatter_add(
355            &context,
356            &context.ones_like()?.to_dtype(counts.dtype())?,
357            D::Minus1,
358        )?;
359
360        let presence = counts
361            .gt(0.)?
362            .where_cond(&counts.ones_like()?, &counts.zeros_like()?)?;
363
364        match self.frequency_penalty {
365            Some(freq_penalty) if freq_penalty != 0. => {
366                probs = (probs - (freq_penalty as f64 * counts)?)?;
367            }
368            _ => (),
369        }
370
371        match self.presence_penalty {
372            Some(pres_penalty) if pres_penalty != 0. => {
373                probs = (probs - (pres_penalty as f64 * presence)?)?;
374            }
375            _ => (),
376        }
377
378        probs = candle_nn::ops::softmax_last_dim(&(probs / self.temperature.unwrap_or(1.))?)?;
379
380        // Top-K
381        if top_k > 0 {
382            let sorted_values = probs.fast_sort_asc(D::Minus1)?;
383            let topk_values = sorted_values.narrow(
384                D::Minus1,
385                sorted_values.dim(D::Minus1)? - top_k as usize,
386                top_k as usize,
387            )?;
388
389            // select the kth largest value as threshold
390            let threshold = topk_values.get_on_dim(D::Minus1, 0)?.unsqueeze(0)?;
391            let mask_topk = probs.broadcast_ge(&threshold)?;
392            probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
393        }
394
395        // Top-P (nucleus)
396        if top_p > 0.0 && top_p < 1.0 {
397            let sorted_probs = probs.fast_sort_asc(D::Minus1)?;
398
399            let cumsum = sorted_probs.fast_cumsum(D::Minus1)?;
400
401            let mask_topp = cumsum.le(top_p)?;
402
403            let masked_sorted =
404                mask_topp.where_cond(&sorted_probs, &Tensor::zeros_like(&sorted_probs)?)?;
405
406            let threshold = masked_sorted.max(D::Minus1)?;
407            let threshold = threshold.unsqueeze(D::Minus1)?;
408            let mask_full = probs.broadcast_ge(&threshold)?;
409            probs = mask_full.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
410        }
411
412        // Min-P
413        if min_p > 0.0 && min_p < 1.0 {
414            let max_vals = probs.max(D::Minus1)?;
415            let threshold_min = (max_vals.unsqueeze(D::Minus1)? * min_p)?;
416            let mask_minp = probs.broadcast_gt(&threshold_min)?;
417            probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
418        }
419
420        // Sample using the Gumbel-max trick fully on-device.
421        let log_probs = probs.log()?;
422        // Generate cached Gumbel noise (-log(-log(u))) once.
423        let gumbel = {
424            let mut guard = self.gumbel_cache.lock().unwrap();
425            if guard.is_none() {
426                let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
427                let noise = uniform
428                    .clamp(1e-20, 1.0)?
429                    .log()? // ln(u)
430                    .neg()? // -ln(u)
431                    .log()? // ln(-ln(u))
432                    .neg()?; // -ln(-ln(u))
433                *guard = Some(noise);
434            }
435            guard.as_ref().unwrap().clone()
436        };
437
438        let gumbel_logits = (&log_probs + &gumbel)?;
439        let next_token = gumbel_logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
440
441        // Extract the top‑n log‑probs if the caller asked for them.
442        let (top_logprobs, logprob) = if return_logprobs {
443            let k = self.top_n_logprobs;
444
445            let sorted_values = probs.fast_sort_asc(D::Minus1)?;
446            let topk_values = sorted_values
447                .narrow(
448                    D::Minus1,
449                    sorted_values.dim(D::Minus1)? - top_k as usize,
450                    top_k as usize,
451                )?
452                .to_vec1::<f32>()?;
453
454            let sorted_idxs = probs.fast_argsort_asc(D::Minus1)?;
455            let topk_idxs = sorted_idxs
456                .narrow(
457                    D::Minus1,
458                    sorted_values.dim(D::Minus1)? - top_k as usize,
459                    top_k as usize,
460                )?
461                .to_vec1::<u32>()?;
462
463            let mut result = Vec::with_capacity(k);
464            if let Some(tokenizer) = &self.tokenizer {
465                for (prob, token) in topk_values.iter().zip(topk_idxs) {
466                    let decoded = tokenizer
467                        .decode(&[token], false)
468                        .map_err(|e| Error::Msg(e.to_string()))?;
469                    result.push(TopLogprob {
470                        token,
471                        logprob: prob.log(10.0),
472                        bytes: Some(decoded),
473                    });
474                }
475            } else {
476                for (prob, token) in topk_values.iter().zip(topk_idxs) {
477                    result.push(TopLogprob {
478                        token,
479                        logprob: prob.log(10.0),
480                        bytes: None,
481                    });
482                }
483            }
484
485            let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
486
487            (Some(result), logprob)
488        } else {
489            (None, 1.)
490        };
491
492        let bytes = if let Some(tokenizer) = &self.tokenizer {
493            Some(
494                tokenizer
495                    .decode(&[next_token], false)
496                    .map_err(|x| Error::Msg(x.to_string()))?,
497            )
498        } else {
499            None
500        };
501
502        Ok(Logprobs {
503            token: next_token,
504            logprob,
505            top_logprobs,
506            bytes,
507        })
508    }
509    fn sample_speculative_top_kp_min_p(
510        &self,
511        logits: Tensor,
512        return_logprobs: bool,
513        top_k: i64,
514        top_p: f32,
515        min_p: f32,
516    ) -> Result<Logprobs> {
517        let mut probs: Vec<f32> = logits.to_vec1()?;
518        let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
519
520        if top_k > 0 {
521            // Clamp smaller probabilities to zero.
522            for (index, val) in argsort_indices.iter().enumerate() {
523                if index >= top_k as usize {
524                    probs[*val as usize] = 0.0;
525                }
526            }
527        }
528
529        // TOP P
530
531        // top-p sampling (or "nucleus sampling") samples from the smallest set of
532        // tokens that exceed probability top_p. This way we never sample tokens that
533        // have very low probabilities and are less likely to go "off the rails".
534
535        // Clamp smaller probabilities to zero.
536        let mut cumsum = 0.;
537        for index in &argsort_indices {
538            if cumsum >= top_p {
539                probs[*index as usize] = 0.0;
540            } else {
541                cumsum += probs[*index as usize];
542            }
543        }
544
545        let max_p = probs[argsort_indices[0] as usize];
546
547        // MIN P
548
549        // min-p sampling samples from the tokens whose prob are greater than
550        // (max prob of token in dist) * min_p
551
552        // Clamp smaller probabilities to zero.
553        for index in &argsort_indices {
554            if max_p * min_p >= probs[*index as usize] {
555                probs[*index as usize] = 0.0;
556            }
557        }
558
559        let logits = Tensor::from_slice(&probs, logits.shape(), &Device::Cpu)?;
560
561        let next_token = argmax_sample_last_dim(&logits)?.to_scalar::<u32>()?;
562
563        let logprob = probs[next_token as usize].log(10.0);
564
565        let top_logprobs = if return_logprobs {
566            Some(self.get_top_logprobs(&probs, &argsort_indices)?)
567        } else {
568            None
569        };
570
571        let bytes = if let Some(tokenizer) = &self.tokenizer {
572            Some(
573                tokenizer
574                    .decode(&[next_token], false)
575                    .map_err(|x| Error::Msg(x.to_string()))?,
576            )
577        } else {
578            None
579        };
580
581        Ok(Logprobs {
582            token: next_token,
583            logprob,
584            top_logprobs,
585            bytes,
586        })
587    }
588
589    fn sample_multinomial(
590        &self,
591        probs: &mut Vec<f32>,
592        argsort_indices: Vec<u32>,
593        return_logprobs: bool,
594        rng: Arc<Mutex<Isaac64Rng>>,
595    ) -> Result<Logprobs> {
596        let distr = WeightedIndex::new(&*probs).map_err(Error::wrap)?;
597
598        let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
599        let next_token = distr.sample(&mut mut_ref_rng); // "Find the first item which has a weight *higher* than the chosen weight."
600        let logprob = probs[next_token].log(10.0);
601
602        let top_logprobs = if return_logprobs {
603            Some(self.get_top_logprobs(probs, &argsort_indices)?)
604        } else {
605            None
606        };
607
608        let bytes = if let Some(tokenizer) = &self.tokenizer {
609            Some(
610                tokenizer
611                    .decode(&[next_token.try_into().unwrap()], false)
612                    .map_err(|x| Error::Msg(x.to_string()))?,
613            )
614        } else {
615            None
616        };
617
618        Ok(Logprobs {
619            token: next_token as u32,
620            logprob,
621            top_logprobs,
622            bytes,
623        })
624    }
625
626    #[allow(clippy::too_many_arguments)]
627    fn sample_top_kp_min_p(
628        &self,
629        probs: &mut Vec<f32>,
630        logits: &Tensor,
631        top_k: i64,
632        top_p: f32,
633        min_p: f32,
634        return_logprobs: bool,
635        rng: Arc<Mutex<Isaac64Rng>>,
636    ) -> Result<Logprobs> {
637        let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
638
639        if top_k > 0 {
640            // Clamp smaller probabilities to zero.
641            for (index, val) in argsort_indices.iter().enumerate() {
642                if index >= top_k as usize {
643                    probs[*val as usize] = 0.0;
644                }
645            }
646        }
647
648        if top_p <= 0.0 || top_p >= 1.0 {
649            return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
650        }
651
652        // TOP P
653
654        // top-p sampling (or "nucleus sampling") samples from the smallest set of
655        // tokens that exceed probability top_p. This way we never sample tokens that
656        // have very low probabilities and are less likely to go "off the rails".
657
658        // Clamp smaller probabilities to zero.
659        let mut cumsum = 0.;
660        for index in &argsort_indices {
661            if cumsum >= top_p {
662                probs[*index as usize] = 0.0;
663            } else {
664                cumsum += probs[*index as usize];
665            }
666        }
667
668        if min_p <= 0.0 || min_p >= 1.0 {
669            return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
670        }
671
672        let max_p = probs[argsort_indices[0] as usize];
673
674        // MIN P
675
676        // min-p sampling samples from the tokens whose prob are greater than
677        // (max prob of token in dist) * min_p
678
679        // Clamp smaller probabilities to zero.
680        for index in &argsort_indices {
681            if max_p * min_p >= probs[*index as usize] {
682                probs[*index as usize] = 0.0;
683            }
684        }
685
686        // Sample with clamped probabilities.
687        self.sample_multinomial(probs, argsort_indices, return_logprobs, rng)
688    }
689
690    fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
691        if context.is_empty() {
692            candle_core::bail!("Penalty context is empty, this should not happen.");
693        }
694
695        // Dry penalty
696        self.apply_dry_penalty(&mut logits, context)?;
697
698        // Frequency and Presence penalty
699        self.apply_freq_presc_penalty(&mut logits, context)?;
700
701        let vocab_size = logits.len();
702        Tensor::from_vec(logits, vocab_size, &Device::Cpu)
703    }
704
705    fn apply_freq_presc_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
706        if self.frequency_penalty.is_some() || self.presence_penalty.is_some() {
707            let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
708            let presence_penalty = self.presence_penalty.unwrap_or(0.);
709
710            //mu[j] -> mu[j] - c[j] * alpha_frequency - float(c[j] > 0) * alpha_presence
711
712            let mut counts = vec![0.0f32; logits.len()];
713            for ctx in context.iter() {
714                // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
715                if *ctx as usize >= logits.len() {
716                    continue;
717                }
718                counts[*ctx as usize] += 1.0;
719            }
720
721            for (token_id, logit) in logits.iter_mut().enumerate() {
722                let count = counts[token_id];
723                *logit = *logit
724                    - count * frequency_penalty
725                    - if count > 0.0 { 1. } else { 0. } * presence_penalty;
726            }
727        }
728        Ok(())
729    }
730
731    fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
732        if let Some(ref params) = self.dry_params {
733            if params.multiplier == 0. {
734                return Ok(());
735            }
736
737            let match_indices = context
738                .par_iter()
739                .enumerate()
740                .take(context.len() - 1)
741                .filter(|(_i, x)| *context.last().unwrap() == **x)
742                .map(|(i, _)| i)
743                .collect::<Vec<_>>();
744
745            let mut match_lengths = HashMap::new();
746
747            for i in match_indices {
748                let next_token = context[i + 1];
749
750                if params.sequence_breakers.contains(&next_token) {
751                    continue;
752                }
753
754                let mut match_length = 1;
755
756                // Limit match length to avoid quadratic runtime and potential DoS with adversarial inputs.
757                while match_length < 50 {
758                    if match_length > i {
759                        // Start of input
760                        break;
761                    }
762
763                    let j = i - match_length;
764
765                    let prev_tok = context[context.len() - (match_length + 1)];
766                    if context[j] != prev_tok {
767                        // Start of match reached
768                        break;
769                    }
770
771                    if params.sequence_breakers.contains(&prev_tok) {
772                        // Seq breaking tok reached
773                        break;
774                    }
775
776                    match_length += 1;
777                }
778
779                #[allow(clippy::map_entry)]
780                if match_lengths.contains_key(&next_token) {
781                    match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
782                } else {
783                    match_lengths.insert(next_token, match_length);
784                }
785            }
786
787            // Actually apply penalties
788            for (tok, match_len) in match_lengths {
789                if match_len >= params.allowed_length {
790                    // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
791                    if tok as usize >= logits.len() {
792                        continue;
793                    }
794                    let penalty = params.multiplier
795                        * params.base.powf((match_len - params.allowed_length) as f32);
796                    logits[tok as usize] -= penalty;
797                }
798            }
799        }
800        Ok(())
801    }
802
803    #[allow(unused)]
804    /// Sample the provided tokens.
805    ///
806    /// If the temperature is `None`, argmax sampling is used. Otherwise, the selected sampling is used.
807    /// With `top-p` sampling, if the `top-p` value is `<= 0.0` or `>= 1.0`, multinomial sampling is used.
808    pub fn sample(
809        &self,
810        logits: Tensor,
811        context: &[u32],
812        return_logprobs: bool,
813        rng: Arc<Mutex<Isaac64Rng>>,
814        sample_speculative: bool,
815        multiple_sequences: bool,
816    ) -> Result<Logprobs> {
817        if cfg!(feature = "metal") && !multiple_sequences {
818            return self.sample_fast(
819                logits,
820                context,
821                return_logprobs,
822                self.top_k,
823                self.top_p,
824                self.min_p,
825            );
826        }
827
828        let logits = logits.to_vec1()?;
829        let mut logits = self.apply_penalties(logits, context)?;
830        for processor in &self.logits_processors {
831            logits = processor.apply(&logits, context)?;
832        }
833        let next_token = if sample_speculative {
834            match self.temperature {
835                None => self.sample_speculative_top_kp_min_p(
836                    logits,
837                    return_logprobs,
838                    self.top_k,
839                    self.top_p as f32,
840                    self.min_p as f32,
841                )?,
842                Some(temperature) => {
843                    let logits = (&logits / temperature)?;
844                    let probs = candle_nn::ops::softmax_last_dim(&logits)?;
845
846                    self.sample_speculative_top_kp_min_p(
847                        probs,
848                        return_logprobs,
849                        self.top_k,
850                        self.top_p as f32,
851                        self.min_p as f32,
852                    )?
853                }
854            }
855        } else {
856            match self.temperature {
857                None => self.sample_argmax(logits, return_logprobs)?,
858                Some(temperature) => {
859                    let logits = (&logits / temperature)?;
860                    let logits = candle_nn::ops::softmax_last_dim(&logits)?;
861                    let mut probs: Vec<f32> = logits.to_vec1()?;
862
863                    self.sample_top_kp_min_p(
864                        &mut probs,
865                        &logits,
866                        self.top_k,
867                        self.top_p as f32,
868                        self.min_p as f32,
869                        return_logprobs,
870                        rng,
871                    )?
872                }
873            }
874        };
875        Ok(next_token)
876    }
877}
878
879mod tests {
880    #[test]
881    fn test_argmax() {
882        use super::Sampler;
883        use candle_core::{Device, Tensor};
884        use rand::SeedableRng;
885        use rand_isaac::Isaac64Rng;
886        use std::sync::Arc;
887        use std::sync::Mutex;
888
889        let sampler =
890            Sampler::new(None, 10, None, None, None, None, 32, 0.1, 0.05, vec![]).unwrap();
891        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
892        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
893        let res = sampler
894            .sample(
895                logits,
896                &(0..1024).collect::<Vec<_>>(),
897                false,
898                rng,
899                false,
900                false,
901            )
902            .unwrap();
903        assert_eq!(res.token, 1023);
904        assert_eq!(res.top_logprobs, None);
905        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
906    }
907
908    #[test]
909    fn test_gumbel_speculative() {
910        use super::Sampler;
911        use candle_core::{Device, Tensor};
912        use rand::SeedableRng;
913        use rand_isaac::Isaac64Rng;
914        use std::sync::Arc;
915        use std::sync::Mutex;
916
917        let sampler =
918            Sampler::new(None, 10, None, None, None, None, 32, 0.1, 0.05, vec![]).unwrap();
919        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
920        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
921        let res = sampler
922            .sample(
923                logits,
924                &(0..1024).collect::<Vec<_>>(),
925                false,
926                rng,
927                true,
928                false,
929            )
930            .unwrap();
931        assert_eq!(res.token, 1023);
932        assert_eq!(res.top_logprobs, None);
933        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
934    }
935}