1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4 collections::{HashMap, HashSet},
5 sync::{Arc, LazyLock, Mutex},
6};
7
8use candle_core::{DType, Device, Error, Result, Tensor, D};
9use mistralrs_quant::{CumSumOp, SortOp};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use rand::distr::{weighted::WeightedIndex, Distribution};
14use rand_isaac::Isaac64Rng;
15use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
16use serde::{Deserialize, Serialize};
17use tokenizers::Tokenizer;
18
19static DRY_SEQUENCE_BREAKERS: LazyLock<Vec<String>> =
20 LazyLock::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
23pub enum StopTokens {
25 Seqs(Vec<String>),
26 Ids(Vec<u32>),
27}
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct SamplingParams {
32 pub temperature: Option<f64>,
33 pub top_k: Option<usize>,
34 pub top_p: Option<f64>,
35 pub min_p: Option<f64>,
36 pub top_n_logprobs: usize,
37 pub frequency_penalty: Option<f32>,
38 pub presence_penalty: Option<f32>,
39 pub repetition_penalty: Option<f32>,
40 pub stop_toks: Option<StopTokens>,
41 pub max_len: Option<usize>,
42 pub logits_bias: Option<HashMap<u32, f32>>,
43 pub n_choices: usize,
44 pub dry_params: Option<DrySamplingParams>,
45}
46
47impl SamplingParams {
48 pub fn deterministic() -> Self {
53 Self {
54 temperature: None,
55 top_k: Some(1),
56 top_p: None,
57 min_p: None,
58 top_n_logprobs: 0,
59 frequency_penalty: None,
60 presence_penalty: None,
61 repetition_penalty: None,
62 stop_toks: None,
63 max_len: None,
64 logits_bias: None,
65 n_choices: 1,
66 dry_params: None,
67 }
68 }
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize)]
72pub struct DrySamplingParams {
73 pub sequence_breakers: Vec<String>,
74 pub multiplier: f32,
75 pub base: f32,
76 pub allowed_length: usize,
77}
78
79impl DrySamplingParams {
80 pub fn new_with_defaults(
81 multiplier: f32,
82 sequence_breakers: Option<Vec<String>>,
83 base: Option<f32>,
84 allowed_length: Option<usize>,
85 ) -> anyhow::Result<Self> {
86 Ok(Self {
87 base: base.unwrap_or(1.75),
88 allowed_length: allowed_length.unwrap_or(2),
89 sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
90 multiplier,
91 })
92 }
93}
94
95impl Default for DrySamplingParams {
96 fn default() -> Self {
97 Self {
98 multiplier: 0.0,
99 base: 1.75,
100 allowed_length: 2,
101 sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
102 }
103 }
104}
105
106#[derive(Clone, Debug)]
107struct DrySamplingParamsInner {
108 pub sequence_breakers: HashSet<u32>,
109 pub multiplier: f32,
110 pub base: f32,
111 pub allowed_length: usize,
112}
113
114impl DrySamplingParamsInner {
115 pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
116 Ok(Self {
117 base: other.base,
118 allowed_length: other.allowed_length,
119 sequence_breakers: HashSet::from_iter(
120 other
121 .sequence_breakers
122 .into_iter()
123 .map(|breaker| {
124 tokenizer
125 .encode_fast(["a", &breaker].concat(), true)
131 .map_err(anyhow::Error::msg)
132 .map(|enc| {
133 let ids = enc.get_ids();
134 if !ids.is_empty() {
135 Some(ids[ids.len() - 1])
136 } else {
137 None
138 }
139 })
140 })
141 .collect::<anyhow::Result<Vec<_>>>()?
142 .into_iter()
143 .flatten()
144 .collect::<Vec<_>>(),
145 ),
146 multiplier: other.multiplier,
147 })
148 }
149}
150
151pub trait CustomLogitsProcessor: Send + Sync {
171 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
173}
174
175impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
176 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
177 self(logits, context)
178 }
179}
180
181#[derive(Clone)]
183pub struct Sampler {
184 temperature: Option<f64>,
185 top_n_logprobs: usize,
186 tokenizer: Option<Arc<Tokenizer>>,
187 frequency_penalty: Option<f32>,
188 presence_penalty: Option<f32>,
189 repetition_penalty: Option<f32>,
190 dry_params: Option<DrySamplingParamsInner>,
191 top_k: i64,
192 top_p: f64,
193 min_p: f64,
194 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
195 gumbel_cache: Arc<Mutex<Option<Tensor>>>,
197}
198
199#[cfg_attr(feature = "pyo3_macros", pyclass)]
200#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
201#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
202pub struct TopLogprob {
204 pub token: u32,
205 pub logprob: f32,
206 pub bytes: Option<String>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct Logprobs {
211 pub token: u32,
212 pub logprob: f32,
213 pub bytes: Option<String>,
214 pub top_logprobs: Option<Vec<TopLogprob>>,
215}
216
217#[inline]
219fn cmp_desc_by_prob(a: &(u32, f32), b: &(u32, f32)) -> std::cmp::Ordering {
220 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
221}
222
223fn partial_sort_top_k(probs: &mut [f32], k: usize, zero_rest: bool) -> Vec<(u32, f32)> {
229 let n = probs.len();
230 if n == 0 || k == 0 {
231 return Vec::new();
232 }
233
234 let mut idx_probs: Vec<(u32, f32)> = (0..n as u32).map(|i| (i, probs[i as usize])).collect();
236
237 let k = k.min(n);
238
239 if k < n {
240 idx_probs.select_nth_unstable_by(k - 1, cmp_desc_by_prob);
244
245 if zero_rest {
246 for (idx, _) in idx_probs[k..].iter() {
248 probs[*idx as usize] = 0.0;
249 }
250 }
251
252 idx_probs.truncate(k);
254 }
255
256 idx_probs.sort_unstable_by(cmp_desc_by_prob);
258
259 idx_probs
260}
261
262#[inline]
264fn argmax_f32(values: &[f32]) -> u32 {
265 values
266 .iter()
267 .enumerate()
268 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
269 .map(|(i, _)| i as u32)
270 .unwrap_or(0)
271}
272
273impl Sampler {
274 #[allow(clippy::too_many_arguments)]
275 pub fn new(
276 temperature: Option<f64>,
277 top_n_logprobs: usize,
278 tokenizer: Option<Arc<Tokenizer>>,
279 frequency_penalty: Option<f32>,
280 presence_penalty: Option<f32>,
281 repetition_penalty: Option<f32>,
282 dry_params: Option<DrySamplingParams>,
283 top_k: i64,
284 top_p: f64,
285 min_p: f64,
286 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
287 ) -> anyhow::Result<Self> {
288 let temperature = if temperature.is_none_or(|v| v < 1e-7) {
289 None
290 } else {
291 temperature
292 };
293 let dry_params = if let Some(ref tokenizer) = tokenizer {
294 dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
295 } else {
296 None
297 };
298 let dry_params = match dry_params {
299 Some(fallible) => Some(fallible?),
300 None => None,
301 };
302 Ok(Self {
303 temperature,
304 top_n_logprobs,
305 tokenizer,
306 frequency_penalty,
307 presence_penalty,
308 repetition_penalty,
309 dry_params,
310 top_k,
311 top_p,
312 min_p,
313 logits_processors,
314 gumbel_cache: Arc::new(Mutex::new(None)),
315 })
316 }
317
318 fn get_top_logprobs(&self, probs: &[f32]) -> Result<Vec<TopLogprob>> {
319 let k = self.top_n_logprobs.min(probs.len());
320 if k == 0 {
321 return Ok(Vec::new());
322 }
323
324 let mut probs_copy = probs.to_vec();
326 let top_k = partial_sort_top_k(&mut probs_copy, k, false);
327
328 let mut result = Vec::with_capacity(k);
330 if let Some(tokenizer) = &self.tokenizer {
331 for (token, prob) in top_k {
332 let decoded = tokenizer
333 .decode(&[token], false)
334 .map_err(|e| Error::Msg(e.to_string()))?;
335 result.push(TopLogprob {
336 token,
337 logprob: prob.log(10.0),
338 bytes: Some(decoded),
339 });
340 }
341 } else {
342 for (token, prob) in top_k {
343 result.push(TopLogprob {
344 token,
345 logprob: prob.log(10.0),
346 bytes: None,
347 });
348 }
349 }
350 Ok(result)
351 }
352
353 fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
354 let probs: Vec<f32> = logits.to_vec1()?;
355 let next_token = argmax_f32(&probs);
356 let logprob = probs[next_token as usize].log(10.0);
357
358 let top_logprobs = if return_logprobs {
359 Some(self.get_top_logprobs(&probs)?)
360 } else {
361 None
362 };
363
364 let bytes = if let Some(tokenizer) = &self.tokenizer {
365 Some(
366 tokenizer
367 .decode(&[next_token], false)
368 .map_err(|x| Error::Msg(x.to_string()))?,
369 )
370 } else {
371 None
372 };
373
374 Ok(Logprobs {
375 token: next_token,
376 logprob,
377 top_logprobs,
378 bytes,
379 })
380 }
381
382 #[allow(unused)]
383 fn sample_fast(
384 &self,
385 logits: Tensor,
386 context: &[u32],
387 return_logprobs: bool,
388 top_k: i64,
389 top_p: f64,
390 min_p: f64,
391 ) -> Result<Logprobs> {
392 let mut probs = logits.to_dtype(DType::F32)?;
393
394 for processor in &self.logits_processors {
395 probs = processor.apply(&probs, context)?;
396 }
397
398 let context = Tensor::new(context, logits.device())?;
399 let mut counts = logits.zeros_like()?;
400 counts = counts.scatter_add(
401 &context,
402 &context.ones_like()?.to_dtype(counts.dtype())?,
403 D::Minus1,
404 )?;
405
406 let presence = counts
407 .gt(0.)?
408 .where_cond(&counts.ones_like()?, &counts.zeros_like()?)?;
409
410 match self.frequency_penalty {
411 Some(freq_penalty) if freq_penalty != 0. => {
412 probs = (probs - (freq_penalty as f64 * counts)?)?;
413 }
414 _ => (),
415 }
416
417 match self.presence_penalty {
418 Some(pres_penalty) if pres_penalty != 0. => {
419 probs = (probs - (pres_penalty as f64 * &presence)?)?;
420 }
421 _ => (),
422 }
423
424 match self.repetition_penalty {
425 Some(rep_penalty) if rep_penalty != 1. => {
426 let pos_mask = probs.gt(0.)?;
427 let scaled_pos = (&probs / (rep_penalty as f64))?;
428 let scaled_neg = (&probs * (rep_penalty as f64))?;
429 let modified = pos_mask.where_cond(&scaled_pos, &scaled_neg)?;
430
431 let pres_mask = presence.gt(0.)?;
432 probs = pres_mask.where_cond(&modified, &probs)?;
433 }
434 _ => (),
435 }
436
437 probs = candle_nn::ops::softmax_last_dim(&(probs / self.temperature.unwrap_or(1.))?)?;
438
439 if top_k > 0 {
441 let sorted_values = probs.fast_sort_asc(D::Minus1)?;
442 let topk_values = sorted_values.narrow(
443 D::Minus1,
444 sorted_values.dim(D::Minus1)? - top_k as usize,
445 top_k as usize,
446 )?;
447
448 let threshold = topk_values.get_on_dim(D::Minus1, 0)?.unsqueeze(0)?;
450 let mask_topk = probs.broadcast_ge(&threshold)?;
451 probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
452 }
453
454 if top_p > 0.0 && top_p < 1.0 {
456 let sorted_probs = probs.fast_sort_asc(D::Minus1)?;
457
458 let cumsum = sorted_probs.fast_cumsum(D::Minus1)?;
459
460 let mask_topp = cumsum.le(top_p)?;
461
462 let masked_sorted =
463 mask_topp.where_cond(&sorted_probs, &Tensor::zeros_like(&sorted_probs)?)?;
464
465 let threshold = masked_sorted.max(D::Minus1)?;
466 let threshold = threshold.unsqueeze(D::Minus1)?;
467 let mask_full = probs.broadcast_ge(&threshold)?;
468 probs = mask_full.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
469 }
470
471 if min_p > 0.0 && min_p < 1.0 {
473 let max_vals = probs.max(D::Minus1)?;
474 let threshold_min = (max_vals.unsqueeze(D::Minus1)? * min_p)?;
475 let mask_minp = probs.broadcast_gt(&threshold_min)?;
476 probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
477 }
478
479 let log_probs = probs.log()?;
481 let gumbel = {
483 let mut guard = self.gumbel_cache.lock().unwrap();
484 if guard.is_none() {
485 let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
486 let noise = uniform
487 .clamp(1e-20, 1.0)?
488 .log()? .neg()? .log()? .neg()?; *guard = Some(noise);
493 }
494 guard.as_ref().unwrap().clone()
495 };
496
497 let gumbel_logits = (&log_probs + &gumbel)?;
498 let next_token = gumbel_logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
499
500 let (top_logprobs, logprob) = if return_logprobs {
502 let k = self.top_n_logprobs;
503
504 let sorted_values = probs.fast_sort_asc(D::Minus1)?;
505 let topk_values = sorted_values
506 .narrow(
507 D::Minus1,
508 sorted_values.dim(D::Minus1)? - top_k as usize,
509 top_k as usize,
510 )?
511 .to_vec1::<f32>()?;
512
513 let sorted_idxs = probs.fast_argsort_asc(D::Minus1)?;
514 let topk_idxs = sorted_idxs
515 .narrow(
516 D::Minus1,
517 sorted_values.dim(D::Minus1)? - top_k as usize,
518 top_k as usize,
519 )?
520 .to_vec1::<u32>()?;
521
522 let mut result = Vec::with_capacity(k);
523 if let Some(tokenizer) = &self.tokenizer {
524 for (prob, token) in topk_values.iter().zip(topk_idxs) {
525 let decoded = tokenizer
526 .decode(&[token], false)
527 .map_err(|e| Error::Msg(e.to_string()))?;
528 result.push(TopLogprob {
529 token,
530 logprob: prob.log(10.0),
531 bytes: Some(decoded),
532 });
533 }
534 } else {
535 for (prob, token) in topk_values.iter().zip(topk_idxs) {
536 result.push(TopLogprob {
537 token,
538 logprob: prob.log(10.0),
539 bytes: None,
540 });
541 }
542 }
543
544 let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
545
546 (Some(result), logprob)
547 } else {
548 (None, 1.)
549 };
550
551 let bytes = if let Some(tokenizer) = &self.tokenizer {
552 Some(
553 tokenizer
554 .decode(&[next_token], false)
555 .map_err(|x| Error::Msg(x.to_string()))?,
556 )
557 } else {
558 None
559 };
560
561 Ok(Logprobs {
562 token: next_token,
563 logprob,
564 top_logprobs,
565 bytes,
566 })
567 }
568 fn sample_speculative_top_kp_min_p(
569 &self,
570 logits: Tensor,
571 return_logprobs: bool,
572 top_k: i64,
573 top_p: f32,
574 min_p: f32,
575 ) -> Result<Logprobs> {
576 let mut probs: Vec<f32> = logits.to_vec1()?;
577
578 let k = if top_k > 0 {
580 top_k as usize
581 } else {
582 probs.len()
583 };
584
585 let idx_probs = partial_sort_top_k(&mut probs, k, true);
587
588 let mut cumsum = 0.;
595 for (index, prob) in &idx_probs {
596 if cumsum >= top_p {
597 probs[*index as usize] = 0.0;
598 } else {
599 cumsum += prob;
600 }
601 }
602
603 let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
605
606 let min_p_threshold = max_p * min_p;
612 for (index, prob) in &idx_probs {
613 if min_p_threshold >= *prob {
614 probs[*index as usize] = 0.0;
615 }
616 }
617
618 let next_token = argmax_f32(&probs);
620 let logprob = probs[next_token as usize].log(10.0);
621
622 let top_logprobs = if return_logprobs {
623 Some(self.get_top_logprobs(&probs)?)
624 } else {
625 None
626 };
627
628 let bytes = if let Some(tokenizer) = &self.tokenizer {
629 Some(
630 tokenizer
631 .decode(&[next_token], false)
632 .map_err(|x| Error::Msg(x.to_string()))?,
633 )
634 } else {
635 None
636 };
637
638 Ok(Logprobs {
639 token: next_token,
640 logprob,
641 top_logprobs,
642 bytes,
643 })
644 }
645
646 fn sample_multinomial(
647 &self,
648 probs: &[f32],
649 return_logprobs: bool,
650 rng: Arc<Mutex<Isaac64Rng>>,
651 ) -> Result<Logprobs> {
652 let distr = WeightedIndex::new(probs).map_err(Error::wrap)?;
653
654 let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
655 let next_token = distr.sample(&mut mut_ref_rng); let logprob = probs[next_token].log(10.0);
657
658 let top_logprobs = if return_logprobs {
659 Some(self.get_top_logprobs(probs)?)
660 } else {
661 None
662 };
663
664 let bytes = if let Some(tokenizer) = &self.tokenizer {
665 Some(
666 tokenizer
667 .decode(&[next_token.try_into().unwrap()], false)
668 .map_err(|x| Error::Msg(x.to_string()))?,
669 )
670 } else {
671 None
672 };
673
674 Ok(Logprobs {
675 token: next_token as u32,
676 logprob,
677 top_logprobs,
678 bytes,
679 })
680 }
681
682 #[allow(clippy::too_many_arguments)]
683 fn sample_top_kp_min_p(
684 &self,
685 probs: &mut [f32],
686 top_k: i64,
687 top_p: f32,
688 min_p: f32,
689 return_logprobs: bool,
690 rng: Arc<Mutex<Isaac64Rng>>,
691 ) -> Result<Logprobs> {
692 let k = if top_k > 0 {
694 top_k as usize
695 } else {
696 probs.len()
697 };
698
699 let idx_probs = partial_sort_top_k(probs, k, true);
701
702 if top_p <= 0.0 || top_p >= 1.0 {
703 return self.sample_multinomial(probs, return_logprobs, rng);
704 }
705
706 let mut cumsum = 0.;
714 for (index, prob) in &idx_probs {
715 if cumsum >= top_p {
716 probs[*index as usize] = 0.0;
717 } else {
718 cumsum += prob;
719 }
720 }
721
722 if min_p <= 0.0 || min_p >= 1.0 {
723 return self.sample_multinomial(probs, return_logprobs, rng);
724 }
725
726 let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
728
729 let min_p_threshold = max_p * min_p;
736 for (index, prob) in &idx_probs {
737 if min_p_threshold >= *prob {
738 probs[*index as usize] = 0.0;
739 }
740 }
741
742 self.sample_multinomial(probs, return_logprobs, rng)
744 }
745
746 fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
747 if context.is_empty() {
748 candle_core::bail!("Penalty context is empty, this should not happen.");
749 }
750
751 self.apply_dry_penalty(&mut logits, context)?;
753
754 self.apply_freq_pres_rep_penalty(&mut logits, context)?;
756
757 let vocab_size = logits.len();
758 Tensor::from_vec(logits, vocab_size, &Device::Cpu)
759 }
760
761 fn apply_freq_pres_rep_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
762 if self.frequency_penalty.is_some()
763 || self.presence_penalty.is_some()
764 || self.repetition_penalty.is_some()
765 {
766 let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
767 let presence_penalty = self.presence_penalty.unwrap_or(0.);
768 let repetition_penalty = self.repetition_penalty.unwrap_or(1.);
769
770 let mut counts = vec![0.0f32; logits.len()];
773 for ctx in context.iter() {
774 if *ctx as usize >= logits.len() {
776 continue;
777 }
778 counts[*ctx as usize] += 1.0;
779 }
780
781 for (token_id, logit) in logits.iter_mut().enumerate() {
782 let count = counts[token_id];
783 *logit = *logit
784 - count * frequency_penalty
785 - if count > 0.0 { 1. } else { 0. } * presence_penalty;
786
787 if repetition_penalty != 1.0 && count > 0.0 {
788 if *logit > 0.0 {
789 *logit /= repetition_penalty;
790 } else {
791 *logit *= repetition_penalty;
792 }
793 }
794 }
795 }
796 Ok(())
797 }
798
799 const DRY_PENALTY_PAR_THRESHOLD: usize = 1024;
802
803 fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
804 if let Some(ref params) = self.dry_params {
805 if params.multiplier == 0. {
806 return Ok(());
807 }
808
809 let last_token = *context.last().unwrap();
810
811 let match_indices: Vec<usize> = if context.len() > Self::DRY_PENALTY_PAR_THRESHOLD {
813 context
814 .par_iter()
815 .enumerate()
816 .take(context.len() - 1)
817 .filter(|(_i, x)| last_token == **x)
818 .map(|(i, _)| i)
819 .collect()
820 } else {
821 context
822 .iter()
823 .enumerate()
824 .take(context.len() - 1)
825 .filter(|(_i, x)| last_token == **x)
826 .map(|(i, _)| i)
827 .collect()
828 };
829
830 let mut match_lengths = HashMap::new();
831
832 for i in match_indices {
833 let next_token = context[i + 1];
834
835 if params.sequence_breakers.contains(&next_token) {
836 continue;
837 }
838
839 let mut match_length = 1;
840
841 while match_length < 50 {
843 if match_length > i {
844 break;
846 }
847
848 let j = i - match_length;
849
850 let prev_tok = context[context.len() - (match_length + 1)];
851 if context[j] != prev_tok {
852 break;
854 }
855
856 if params.sequence_breakers.contains(&prev_tok) {
857 break;
859 }
860
861 match_length += 1;
862 }
863
864 #[allow(clippy::map_entry)]
865 if match_lengths.contains_key(&next_token) {
866 match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
867 } else {
868 match_lengths.insert(next_token, match_length);
869 }
870 }
871
872 for (tok, match_len) in match_lengths {
874 if match_len >= params.allowed_length {
875 if tok as usize >= logits.len() {
877 continue;
878 }
879 let penalty = params.multiplier
880 * params.base.powf((match_len - params.allowed_length) as f32);
881 logits[tok as usize] -= penalty;
882 }
883 }
884 }
885 Ok(())
886 }
887
888 #[allow(unused)]
889 pub fn sample(
894 &self,
895 logits: Tensor,
896 context: &[u32],
897 return_logprobs: bool,
898 rng: Arc<Mutex<Isaac64Rng>>,
899 sample_speculative: bool,
900 multiple_sequences: bool,
901 ) -> Result<Logprobs> {
902 let logits = logits.to_vec1()?;
914 let mut logits = self.apply_penalties(logits, context)?;
915 for processor in &self.logits_processors {
916 logits = processor.apply(&logits, context)?;
917 }
918 let next_token = if sample_speculative {
919 match self.temperature {
920 None => self.sample_speculative_top_kp_min_p(
921 logits,
922 return_logprobs,
923 self.top_k,
924 self.top_p as f32,
925 self.min_p as f32,
926 )?,
927 Some(temperature) => {
928 let logits = (&logits / temperature)?;
929 let probs = candle_nn::ops::softmax_last_dim(&logits)?;
930
931 self.sample_speculative_top_kp_min_p(
932 probs,
933 return_logprobs,
934 self.top_k,
935 self.top_p as f32,
936 self.min_p as f32,
937 )?
938 }
939 }
940 } else {
941 match self.temperature {
942 None => self.sample_argmax(logits, return_logprobs)?,
943 Some(temperature) => {
944 let logits = (&logits / temperature)?;
945 let probs = candle_nn::ops::softmax_last_dim(&logits)?;
946 let mut probs: Vec<f32> = probs.to_vec1()?;
947
948 self.sample_top_kp_min_p(
949 &mut probs,
950 self.top_k,
951 self.top_p as f32,
952 self.min_p as f32,
953 return_logprobs,
954 rng,
955 )?
956 }
957 }
958 };
959 Ok(next_token)
960 }
961}
962
963mod tests {
964 #[test]
965 fn test_argmax() {
966 use super::Sampler;
967 use candle_core::{Device, Tensor};
968 use rand::SeedableRng;
969 use rand_isaac::Isaac64Rng;
970 use std::sync::Arc;
971 use std::sync::Mutex;
972
973 let sampler = Sampler::new(
974 None,
975 10,
976 None,
977 None,
978 None,
979 None,
980 None,
981 32,
982 0.1,
983 0.05,
984 vec![],
985 )
986 .unwrap();
987 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
988 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
989 let res = sampler
990 .sample(
991 logits,
992 &(0..1024).collect::<Vec<_>>(),
993 false,
994 rng,
995 false,
996 false,
997 )
998 .unwrap();
999 assert_eq!(res.token, 1023);
1000 assert_eq!(res.top_logprobs, None);
1001 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
1002 }
1003
1004 #[test]
1005 fn test_gumbel_speculative() {
1006 use super::Sampler;
1007 use candle_core::{Device, Tensor};
1008 use rand::SeedableRng;
1009 use rand_isaac::Isaac64Rng;
1010 use std::sync::Arc;
1011 use std::sync::Mutex;
1012
1013 let sampler = Sampler::new(
1014 None,
1015 10,
1016 None,
1017 None,
1018 None,
1019 None,
1020 None,
1021 32,
1022 0.1,
1023 0.05,
1024 vec![],
1025 )
1026 .unwrap();
1027 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
1028 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
1029 let res = sampler
1030 .sample(
1031 logits,
1032 &(0..1024).collect::<Vec<_>>(),
1033 false,
1034 rng,
1035 true,
1036 false,
1037 )
1038 .unwrap();
1039 assert_eq!(res.token, 1023);
1040 assert_eq!(res.top_logprobs, None);
1041 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
1042 }
1043}