1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4 collections::{HashMap, HashSet},
5 sync::{Arc, Mutex},
6};
7
8use candle_core::{DType, Device, Error, Result, Tensor, D};
9use mistralrs_quant::{CumSumOp, SortOp};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use once_cell::sync::Lazy;
14use rand::distr::{weighted::WeightedIndex, Distribution};
15use rand_isaac::Isaac64Rng;
16use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
17use serde::{Deserialize, Serialize};
18use tokenizers::Tokenizer;
19
20static DRY_SEQUENCE_BREAKERS: Lazy<Vec<String>> =
21 Lazy::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
24pub enum StopTokens {
26 Seqs(Vec<String>),
27 Ids(Vec<u32>),
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct SamplingParams {
33 pub temperature: Option<f64>,
34 pub top_k: Option<usize>,
35 pub top_p: Option<f64>,
36 pub min_p: Option<f64>,
37 pub top_n_logprobs: usize,
38 pub frequency_penalty: Option<f32>,
39 pub presence_penalty: Option<f32>,
40 pub stop_toks: Option<StopTokens>,
41 pub max_len: Option<usize>,
42 pub logits_bias: Option<HashMap<u32, f32>>,
43 pub n_choices: usize,
44 pub dry_params: Option<DrySamplingParams>,
45}
46
47impl SamplingParams {
48 pub fn deterministic() -> Self {
53 Self {
54 temperature: None,
55 top_k: Some(1),
56 top_p: None,
57 min_p: None,
58 top_n_logprobs: 0,
59 frequency_penalty: None,
60 presence_penalty: None,
61 stop_toks: None,
62 max_len: None,
63 logits_bias: None,
64 n_choices: 1,
65 dry_params: None,
66 }
67 }
68}
69
70#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct DrySamplingParams {
72 pub sequence_breakers: Vec<String>,
73 pub multiplier: f32,
74 pub base: f32,
75 pub allowed_length: usize,
76}
77
78impl DrySamplingParams {
79 pub fn new_with_defaults(
80 multiplier: f32,
81 sequence_breakers: Option<Vec<String>>,
82 base: Option<f32>,
83 allowed_length: Option<usize>,
84 ) -> anyhow::Result<Self> {
85 Ok(Self {
86 base: base.unwrap_or(1.75),
87 allowed_length: allowed_length.unwrap_or(2),
88 sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
89 multiplier,
90 })
91 }
92}
93
94impl Default for DrySamplingParams {
95 fn default() -> Self {
96 Self {
97 multiplier: 0.0,
98 base: 1.75,
99 allowed_length: 2,
100 sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
101 }
102 }
103}
104
105#[derive(Clone, Debug)]
106struct DrySamplingParamsInner {
107 pub sequence_breakers: HashSet<u32>,
108 pub multiplier: f32,
109 pub base: f32,
110 pub allowed_length: usize,
111}
112
113impl DrySamplingParamsInner {
114 pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
115 Ok(Self {
116 base: other.base,
117 allowed_length: other.allowed_length,
118 sequence_breakers: HashSet::from_iter(
119 other
120 .sequence_breakers
121 .into_iter()
122 .map(|breaker| {
123 tokenizer
124 .encode_fast(["a", &breaker].concat(), true)
130 .map_err(anyhow::Error::msg)
131 .map(|enc| {
132 let ids = enc.get_ids();
133 if !ids.is_empty() {
134 None
135 } else {
136 Some(ids[ids.len() - 1])
137 }
138 })
139 })
140 .collect::<anyhow::Result<Vec<_>>>()?
141 .into_iter()
142 .flatten()
143 .collect::<Vec<_>>(),
144 ),
145 multiplier: other.multiplier,
146 })
147 }
148}
149
150pub trait CustomLogitsProcessor: Send + Sync {
170 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
172}
173
174impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
175 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
176 self(logits, context)
177 }
178}
179
180#[derive(Clone)]
182pub struct Sampler {
183 temperature: Option<f64>,
184 top_n_logprobs: usize,
185 tokenizer: Option<Arc<Tokenizer>>,
186 frequency_penalty: Option<f32>,
187 presence_penalty: Option<f32>,
188 dry_params: Option<DrySamplingParamsInner>,
189 top_k: i64,
190 top_p: f64,
191 min_p: f64,
192 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
193 gumbel_cache: Arc<Mutex<Option<Tensor>>>,
195}
196
197#[cfg_attr(feature = "pyo3_macros", pyclass)]
198#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
199#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
200pub struct TopLogprob {
202 pub token: u32,
203 pub logprob: f32,
204 pub bytes: Option<String>,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct Logprobs {
209 pub token: u32,
210 pub logprob: f32,
211 pub bytes: Option<String>,
212 pub top_logprobs: Option<Vec<TopLogprob>>,
213}
214
215fn argmax_sample_last_dim(logits: &Tensor) -> Result<Tensor> {
216 logits.argmax(D::Minus1)
217}
218
219impl Sampler {
220 #[allow(clippy::too_many_arguments)]
221 pub fn new(
222 temperature: Option<f64>,
223 top_n_logprobs: usize,
224 tokenizer: Option<Arc<Tokenizer>>,
225 frequency_penalty: Option<f32>,
226 presence_penalty: Option<f32>,
227 dry_params: Option<DrySamplingParams>,
228 top_k: i64,
229 top_p: f64,
230 min_p: f64,
231 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
232 ) -> anyhow::Result<Self> {
233 let temperature = if temperature.is_none_or(|v| v < 1e-7) {
234 None
235 } else {
236 temperature
237 };
238 let dry_params = if let Some(ref tokenizer) = tokenizer {
239 dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
240 } else {
241 None
242 };
243 let dry_params = match dry_params {
244 Some(fallible) => Some(fallible?),
245 None => None,
246 };
247 Ok(Self {
248 temperature,
249 top_n_logprobs,
250 tokenizer,
251 frequency_penalty,
252 presence_penalty,
253 dry_params,
254 top_k,
255 top_p,
256 min_p,
257 logits_processors,
258 gumbel_cache: Arc::new(Mutex::new(None)),
259 })
260 }
261
262 fn get_top_logprobs(&self, probs: &[f32], _argsort_indices: &[u32]) -> Result<Vec<TopLogprob>> {
263 let k = self.top_n_logprobs.min(probs.len());
265 if k == 0 {
266 return Ok(Vec::new());
267 }
268 let mut idx_probs: Vec<(u32, f32)> = (0..probs.len() as u32)
270 .map(|i| (i, probs[i as usize]))
271 .collect();
272 let (top_k_slice, _, _) = idx_probs.select_nth_unstable_by(k, |a, b| {
274 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
275 });
276 let mut top_k: Vec<(u32, f32)> = top_k_slice.to_vec();
278 top_k.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
279 let mut result = Vec::with_capacity(k);
281 if let Some(tokenizer) = &self.tokenizer {
282 for (token, prob) in top_k {
283 let decoded = tokenizer
284 .decode(&[token], false)
285 .map_err(|e| Error::Msg(e.to_string()))?;
286 result.push(TopLogprob {
287 token,
288 logprob: prob.log(10.0),
289 bytes: Some(decoded),
290 });
291 }
292 } else {
293 for (token, prob) in top_k {
294 result.push(TopLogprob {
295 token,
296 logprob: prob.log(10.0),
297 bytes: None,
298 });
299 }
300 }
301 Ok(result)
302 }
303
304 fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
305 let next_token = logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
306
307 let probs: Vec<f32> = logits.to_vec1()?;
308
309 let argsort_indices = (0..probs.len() as u32).collect::<Vec<_>>();
310 let logprob = probs[next_token as usize].log(10.0);
311
312 let top_logprobs = if return_logprobs {
313 Some(self.get_top_logprobs(&probs, &argsort_indices)?)
314 } else {
315 None
316 };
317
318 let bytes = if let Some(tokenizer) = &self.tokenizer {
319 Some(
320 tokenizer
321 .decode(&[next_token], false)
322 .map_err(|x| Error::Msg(x.to_string()))?,
323 )
324 } else {
325 None
326 };
327
328 Ok(Logprobs {
329 token: next_token,
330 logprob,
331 top_logprobs,
332 bytes,
333 })
334 }
335
336 #[allow(unused)]
337 fn sample_fast(
338 &self,
339 logits: Tensor,
340 context: &[u32],
341 return_logprobs: bool,
342 top_k: i64,
343 top_p: f64,
344 min_p: f64,
345 ) -> Result<Logprobs> {
346 let mut probs = logits.to_dtype(DType::F32)?;
347
348 for processor in &self.logits_processors {
349 probs = processor.apply(&probs, context)?;
350 }
351
352 let context = Tensor::new(context, logits.device())?;
353 let mut counts = logits.zeros_like()?;
354 counts = counts.scatter_add(
355 &context,
356 &context.ones_like()?.to_dtype(counts.dtype())?,
357 D::Minus1,
358 )?;
359
360 let presence = counts
361 .gt(0.)?
362 .where_cond(&counts.ones_like()?, &counts.zeros_like()?)?;
363
364 match self.frequency_penalty {
365 Some(freq_penalty) if freq_penalty != 0. => {
366 probs = (probs - (freq_penalty as f64 * counts)?)?;
367 }
368 _ => (),
369 }
370
371 match self.presence_penalty {
372 Some(pres_penalty) if pres_penalty != 0. => {
373 probs = (probs - (pres_penalty as f64 * presence)?)?;
374 }
375 _ => (),
376 }
377
378 probs = candle_nn::ops::softmax_last_dim(&(probs / self.temperature.unwrap_or(1.))?)?;
379
380 if top_k > 0 {
382 let sorted_values = probs.fast_sort_asc(D::Minus1)?;
383 let topk_values = sorted_values.narrow(
384 D::Minus1,
385 sorted_values.dim(D::Minus1)? - top_k as usize,
386 top_k as usize,
387 )?;
388
389 let threshold = topk_values.get_on_dim(D::Minus1, 0)?.unsqueeze(0)?;
391 let mask_topk = probs.broadcast_ge(&threshold)?;
392 probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
393 }
394
395 if top_p > 0.0 && top_p < 1.0 {
397 let sorted_probs = probs.fast_sort_asc(D::Minus1)?;
398
399 let cumsum = sorted_probs.fast_cumsum(D::Minus1)?;
400
401 let mask_topp = cumsum.le(top_p)?;
402
403 let masked_sorted =
404 mask_topp.where_cond(&sorted_probs, &Tensor::zeros_like(&sorted_probs)?)?;
405
406 let threshold = masked_sorted.max(D::Minus1)?;
407 let threshold = threshold.unsqueeze(D::Minus1)?;
408 let mask_full = probs.broadcast_ge(&threshold)?;
409 probs = mask_full.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
410 }
411
412 if min_p > 0.0 && min_p < 1.0 {
414 let max_vals = probs.max(D::Minus1)?;
415 let threshold_min = (max_vals.unsqueeze(D::Minus1)? * min_p)?;
416 let mask_minp = probs.broadcast_gt(&threshold_min)?;
417 probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
418 }
419
420 let log_probs = probs.log()?;
422 let gumbel = {
424 let mut guard = self.gumbel_cache.lock().unwrap();
425 if guard.is_none() {
426 let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
427 let noise = uniform
428 .clamp(1e-20, 1.0)?
429 .log()? .neg()? .log()? .neg()?; *guard = Some(noise);
434 }
435 guard.as_ref().unwrap().clone()
436 };
437
438 let gumbel_logits = (&log_probs + &gumbel)?;
439 let next_token = gumbel_logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
440
441 let (top_logprobs, logprob) = if return_logprobs {
443 let k = self.top_n_logprobs;
444
445 let sorted_values = probs.fast_sort_asc(D::Minus1)?;
446 let topk_values = sorted_values
447 .narrow(
448 D::Minus1,
449 sorted_values.dim(D::Minus1)? - top_k as usize,
450 top_k as usize,
451 )?
452 .to_vec1::<f32>()?;
453
454 let sorted_idxs = probs.fast_argsort_asc(D::Minus1)?;
455 let topk_idxs = sorted_idxs
456 .narrow(
457 D::Minus1,
458 sorted_values.dim(D::Minus1)? - top_k as usize,
459 top_k as usize,
460 )?
461 .to_vec1::<u32>()?;
462
463 let mut result = Vec::with_capacity(k);
464 if let Some(tokenizer) = &self.tokenizer {
465 for (prob, token) in topk_values.iter().zip(topk_idxs) {
466 let decoded = tokenizer
467 .decode(&[token], false)
468 .map_err(|e| Error::Msg(e.to_string()))?;
469 result.push(TopLogprob {
470 token,
471 logprob: prob.log(10.0),
472 bytes: Some(decoded),
473 });
474 }
475 } else {
476 for (prob, token) in topk_values.iter().zip(topk_idxs) {
477 result.push(TopLogprob {
478 token,
479 logprob: prob.log(10.0),
480 bytes: None,
481 });
482 }
483 }
484
485 let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
486
487 (Some(result), logprob)
488 } else {
489 (None, 1.)
490 };
491
492 let bytes = if let Some(tokenizer) = &self.tokenizer {
493 Some(
494 tokenizer
495 .decode(&[next_token], false)
496 .map_err(|x| Error::Msg(x.to_string()))?,
497 )
498 } else {
499 None
500 };
501
502 Ok(Logprobs {
503 token: next_token,
504 logprob,
505 top_logprobs,
506 bytes,
507 })
508 }
509 fn sample_speculative_top_kp_min_p(
510 &self,
511 logits: Tensor,
512 return_logprobs: bool,
513 top_k: i64,
514 top_p: f32,
515 min_p: f32,
516 ) -> Result<Logprobs> {
517 let mut probs: Vec<f32> = logits.to_vec1()?;
518 let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
519
520 if top_k > 0 {
521 for (index, val) in argsort_indices.iter().enumerate() {
523 if index >= top_k as usize {
524 probs[*val as usize] = 0.0;
525 }
526 }
527 }
528
529 let mut cumsum = 0.;
537 for index in &argsort_indices {
538 if cumsum >= top_p {
539 probs[*index as usize] = 0.0;
540 } else {
541 cumsum += probs[*index as usize];
542 }
543 }
544
545 let max_p = probs[argsort_indices[0] as usize];
546
547 for index in &argsort_indices {
554 if max_p * min_p >= probs[*index as usize] {
555 probs[*index as usize] = 0.0;
556 }
557 }
558
559 let logits = Tensor::from_slice(&probs, logits.shape(), &Device::Cpu)?;
560
561 let next_token = argmax_sample_last_dim(&logits)?.to_scalar::<u32>()?;
562
563 let logprob = probs[next_token as usize].log(10.0);
564
565 let top_logprobs = if return_logprobs {
566 Some(self.get_top_logprobs(&probs, &argsort_indices)?)
567 } else {
568 None
569 };
570
571 let bytes = if let Some(tokenizer) = &self.tokenizer {
572 Some(
573 tokenizer
574 .decode(&[next_token], false)
575 .map_err(|x| Error::Msg(x.to_string()))?,
576 )
577 } else {
578 None
579 };
580
581 Ok(Logprobs {
582 token: next_token,
583 logprob,
584 top_logprobs,
585 bytes,
586 })
587 }
588
589 fn sample_multinomial(
590 &self,
591 probs: &mut Vec<f32>,
592 argsort_indices: Vec<u32>,
593 return_logprobs: bool,
594 rng: Arc<Mutex<Isaac64Rng>>,
595 ) -> Result<Logprobs> {
596 let distr = WeightedIndex::new(&*probs).map_err(Error::wrap)?;
597
598 let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
599 let next_token = distr.sample(&mut mut_ref_rng); let logprob = probs[next_token].log(10.0);
601
602 let top_logprobs = if return_logprobs {
603 Some(self.get_top_logprobs(probs, &argsort_indices)?)
604 } else {
605 None
606 };
607
608 let bytes = if let Some(tokenizer) = &self.tokenizer {
609 Some(
610 tokenizer
611 .decode(&[next_token.try_into().unwrap()], false)
612 .map_err(|x| Error::Msg(x.to_string()))?,
613 )
614 } else {
615 None
616 };
617
618 Ok(Logprobs {
619 token: next_token as u32,
620 logprob,
621 top_logprobs,
622 bytes,
623 })
624 }
625
626 #[allow(clippy::too_many_arguments)]
627 fn sample_top_kp_min_p(
628 &self,
629 probs: &mut Vec<f32>,
630 logits: &Tensor,
631 top_k: i64,
632 top_p: f32,
633 min_p: f32,
634 return_logprobs: bool,
635 rng: Arc<Mutex<Isaac64Rng>>,
636 ) -> Result<Logprobs> {
637 let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
638
639 if top_k > 0 {
640 for (index, val) in argsort_indices.iter().enumerate() {
642 if index >= top_k as usize {
643 probs[*val as usize] = 0.0;
644 }
645 }
646 }
647
648 if top_p <= 0.0 || top_p >= 1.0 {
649 return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
650 }
651
652 let mut cumsum = 0.;
660 for index in &argsort_indices {
661 if cumsum >= top_p {
662 probs[*index as usize] = 0.0;
663 } else {
664 cumsum += probs[*index as usize];
665 }
666 }
667
668 if min_p <= 0.0 || min_p >= 1.0 {
669 return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
670 }
671
672 let max_p = probs[argsort_indices[0] as usize];
673
674 for index in &argsort_indices {
681 if max_p * min_p >= probs[*index as usize] {
682 probs[*index as usize] = 0.0;
683 }
684 }
685
686 self.sample_multinomial(probs, argsort_indices, return_logprobs, rng)
688 }
689
690 fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
691 if context.is_empty() {
692 candle_core::bail!("Penalty context is empty, this should not happen.");
693 }
694
695 self.apply_dry_penalty(&mut logits, context)?;
697
698 self.apply_freq_presc_penalty(&mut logits, context)?;
700
701 let vocab_size = logits.len();
702 Tensor::from_vec(logits, vocab_size, &Device::Cpu)
703 }
704
705 fn apply_freq_presc_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
706 if self.frequency_penalty.is_some() || self.presence_penalty.is_some() {
707 let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
708 let presence_penalty = self.presence_penalty.unwrap_or(0.);
709
710 let mut counts = vec![0.0f32; logits.len()];
713 for ctx in context.iter() {
714 if *ctx as usize >= logits.len() {
716 continue;
717 }
718 counts[*ctx as usize] += 1.0;
719 }
720
721 for (token_id, logit) in logits.iter_mut().enumerate() {
722 let count = counts[token_id];
723 *logit = *logit
724 - count * frequency_penalty
725 - if count > 0.0 { 1. } else { 0. } * presence_penalty;
726 }
727 }
728 Ok(())
729 }
730
731 fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
732 if let Some(ref params) = self.dry_params {
733 if params.multiplier == 0. {
734 return Ok(());
735 }
736
737 let match_indices = context
738 .par_iter()
739 .enumerate()
740 .take(context.len() - 1)
741 .filter(|(_i, x)| *context.last().unwrap() == **x)
742 .map(|(i, _)| i)
743 .collect::<Vec<_>>();
744
745 let mut match_lengths = HashMap::new();
746
747 for i in match_indices {
748 let next_token = context[i + 1];
749
750 if params.sequence_breakers.contains(&next_token) {
751 continue;
752 }
753
754 let mut match_length = 1;
755
756 while match_length < 50 {
758 if match_length > i {
759 break;
761 }
762
763 let j = i - match_length;
764
765 let prev_tok = context[context.len() - (match_length + 1)];
766 if context[j] != prev_tok {
767 break;
769 }
770
771 if params.sequence_breakers.contains(&prev_tok) {
772 break;
774 }
775
776 match_length += 1;
777 }
778
779 #[allow(clippy::map_entry)]
780 if match_lengths.contains_key(&next_token) {
781 match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
782 } else {
783 match_lengths.insert(next_token, match_length);
784 }
785 }
786
787 for (tok, match_len) in match_lengths {
789 if match_len >= params.allowed_length {
790 if tok as usize >= logits.len() {
792 continue;
793 }
794 let penalty = params.multiplier
795 * params.base.powf((match_len - params.allowed_length) as f32);
796 logits[tok as usize] -= penalty;
797 }
798 }
799 }
800 Ok(())
801 }
802
803 #[allow(unused)]
804 pub fn sample(
809 &self,
810 logits: Tensor,
811 context: &[u32],
812 return_logprobs: bool,
813 rng: Arc<Mutex<Isaac64Rng>>,
814 sample_speculative: bool,
815 multiple_sequences: bool,
816 ) -> Result<Logprobs> {
817 if cfg!(feature = "metal") && !multiple_sequences {
818 return self.sample_fast(
819 logits,
820 context,
821 return_logprobs,
822 self.top_k,
823 self.top_p,
824 self.min_p,
825 );
826 }
827
828 let logits = logits.to_vec1()?;
829 let mut logits = self.apply_penalties(logits, context)?;
830 for processor in &self.logits_processors {
831 logits = processor.apply(&logits, context)?;
832 }
833 let next_token = if sample_speculative {
834 match self.temperature {
835 None => self.sample_speculative_top_kp_min_p(
836 logits,
837 return_logprobs,
838 self.top_k,
839 self.top_p as f32,
840 self.min_p as f32,
841 )?,
842 Some(temperature) => {
843 let logits = (&logits / temperature)?;
844 let probs = candle_nn::ops::softmax_last_dim(&logits)?;
845
846 self.sample_speculative_top_kp_min_p(
847 probs,
848 return_logprobs,
849 self.top_k,
850 self.top_p as f32,
851 self.min_p as f32,
852 )?
853 }
854 }
855 } else {
856 match self.temperature {
857 None => self.sample_argmax(logits, return_logprobs)?,
858 Some(temperature) => {
859 let logits = (&logits / temperature)?;
860 let logits = candle_nn::ops::softmax_last_dim(&logits)?;
861 let mut probs: Vec<f32> = logits.to_vec1()?;
862
863 self.sample_top_kp_min_p(
864 &mut probs,
865 &logits,
866 self.top_k,
867 self.top_p as f32,
868 self.min_p as f32,
869 return_logprobs,
870 rng,
871 )?
872 }
873 }
874 };
875 Ok(next_token)
876 }
877}
878
879mod tests {
880 #[test]
881 fn test_argmax() {
882 use super::Sampler;
883 use candle_core::{Device, Tensor};
884 use rand::SeedableRng;
885 use rand_isaac::Isaac64Rng;
886 use std::sync::Arc;
887 use std::sync::Mutex;
888
889 let sampler =
890 Sampler::new(None, 10, None, None, None, None, 32, 0.1, 0.05, vec![]).unwrap();
891 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
892 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
893 let res = sampler
894 .sample(
895 logits,
896 &(0..1024).collect::<Vec<_>>(),
897 false,
898 rng,
899 false,
900 false,
901 )
902 .unwrap();
903 assert_eq!(res.token, 1023);
904 assert_eq!(res.top_logprobs, None);
905 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
906 }
907
908 #[test]
909 fn test_gumbel_speculative() {
910 use super::Sampler;
911 use candle_core::{Device, Tensor};
912 use rand::SeedableRng;
913 use rand_isaac::Isaac64Rng;
914 use std::sync::Arc;
915 use std::sync::Mutex;
916
917 let sampler =
918 Sampler::new(None, 10, None, None, None, None, 32, 0.1, 0.05, vec![]).unwrap();
919 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
920 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
921 let res = sampler
922 .sample(
923 logits,
924 &(0..1024).collect::<Vec<_>>(),
925 false,
926 rng,
927 true,
928 false,
929 )
930 .unwrap();
931 assert_eq!(res.token, 1023);
932 assert_eq!(res.top_logprobs, None);
933 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
934 }
935}