1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4 collections::{HashMap, HashSet},
5 sync::{Arc, LazyLock, Mutex},
6};
7
8use candle_core::{DType, Device, Error, Result, Tensor, D};
9use mistralrs_quant::{CumSumOp, SortOp};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use rand::distr::{weighted::WeightedIndex, Distribution};
14use rand_isaac::Isaac64Rng;
15use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
16use serde::{Deserialize, Serialize};
17use tokenizers::Tokenizer;
18
19static DRY_SEQUENCE_BREAKERS: LazyLock<Vec<String>> =
20 LazyLock::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
23pub enum StopTokens {
25 Seqs(Vec<String>),
26 Ids(Vec<u32>),
27}
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct SamplingParams {
32 pub temperature: Option<f64>,
33 pub top_k: Option<usize>,
34 pub top_p: Option<f64>,
35 pub min_p: Option<f64>,
36 pub top_n_logprobs: usize,
37 pub frequency_penalty: Option<f32>,
38 pub presence_penalty: Option<f32>,
39 pub repetition_penalty: Option<f32>,
40 pub stop_toks: Option<StopTokens>,
41 pub max_len: Option<usize>,
42 pub logits_bias: Option<HashMap<u32, f32>>,
43 pub n_choices: usize,
44 pub dry_params: Option<DrySamplingParams>,
45}
46
47impl SamplingParams {
48 pub fn deterministic() -> Self {
53 Self {
54 temperature: None,
55 top_k: Some(1),
56 top_p: None,
57 min_p: None,
58 top_n_logprobs: 0,
59 frequency_penalty: None,
60 presence_penalty: None,
61 repetition_penalty: None,
62 stop_toks: None,
63 max_len: None,
64 logits_bias: None,
65 n_choices: 1,
66 dry_params: None,
67 }
68 }
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize)]
72pub struct DrySamplingParams {
73 pub sequence_breakers: Vec<String>,
74 pub multiplier: f32,
75 pub base: f32,
76 pub allowed_length: usize,
77}
78
79impl DrySamplingParams {
80 pub fn new_with_defaults(
81 multiplier: f32,
82 sequence_breakers: Option<Vec<String>>,
83 base: Option<f32>,
84 allowed_length: Option<usize>,
85 ) -> anyhow::Result<Self> {
86 Ok(Self {
87 base: base.unwrap_or(1.75),
88 allowed_length: allowed_length.unwrap_or(2),
89 sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
90 multiplier,
91 })
92 }
93}
94
95impl Default for DrySamplingParams {
96 fn default() -> Self {
97 Self {
98 multiplier: 0.0,
99 base: 1.75,
100 allowed_length: 2,
101 sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
102 }
103 }
104}
105
106#[derive(Clone, Debug)]
107struct DrySamplingParamsInner {
108 pub sequence_breakers: HashSet<u32>,
109 pub multiplier: f32,
110 pub base: f32,
111 pub allowed_length: usize,
112}
113
114impl DrySamplingParamsInner {
115 pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
116 Ok(Self {
117 base: other.base,
118 allowed_length: other.allowed_length,
119 sequence_breakers: HashSet::from_iter(
120 other
121 .sequence_breakers
122 .into_iter()
123 .map(|breaker| {
124 tokenizer
125 .encode_fast(["a", &breaker].concat(), true)
131 .map_err(anyhow::Error::msg)
132 .map(|enc| {
133 let ids = enc.get_ids();
134 if !ids.is_empty() {
135 Some(ids[ids.len() - 1])
136 } else {
137 None
138 }
139 })
140 })
141 .collect::<anyhow::Result<Vec<_>>>()?
142 .into_iter()
143 .flatten()
144 .collect::<Vec<_>>(),
145 ),
146 multiplier: other.multiplier,
147 })
148 }
149}
150
151pub trait CustomLogitsProcessor: Send + Sync {
171 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
173}
174
175impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
176 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
177 self(logits, context)
178 }
179}
180
181#[derive(Clone)]
183pub struct Sampler {
184 temperature: Option<f64>,
185 top_n_logprobs: usize,
186 tokenizer: Option<Arc<Tokenizer>>,
187 frequency_penalty: Option<f32>,
188 presence_penalty: Option<f32>,
189 repetition_penalty: Option<f32>,
190 dry_params: Option<DrySamplingParamsInner>,
191 top_k: i64,
192 top_p: f64,
193 min_p: f64,
194 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
195 gumbel_cache: Arc<Mutex<Option<Tensor>>>,
197}
198
199#[cfg_attr(feature = "pyo3_macros", pyclass)]
200#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
201#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
202pub struct TopLogprob {
204 pub token: u32,
205 pub logprob: f32,
206 pub bytes: Option<String>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct Logprobs {
211 pub token: u32,
212 pub logprob: f32,
213 pub bytes: Option<String>,
214 pub top_logprobs: Option<Vec<TopLogprob>>,
215}
216
217fn argmax_sample_last_dim(logits: &Tensor) -> Result<Tensor> {
218 logits.argmax(D::Minus1)
219}
220
221impl Sampler {
222 #[allow(clippy::too_many_arguments)]
223 pub fn new(
224 temperature: Option<f64>,
225 top_n_logprobs: usize,
226 tokenizer: Option<Arc<Tokenizer>>,
227 frequency_penalty: Option<f32>,
228 presence_penalty: Option<f32>,
229 repetition_penalty: Option<f32>,
230 dry_params: Option<DrySamplingParams>,
231 top_k: i64,
232 top_p: f64,
233 min_p: f64,
234 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
235 ) -> anyhow::Result<Self> {
236 let temperature = if temperature.is_none_or(|v| v < 1e-7) {
237 None
238 } else {
239 temperature
240 };
241 let dry_params = if let Some(ref tokenizer) = tokenizer {
242 dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
243 } else {
244 None
245 };
246 let dry_params = match dry_params {
247 Some(fallible) => Some(fallible?),
248 None => None,
249 };
250 Ok(Self {
251 temperature,
252 top_n_logprobs,
253 tokenizer,
254 frequency_penalty,
255 presence_penalty,
256 repetition_penalty,
257 dry_params,
258 top_k,
259 top_p,
260 min_p,
261 logits_processors,
262 gumbel_cache: Arc::new(Mutex::new(None)),
263 })
264 }
265
266 fn get_top_logprobs(&self, probs: &[f32], _argsort_indices: &[u32]) -> Result<Vec<TopLogprob>> {
267 let k = self.top_n_logprobs.min(probs.len());
269 if k == 0 {
270 return Ok(Vec::new());
271 }
272 let mut idx_probs: Vec<(u32, f32)> = (0..probs.len() as u32)
274 .map(|i| (i, probs[i as usize]))
275 .collect();
276 let (top_k_slice, _, _) = idx_probs.select_nth_unstable_by(k, |a, b| {
278 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
279 });
280 let mut top_k: Vec<(u32, f32)> = top_k_slice.to_vec();
282 top_k.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
283 let mut result = Vec::with_capacity(k);
285 if let Some(tokenizer) = &self.tokenizer {
286 for (token, prob) in top_k {
287 let decoded = tokenizer
288 .decode(&[token], false)
289 .map_err(|e| Error::Msg(e.to_string()))?;
290 result.push(TopLogprob {
291 token,
292 logprob: prob.log(10.0),
293 bytes: Some(decoded),
294 });
295 }
296 } else {
297 for (token, prob) in top_k {
298 result.push(TopLogprob {
299 token,
300 logprob: prob.log(10.0),
301 bytes: None,
302 });
303 }
304 }
305 Ok(result)
306 }
307
308 fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
309 let next_token = logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
310
311 let probs: Vec<f32> = logits.to_vec1()?;
312
313 let argsort_indices = (0..probs.len() as u32).collect::<Vec<_>>();
314 let logprob = probs[next_token as usize].log(10.0);
315
316 let top_logprobs = if return_logprobs {
317 Some(self.get_top_logprobs(&probs, &argsort_indices)?)
318 } else {
319 None
320 };
321
322 let bytes = if let Some(tokenizer) = &self.tokenizer {
323 Some(
324 tokenizer
325 .decode(&[next_token], false)
326 .map_err(|x| Error::Msg(x.to_string()))?,
327 )
328 } else {
329 None
330 };
331
332 Ok(Logprobs {
333 token: next_token,
334 logprob,
335 top_logprobs,
336 bytes,
337 })
338 }
339
340 #[allow(unused)]
341 fn sample_fast(
342 &self,
343 logits: Tensor,
344 context: &[u32],
345 return_logprobs: bool,
346 top_k: i64,
347 top_p: f64,
348 min_p: f64,
349 ) -> Result<Logprobs> {
350 let mut probs = logits.to_dtype(DType::F32)?;
351
352 for processor in &self.logits_processors {
353 probs = processor.apply(&probs, context)?;
354 }
355
356 let context = Tensor::new(context, logits.device())?;
357 let mut counts = logits.zeros_like()?;
358 counts = counts.scatter_add(
359 &context,
360 &context.ones_like()?.to_dtype(counts.dtype())?,
361 D::Minus1,
362 )?;
363
364 let presence = counts
365 .gt(0.)?
366 .where_cond(&counts.ones_like()?, &counts.zeros_like()?)?;
367
368 match self.frequency_penalty {
369 Some(freq_penalty) if freq_penalty != 0. => {
370 probs = (probs - (freq_penalty as f64 * counts)?)?;
371 }
372 _ => (),
373 }
374
375 match self.presence_penalty {
376 Some(pres_penalty) if pres_penalty != 0. => {
377 probs = (probs - (pres_penalty as f64 * &presence)?)?;
378 }
379 _ => (),
380 }
381
382 match self.repetition_penalty {
383 Some(rep_penalty) if rep_penalty != 1. => {
384 let pos_mask = probs.gt(0.)?;
385 let scaled_pos = (&probs / (rep_penalty as f64))?;
386 let scaled_neg = (&probs * (rep_penalty as f64))?;
387 let modified = pos_mask.where_cond(&scaled_pos, &scaled_neg)?;
388
389 let pres_mask = presence.gt(0.)?;
390 probs = pres_mask.where_cond(&modified, &probs)?;
391 }
392 _ => (),
393 }
394
395 probs = candle_nn::ops::softmax_last_dim(&(probs / self.temperature.unwrap_or(1.))?)?;
396
397 if top_k > 0 {
399 let sorted_values = probs.fast_sort_asc(D::Minus1)?;
400 let topk_values = sorted_values.narrow(
401 D::Minus1,
402 sorted_values.dim(D::Minus1)? - top_k as usize,
403 top_k as usize,
404 )?;
405
406 let threshold = topk_values.get_on_dim(D::Minus1, 0)?.unsqueeze(0)?;
408 let mask_topk = probs.broadcast_ge(&threshold)?;
409 probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
410 }
411
412 if top_p > 0.0 && top_p < 1.0 {
414 let sorted_probs = probs.fast_sort_asc(D::Minus1)?;
415
416 let cumsum = sorted_probs.fast_cumsum(D::Minus1)?;
417
418 let mask_topp = cumsum.le(top_p)?;
419
420 let masked_sorted =
421 mask_topp.where_cond(&sorted_probs, &Tensor::zeros_like(&sorted_probs)?)?;
422
423 let threshold = masked_sorted.max(D::Minus1)?;
424 let threshold = threshold.unsqueeze(D::Minus1)?;
425 let mask_full = probs.broadcast_ge(&threshold)?;
426 probs = mask_full.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
427 }
428
429 if min_p > 0.0 && min_p < 1.0 {
431 let max_vals = probs.max(D::Minus1)?;
432 let threshold_min = (max_vals.unsqueeze(D::Minus1)? * min_p)?;
433 let mask_minp = probs.broadcast_gt(&threshold_min)?;
434 probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
435 }
436
437 let log_probs = probs.log()?;
439 let gumbel = {
441 let mut guard = self.gumbel_cache.lock().unwrap();
442 if guard.is_none() {
443 let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
444 let noise = uniform
445 .clamp(1e-20, 1.0)?
446 .log()? .neg()? .log()? .neg()?; *guard = Some(noise);
451 }
452 guard.as_ref().unwrap().clone()
453 };
454
455 let gumbel_logits = (&log_probs + &gumbel)?;
456 let next_token = gumbel_logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
457
458 let (top_logprobs, logprob) = if return_logprobs {
460 let k = self.top_n_logprobs;
461
462 let sorted_values = probs.fast_sort_asc(D::Minus1)?;
463 let topk_values = sorted_values
464 .narrow(
465 D::Minus1,
466 sorted_values.dim(D::Minus1)? - top_k as usize,
467 top_k as usize,
468 )?
469 .to_vec1::<f32>()?;
470
471 let sorted_idxs = probs.fast_argsort_asc(D::Minus1)?;
472 let topk_idxs = sorted_idxs
473 .narrow(
474 D::Minus1,
475 sorted_values.dim(D::Minus1)? - top_k as usize,
476 top_k as usize,
477 )?
478 .to_vec1::<u32>()?;
479
480 let mut result = Vec::with_capacity(k);
481 if let Some(tokenizer) = &self.tokenizer {
482 for (prob, token) in topk_values.iter().zip(topk_idxs) {
483 let decoded = tokenizer
484 .decode(&[token], false)
485 .map_err(|e| Error::Msg(e.to_string()))?;
486 result.push(TopLogprob {
487 token,
488 logprob: prob.log(10.0),
489 bytes: Some(decoded),
490 });
491 }
492 } else {
493 for (prob, token) in topk_values.iter().zip(topk_idxs) {
494 result.push(TopLogprob {
495 token,
496 logprob: prob.log(10.0),
497 bytes: None,
498 });
499 }
500 }
501
502 let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
503
504 (Some(result), logprob)
505 } else {
506 (None, 1.)
507 };
508
509 let bytes = if let Some(tokenizer) = &self.tokenizer {
510 Some(
511 tokenizer
512 .decode(&[next_token], false)
513 .map_err(|x| Error::Msg(x.to_string()))?,
514 )
515 } else {
516 None
517 };
518
519 Ok(Logprobs {
520 token: next_token,
521 logprob,
522 top_logprobs,
523 bytes,
524 })
525 }
526 fn sample_speculative_top_kp_min_p(
527 &self,
528 logits: Tensor,
529 return_logprobs: bool,
530 top_k: i64,
531 top_p: f32,
532 min_p: f32,
533 ) -> Result<Logprobs> {
534 let mut probs: Vec<f32> = logits.to_vec1()?;
535 let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
536
537 if top_k > 0 {
538 for (index, val) in argsort_indices.iter().enumerate() {
540 if index >= top_k as usize {
541 probs[*val as usize] = 0.0;
542 }
543 }
544 }
545
546 let mut cumsum = 0.;
554 for index in &argsort_indices {
555 if cumsum >= top_p {
556 probs[*index as usize] = 0.0;
557 } else {
558 cumsum += probs[*index as usize];
559 }
560 }
561
562 let max_p = probs[argsort_indices[0] as usize];
563
564 for index in &argsort_indices {
571 if max_p * min_p >= probs[*index as usize] {
572 probs[*index as usize] = 0.0;
573 }
574 }
575
576 let logits = Tensor::from_slice(&probs, logits.shape(), &Device::Cpu)?;
577
578 let next_token = argmax_sample_last_dim(&logits)?.to_scalar::<u32>()?;
579
580 let logprob = probs[next_token as usize].log(10.0);
581
582 let top_logprobs = if return_logprobs {
583 Some(self.get_top_logprobs(&probs, &argsort_indices)?)
584 } else {
585 None
586 };
587
588 let bytes = if let Some(tokenizer) = &self.tokenizer {
589 Some(
590 tokenizer
591 .decode(&[next_token], false)
592 .map_err(|x| Error::Msg(x.to_string()))?,
593 )
594 } else {
595 None
596 };
597
598 Ok(Logprobs {
599 token: next_token,
600 logprob,
601 top_logprobs,
602 bytes,
603 })
604 }
605
606 fn sample_multinomial(
607 &self,
608 probs: &mut Vec<f32>,
609 argsort_indices: Vec<u32>,
610 return_logprobs: bool,
611 rng: Arc<Mutex<Isaac64Rng>>,
612 ) -> Result<Logprobs> {
613 let distr = WeightedIndex::new(&*probs).map_err(Error::wrap)?;
614
615 let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
616 let next_token = distr.sample(&mut mut_ref_rng); let logprob = probs[next_token].log(10.0);
618
619 let top_logprobs = if return_logprobs {
620 Some(self.get_top_logprobs(probs, &argsort_indices)?)
621 } else {
622 None
623 };
624
625 let bytes = if let Some(tokenizer) = &self.tokenizer {
626 Some(
627 tokenizer
628 .decode(&[next_token.try_into().unwrap()], false)
629 .map_err(|x| Error::Msg(x.to_string()))?,
630 )
631 } else {
632 None
633 };
634
635 Ok(Logprobs {
636 token: next_token as u32,
637 logprob,
638 top_logprobs,
639 bytes,
640 })
641 }
642
643 #[allow(clippy::too_many_arguments)]
644 fn sample_top_kp_min_p(
645 &self,
646 probs: &mut Vec<f32>,
647 logits: &Tensor,
648 top_k: i64,
649 top_p: f32,
650 min_p: f32,
651 return_logprobs: bool,
652 rng: Arc<Mutex<Isaac64Rng>>,
653 ) -> Result<Logprobs> {
654 let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
655
656 if top_k > 0 {
657 for (index, val) in argsort_indices.iter().enumerate() {
659 if index >= top_k as usize {
660 probs[*val as usize] = 0.0;
661 }
662 }
663 }
664
665 if top_p <= 0.0 || top_p >= 1.0 {
666 return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
667 }
668
669 let mut cumsum = 0.;
677 for index in &argsort_indices {
678 if cumsum >= top_p {
679 probs[*index as usize] = 0.0;
680 } else {
681 cumsum += probs[*index as usize];
682 }
683 }
684
685 if min_p <= 0.0 || min_p >= 1.0 {
686 return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
687 }
688
689 let max_p = probs[argsort_indices[0] as usize];
690
691 for index in &argsort_indices {
698 if max_p * min_p >= probs[*index as usize] {
699 probs[*index as usize] = 0.0;
700 }
701 }
702
703 self.sample_multinomial(probs, argsort_indices, return_logprobs, rng)
705 }
706
707 fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
708 if context.is_empty() {
709 candle_core::bail!("Penalty context is empty, this should not happen.");
710 }
711
712 self.apply_dry_penalty(&mut logits, context)?;
714
715 self.apply_freq_pres_rep_penalty(&mut logits, context)?;
717
718 let vocab_size = logits.len();
719 Tensor::from_vec(logits, vocab_size, &Device::Cpu)
720 }
721
722 fn apply_freq_pres_rep_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
723 if self.frequency_penalty.is_some()
724 || self.presence_penalty.is_some()
725 || self.repetition_penalty.is_some()
726 {
727 let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
728 let presence_penalty = self.presence_penalty.unwrap_or(0.);
729 let repetition_penalty = self.repetition_penalty.unwrap_or(1.);
730
731 let mut counts = vec![0.0f32; logits.len()];
734 for ctx in context.iter() {
735 if *ctx as usize >= logits.len() {
737 continue;
738 }
739 counts[*ctx as usize] += 1.0;
740 }
741
742 for (token_id, logit) in logits.iter_mut().enumerate() {
743 let count = counts[token_id];
744 *logit = *logit
745 - count * frequency_penalty
746 - if count > 0.0 { 1. } else { 0. } * presence_penalty;
747
748 if repetition_penalty != 1.0 && count > 0.0 {
749 if *logit > 0.0 {
750 *logit /= repetition_penalty;
751 } else {
752 *logit *= repetition_penalty;
753 }
754 }
755 }
756 }
757 Ok(())
758 }
759
760 fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
761 if let Some(ref params) = self.dry_params {
762 if params.multiplier == 0. {
763 return Ok(());
764 }
765
766 let match_indices = context
767 .par_iter()
768 .enumerate()
769 .take(context.len() - 1)
770 .filter(|(_i, x)| *context.last().unwrap() == **x)
771 .map(|(i, _)| i)
772 .collect::<Vec<_>>();
773
774 let mut match_lengths = HashMap::new();
775
776 for i in match_indices {
777 let next_token = context[i + 1];
778
779 if params.sequence_breakers.contains(&next_token) {
780 continue;
781 }
782
783 let mut match_length = 1;
784
785 while match_length < 50 {
787 if match_length > i {
788 break;
790 }
791
792 let j = i - match_length;
793
794 let prev_tok = context[context.len() - (match_length + 1)];
795 if context[j] != prev_tok {
796 break;
798 }
799
800 if params.sequence_breakers.contains(&prev_tok) {
801 break;
803 }
804
805 match_length += 1;
806 }
807
808 #[allow(clippy::map_entry)]
809 if match_lengths.contains_key(&next_token) {
810 match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
811 } else {
812 match_lengths.insert(next_token, match_length);
813 }
814 }
815
816 for (tok, match_len) in match_lengths {
818 if match_len >= params.allowed_length {
819 if tok as usize >= logits.len() {
821 continue;
822 }
823 let penalty = params.multiplier
824 * params.base.powf((match_len - params.allowed_length) as f32);
825 logits[tok as usize] -= penalty;
826 }
827 }
828 }
829 Ok(())
830 }
831
832 #[allow(unused)]
833 pub fn sample(
838 &self,
839 logits: Tensor,
840 context: &[u32],
841 return_logprobs: bool,
842 rng: Arc<Mutex<Isaac64Rng>>,
843 sample_speculative: bool,
844 multiple_sequences: bool,
845 ) -> Result<Logprobs> {
846 let logits = logits.to_vec1()?;
858 let mut logits = self.apply_penalties(logits, context)?;
859 for processor in &self.logits_processors {
860 logits = processor.apply(&logits, context)?;
861 }
862 let next_token = if sample_speculative {
863 match self.temperature {
864 None => self.sample_speculative_top_kp_min_p(
865 logits,
866 return_logprobs,
867 self.top_k,
868 self.top_p as f32,
869 self.min_p as f32,
870 )?,
871 Some(temperature) => {
872 let logits = (&logits / temperature)?;
873 let probs = candle_nn::ops::softmax_last_dim(&logits)?;
874
875 self.sample_speculative_top_kp_min_p(
876 probs,
877 return_logprobs,
878 self.top_k,
879 self.top_p as f32,
880 self.min_p as f32,
881 )?
882 }
883 }
884 } else {
885 match self.temperature {
886 None => self.sample_argmax(logits, return_logprobs)?,
887 Some(temperature) => {
888 let logits = (&logits / temperature)?;
889 let logits = candle_nn::ops::softmax_last_dim(&logits)?;
890 let mut probs: Vec<f32> = logits.to_vec1()?;
891
892 self.sample_top_kp_min_p(
893 &mut probs,
894 &logits,
895 self.top_k,
896 self.top_p as f32,
897 self.min_p as f32,
898 return_logprobs,
899 rng,
900 )?
901 }
902 }
903 };
904 Ok(next_token)
905 }
906}
907
908mod tests {
909 #[test]
910 fn test_argmax() {
911 use super::Sampler;
912 use candle_core::{Device, Tensor};
913 use rand::SeedableRng;
914 use rand_isaac::Isaac64Rng;
915 use std::sync::Arc;
916 use std::sync::Mutex;
917
918 let sampler = Sampler::new(
919 None,
920 10,
921 None,
922 None,
923 None,
924 None,
925 None,
926 32,
927 0.1,
928 0.05,
929 vec![],
930 )
931 .unwrap();
932 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
933 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
934 let res = sampler
935 .sample(
936 logits,
937 &(0..1024).collect::<Vec<_>>(),
938 false,
939 rng,
940 false,
941 false,
942 )
943 .unwrap();
944 assert_eq!(res.token, 1023);
945 assert_eq!(res.top_logprobs, None);
946 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
947 }
948
949 #[test]
950 fn test_gumbel_speculative() {
951 use super::Sampler;
952 use candle_core::{Device, Tensor};
953 use rand::SeedableRng;
954 use rand_isaac::Isaac64Rng;
955 use std::sync::Arc;
956 use std::sync::Mutex;
957
958 let sampler = Sampler::new(
959 None,
960 10,
961 None,
962 None,
963 None,
964 None,
965 None,
966 32,
967 0.1,
968 0.05,
969 vec![],
970 )
971 .unwrap();
972 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
973 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
974 let res = sampler
975 .sample(
976 logits,
977 &(0..1024).collect::<Vec<_>>(),
978 false,
979 rng,
980 true,
981 false,
982 )
983 .unwrap();
984 assert_eq!(res.token, 1023);
985 assert_eq!(res.top_logprobs, None);
986 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
987 }
988}