1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4 collections::{HashMap, HashSet},
5 sync::{Arc, Mutex},
6};
7
8use candle_core::{DType, Device, Error, Result, Tensor, D};
9use mistralrs_quant::{CumSumOp, SortOp};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use once_cell::sync::Lazy;
14use rand::distr::{weighted::WeightedIndex, Distribution};
15use rand_isaac::Isaac64Rng;
16use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
17use serde::{Deserialize, Serialize};
18use tokenizers::Tokenizer;
19
20static DRY_SEQUENCE_BREAKERS: Lazy<Vec<String>> =
21 Lazy::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
24pub enum StopTokens {
26 Seqs(Vec<String>),
27 Ids(Vec<u32>),
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct SamplingParams {
33 pub temperature: Option<f64>,
34 pub top_k: Option<usize>,
35 pub top_p: Option<f64>,
36 pub min_p: Option<f64>,
37 pub top_n_logprobs: usize,
38 pub frequency_penalty: Option<f32>,
39 pub presence_penalty: Option<f32>,
40 pub repetition_penalty: Option<f32>,
41 pub stop_toks: Option<StopTokens>,
42 pub max_len: Option<usize>,
43 pub logits_bias: Option<HashMap<u32, f32>>,
44 pub n_choices: usize,
45 pub dry_params: Option<DrySamplingParams>,
46}
47
48impl SamplingParams {
49 pub fn deterministic() -> Self {
54 Self {
55 temperature: None,
56 top_k: Some(1),
57 top_p: None,
58 min_p: None,
59 top_n_logprobs: 0,
60 frequency_penalty: None,
61 presence_penalty: None,
62 repetition_penalty: None,
63 stop_toks: None,
64 max_len: None,
65 logits_bias: None,
66 n_choices: 1,
67 dry_params: None,
68 }
69 }
70}
71
72#[derive(Clone, Debug, Serialize, Deserialize)]
73pub struct DrySamplingParams {
74 pub sequence_breakers: Vec<String>,
75 pub multiplier: f32,
76 pub base: f32,
77 pub allowed_length: usize,
78}
79
80impl DrySamplingParams {
81 pub fn new_with_defaults(
82 multiplier: f32,
83 sequence_breakers: Option<Vec<String>>,
84 base: Option<f32>,
85 allowed_length: Option<usize>,
86 ) -> anyhow::Result<Self> {
87 Ok(Self {
88 base: base.unwrap_or(1.75),
89 allowed_length: allowed_length.unwrap_or(2),
90 sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
91 multiplier,
92 })
93 }
94}
95
96impl Default for DrySamplingParams {
97 fn default() -> Self {
98 Self {
99 multiplier: 0.0,
100 base: 1.75,
101 allowed_length: 2,
102 sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
103 }
104 }
105}
106
107#[derive(Clone, Debug)]
108struct DrySamplingParamsInner {
109 pub sequence_breakers: HashSet<u32>,
110 pub multiplier: f32,
111 pub base: f32,
112 pub allowed_length: usize,
113}
114
115impl DrySamplingParamsInner {
116 pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
117 Ok(Self {
118 base: other.base,
119 allowed_length: other.allowed_length,
120 sequence_breakers: HashSet::from_iter(
121 other
122 .sequence_breakers
123 .into_iter()
124 .map(|breaker| {
125 tokenizer
126 .encode_fast(["a", &breaker].concat(), true)
132 .map_err(anyhow::Error::msg)
133 .map(|enc| {
134 let ids = enc.get_ids();
135 if !ids.is_empty() {
136 None
137 } else {
138 Some(ids[ids.len() - 1])
139 }
140 })
141 })
142 .collect::<anyhow::Result<Vec<_>>>()?
143 .into_iter()
144 .flatten()
145 .collect::<Vec<_>>(),
146 ),
147 multiplier: other.multiplier,
148 })
149 }
150}
151
152pub trait CustomLogitsProcessor: Send + Sync {
172 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
174}
175
176impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
177 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
178 self(logits, context)
179 }
180}
181
182#[derive(Clone)]
184pub struct Sampler {
185 temperature: Option<f64>,
186 top_n_logprobs: usize,
187 tokenizer: Option<Arc<Tokenizer>>,
188 frequency_penalty: Option<f32>,
189 presence_penalty: Option<f32>,
190 repetition_penalty: Option<f32>,
191 dry_params: Option<DrySamplingParamsInner>,
192 top_k: i64,
193 top_p: f64,
194 min_p: f64,
195 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
196 gumbel_cache: Arc<Mutex<Option<Tensor>>>,
198}
199
200#[cfg_attr(feature = "pyo3_macros", pyclass)]
201#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
202#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
203pub struct TopLogprob {
205 pub token: u32,
206 pub logprob: f32,
207 pub bytes: Option<String>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct Logprobs {
212 pub token: u32,
213 pub logprob: f32,
214 pub bytes: Option<String>,
215 pub top_logprobs: Option<Vec<TopLogprob>>,
216}
217
218fn argmax_sample_last_dim(logits: &Tensor) -> Result<Tensor> {
219 logits.argmax(D::Minus1)
220}
221
222impl Sampler {
223 #[allow(clippy::too_many_arguments)]
224 pub fn new(
225 temperature: Option<f64>,
226 top_n_logprobs: usize,
227 tokenizer: Option<Arc<Tokenizer>>,
228 frequency_penalty: Option<f32>,
229 presence_penalty: Option<f32>,
230 repetition_penalty: Option<f32>,
231 dry_params: Option<DrySamplingParams>,
232 top_k: i64,
233 top_p: f64,
234 min_p: f64,
235 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
236 ) -> anyhow::Result<Self> {
237 let temperature = if temperature.is_none_or(|v| v < 1e-7) {
238 None
239 } else {
240 temperature
241 };
242 let dry_params = if let Some(ref tokenizer) = tokenizer {
243 dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
244 } else {
245 None
246 };
247 let dry_params = match dry_params {
248 Some(fallible) => Some(fallible?),
249 None => None,
250 };
251 Ok(Self {
252 temperature,
253 top_n_logprobs,
254 tokenizer,
255 frequency_penalty,
256 presence_penalty,
257 repetition_penalty,
258 dry_params,
259 top_k,
260 top_p,
261 min_p,
262 logits_processors,
263 gumbel_cache: Arc::new(Mutex::new(None)),
264 })
265 }
266
267 fn get_top_logprobs(&self, probs: &[f32], _argsort_indices: &[u32]) -> Result<Vec<TopLogprob>> {
268 let k = self.top_n_logprobs.min(probs.len());
270 if k == 0 {
271 return Ok(Vec::new());
272 }
273 let mut idx_probs: Vec<(u32, f32)> = (0..probs.len() as u32)
275 .map(|i| (i, probs[i as usize]))
276 .collect();
277 let (top_k_slice, _, _) = idx_probs.select_nth_unstable_by(k, |a, b| {
279 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
280 });
281 let mut top_k: Vec<(u32, f32)> = top_k_slice.to_vec();
283 top_k.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
284 let mut result = Vec::with_capacity(k);
286 if let Some(tokenizer) = &self.tokenizer {
287 for (token, prob) in top_k {
288 let decoded = tokenizer
289 .decode(&[token], false)
290 .map_err(|e| Error::Msg(e.to_string()))?;
291 result.push(TopLogprob {
292 token,
293 logprob: prob.log(10.0),
294 bytes: Some(decoded),
295 });
296 }
297 } else {
298 for (token, prob) in top_k {
299 result.push(TopLogprob {
300 token,
301 logprob: prob.log(10.0),
302 bytes: None,
303 });
304 }
305 }
306 Ok(result)
307 }
308
309 fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
310 let next_token = logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
311
312 let probs: Vec<f32> = logits.to_vec1()?;
313
314 let argsort_indices = (0..probs.len() as u32).collect::<Vec<_>>();
315 let logprob = probs[next_token as usize].log(10.0);
316
317 let top_logprobs = if return_logprobs {
318 Some(self.get_top_logprobs(&probs, &argsort_indices)?)
319 } else {
320 None
321 };
322
323 let bytes = if let Some(tokenizer) = &self.tokenizer {
324 Some(
325 tokenizer
326 .decode(&[next_token], false)
327 .map_err(|x| Error::Msg(x.to_string()))?,
328 )
329 } else {
330 None
331 };
332
333 Ok(Logprobs {
334 token: next_token,
335 logprob,
336 top_logprobs,
337 bytes,
338 })
339 }
340
341 #[allow(unused)]
342 fn sample_fast(
343 &self,
344 logits: Tensor,
345 context: &[u32],
346 return_logprobs: bool,
347 top_k: i64,
348 top_p: f64,
349 min_p: f64,
350 ) -> Result<Logprobs> {
351 let mut probs = logits.to_dtype(DType::F32)?;
352
353 for processor in &self.logits_processors {
354 probs = processor.apply(&probs, context)?;
355 }
356
357 let context = Tensor::new(context, logits.device())?;
358 let mut counts = logits.zeros_like()?;
359 counts = counts.scatter_add(
360 &context,
361 &context.ones_like()?.to_dtype(counts.dtype())?,
362 D::Minus1,
363 )?;
364
365 let presence = counts
366 .gt(0.)?
367 .where_cond(&counts.ones_like()?, &counts.zeros_like()?)?;
368
369 match self.frequency_penalty {
370 Some(freq_penalty) if freq_penalty != 0. => {
371 probs = (probs - (freq_penalty as f64 * counts)?)?;
372 }
373 _ => (),
374 }
375
376 match self.presence_penalty {
377 Some(pres_penalty) if pres_penalty != 0. => {
378 probs = (probs - (pres_penalty as f64 * &presence)?)?;
379 }
380 _ => (),
381 }
382
383 match self.repetition_penalty {
384 Some(rep_penalty) if rep_penalty != 1. => {
385 let pos_mask = probs.gt(0.)?;
386 let scaled_pos = (&probs / (rep_penalty as f64))?;
387 let scaled_neg = (&probs * (rep_penalty as f64))?;
388 let modified = pos_mask.where_cond(&scaled_pos, &scaled_neg)?;
389
390 let pres_mask = presence.gt(0.)?;
391 probs = pres_mask.where_cond(&modified, &probs)?;
392 }
393 _ => (),
394 }
395
396 probs = candle_nn::ops::softmax_last_dim(&(probs / self.temperature.unwrap_or(1.))?)?;
397
398 if top_k > 0 {
400 let sorted_values = probs.fast_sort_asc(D::Minus1)?;
401 let topk_values = sorted_values.narrow(
402 D::Minus1,
403 sorted_values.dim(D::Minus1)? - top_k as usize,
404 top_k as usize,
405 )?;
406
407 let threshold = topk_values.get_on_dim(D::Minus1, 0)?.unsqueeze(0)?;
409 let mask_topk = probs.broadcast_ge(&threshold)?;
410 probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
411 }
412
413 if top_p > 0.0 && top_p < 1.0 {
415 let sorted_probs = probs.fast_sort_asc(D::Minus1)?;
416
417 let cumsum = sorted_probs.fast_cumsum(D::Minus1)?;
418
419 let mask_topp = cumsum.le(top_p)?;
420
421 let masked_sorted =
422 mask_topp.where_cond(&sorted_probs, &Tensor::zeros_like(&sorted_probs)?)?;
423
424 let threshold = masked_sorted.max(D::Minus1)?;
425 let threshold = threshold.unsqueeze(D::Minus1)?;
426 let mask_full = probs.broadcast_ge(&threshold)?;
427 probs = mask_full.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
428 }
429
430 if min_p > 0.0 && min_p < 1.0 {
432 let max_vals = probs.max(D::Minus1)?;
433 let threshold_min = (max_vals.unsqueeze(D::Minus1)? * min_p)?;
434 let mask_minp = probs.broadcast_gt(&threshold_min)?;
435 probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
436 }
437
438 let log_probs = probs.log()?;
440 let gumbel = {
442 let mut guard = self.gumbel_cache.lock().unwrap();
443 if guard.is_none() {
444 let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
445 let noise = uniform
446 .clamp(1e-20, 1.0)?
447 .log()? .neg()? .log()? .neg()?; *guard = Some(noise);
452 }
453 guard.as_ref().unwrap().clone()
454 };
455
456 let gumbel_logits = (&log_probs + &gumbel)?;
457 let next_token = gumbel_logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
458
459 let (top_logprobs, logprob) = if return_logprobs {
461 let k = self.top_n_logprobs;
462
463 let sorted_values = probs.fast_sort_asc(D::Minus1)?;
464 let topk_values = sorted_values
465 .narrow(
466 D::Minus1,
467 sorted_values.dim(D::Minus1)? - top_k as usize,
468 top_k as usize,
469 )?
470 .to_vec1::<f32>()?;
471
472 let sorted_idxs = probs.fast_argsort_asc(D::Minus1)?;
473 let topk_idxs = sorted_idxs
474 .narrow(
475 D::Minus1,
476 sorted_values.dim(D::Minus1)? - top_k as usize,
477 top_k as usize,
478 )?
479 .to_vec1::<u32>()?;
480
481 let mut result = Vec::with_capacity(k);
482 if let Some(tokenizer) = &self.tokenizer {
483 for (prob, token) in topk_values.iter().zip(topk_idxs) {
484 let decoded = tokenizer
485 .decode(&[token], false)
486 .map_err(|e| Error::Msg(e.to_string()))?;
487 result.push(TopLogprob {
488 token,
489 logprob: prob.log(10.0),
490 bytes: Some(decoded),
491 });
492 }
493 } else {
494 for (prob, token) in topk_values.iter().zip(topk_idxs) {
495 result.push(TopLogprob {
496 token,
497 logprob: prob.log(10.0),
498 bytes: None,
499 });
500 }
501 }
502
503 let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
504
505 (Some(result), logprob)
506 } else {
507 (None, 1.)
508 };
509
510 let bytes = if let Some(tokenizer) = &self.tokenizer {
511 Some(
512 tokenizer
513 .decode(&[next_token], false)
514 .map_err(|x| Error::Msg(x.to_string()))?,
515 )
516 } else {
517 None
518 };
519
520 Ok(Logprobs {
521 token: next_token,
522 logprob,
523 top_logprobs,
524 bytes,
525 })
526 }
527 fn sample_speculative_top_kp_min_p(
528 &self,
529 logits: Tensor,
530 return_logprobs: bool,
531 top_k: i64,
532 top_p: f32,
533 min_p: f32,
534 ) -> Result<Logprobs> {
535 let mut probs: Vec<f32> = logits.to_vec1()?;
536 let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
537
538 if top_k > 0 {
539 for (index, val) in argsort_indices.iter().enumerate() {
541 if index >= top_k as usize {
542 probs[*val as usize] = 0.0;
543 }
544 }
545 }
546
547 let mut cumsum = 0.;
555 for index in &argsort_indices {
556 if cumsum >= top_p {
557 probs[*index as usize] = 0.0;
558 } else {
559 cumsum += probs[*index as usize];
560 }
561 }
562
563 let max_p = probs[argsort_indices[0] as usize];
564
565 for index in &argsort_indices {
572 if max_p * min_p >= probs[*index as usize] {
573 probs[*index as usize] = 0.0;
574 }
575 }
576
577 let logits = Tensor::from_slice(&probs, logits.shape(), &Device::Cpu)?;
578
579 let next_token = argmax_sample_last_dim(&logits)?.to_scalar::<u32>()?;
580
581 let logprob = probs[next_token as usize].log(10.0);
582
583 let top_logprobs = if return_logprobs {
584 Some(self.get_top_logprobs(&probs, &argsort_indices)?)
585 } else {
586 None
587 };
588
589 let bytes = if let Some(tokenizer) = &self.tokenizer {
590 Some(
591 tokenizer
592 .decode(&[next_token], false)
593 .map_err(|x| Error::Msg(x.to_string()))?,
594 )
595 } else {
596 None
597 };
598
599 Ok(Logprobs {
600 token: next_token,
601 logprob,
602 top_logprobs,
603 bytes,
604 })
605 }
606
607 fn sample_multinomial(
608 &self,
609 probs: &mut Vec<f32>,
610 argsort_indices: Vec<u32>,
611 return_logprobs: bool,
612 rng: Arc<Mutex<Isaac64Rng>>,
613 ) -> Result<Logprobs> {
614 let distr = WeightedIndex::new(&*probs).map_err(Error::wrap)?;
615
616 let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
617 let next_token = distr.sample(&mut mut_ref_rng); let logprob = probs[next_token].log(10.0);
619
620 let top_logprobs = if return_logprobs {
621 Some(self.get_top_logprobs(probs, &argsort_indices)?)
622 } else {
623 None
624 };
625
626 let bytes = if let Some(tokenizer) = &self.tokenizer {
627 Some(
628 tokenizer
629 .decode(&[next_token.try_into().unwrap()], false)
630 .map_err(|x| Error::Msg(x.to_string()))?,
631 )
632 } else {
633 None
634 };
635
636 Ok(Logprobs {
637 token: next_token as u32,
638 logprob,
639 top_logprobs,
640 bytes,
641 })
642 }
643
644 #[allow(clippy::too_many_arguments)]
645 fn sample_top_kp_min_p(
646 &self,
647 probs: &mut Vec<f32>,
648 logits: &Tensor,
649 top_k: i64,
650 top_p: f32,
651 min_p: f32,
652 return_logprobs: bool,
653 rng: Arc<Mutex<Isaac64Rng>>,
654 ) -> Result<Logprobs> {
655 let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
656
657 if top_k > 0 {
658 for (index, val) in argsort_indices.iter().enumerate() {
660 if index >= top_k as usize {
661 probs[*val as usize] = 0.0;
662 }
663 }
664 }
665
666 if top_p <= 0.0 || top_p >= 1.0 {
667 return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
668 }
669
670 let mut cumsum = 0.;
678 for index in &argsort_indices {
679 if cumsum >= top_p {
680 probs[*index as usize] = 0.0;
681 } else {
682 cumsum += probs[*index as usize];
683 }
684 }
685
686 if min_p <= 0.0 || min_p >= 1.0 {
687 return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
688 }
689
690 let max_p = probs[argsort_indices[0] as usize];
691
692 for index in &argsort_indices {
699 if max_p * min_p >= probs[*index as usize] {
700 probs[*index as usize] = 0.0;
701 }
702 }
703
704 self.sample_multinomial(probs, argsort_indices, return_logprobs, rng)
706 }
707
708 fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
709 if context.is_empty() {
710 candle_core::bail!("Penalty context is empty, this should not happen.");
711 }
712
713 self.apply_dry_penalty(&mut logits, context)?;
715
716 self.apply_freq_pres_rep_penalty(&mut logits, context)?;
718
719 let vocab_size = logits.len();
720 Tensor::from_vec(logits, vocab_size, &Device::Cpu)
721 }
722
723 fn apply_freq_pres_rep_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
724 if self.frequency_penalty.is_some()
725 || self.presence_penalty.is_some()
726 || self.repetition_penalty.is_some()
727 {
728 let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
729 let presence_penalty = self.presence_penalty.unwrap_or(0.);
730 let repetition_penalty = self.repetition_penalty.unwrap_or(1.);
731
732 let mut counts = vec![0.0f32; logits.len()];
735 for ctx in context.iter() {
736 if *ctx as usize >= logits.len() {
738 continue;
739 }
740 counts[*ctx as usize] += 1.0;
741 }
742
743 for (token_id, logit) in logits.iter_mut().enumerate() {
744 let count = counts[token_id];
745 *logit = *logit
746 - count * frequency_penalty
747 - if count > 0.0 { 1. } else { 0. } * presence_penalty;
748
749 if repetition_penalty != 1.0 && count > 0.0 {
750 if *logit > 0.0 {
751 *logit /= repetition_penalty;
752 } else {
753 *logit *= repetition_penalty;
754 }
755 }
756 }
757 }
758 Ok(())
759 }
760
761 fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
762 if let Some(ref params) = self.dry_params {
763 if params.multiplier == 0. {
764 return Ok(());
765 }
766
767 let match_indices = context
768 .par_iter()
769 .enumerate()
770 .take(context.len() - 1)
771 .filter(|(_i, x)| *context.last().unwrap() == **x)
772 .map(|(i, _)| i)
773 .collect::<Vec<_>>();
774
775 let mut match_lengths = HashMap::new();
776
777 for i in match_indices {
778 let next_token = context[i + 1];
779
780 if params.sequence_breakers.contains(&next_token) {
781 continue;
782 }
783
784 let mut match_length = 1;
785
786 while match_length < 50 {
788 if match_length > i {
789 break;
791 }
792
793 let j = i - match_length;
794
795 let prev_tok = context[context.len() - (match_length + 1)];
796 if context[j] != prev_tok {
797 break;
799 }
800
801 if params.sequence_breakers.contains(&prev_tok) {
802 break;
804 }
805
806 match_length += 1;
807 }
808
809 #[allow(clippy::map_entry)]
810 if match_lengths.contains_key(&next_token) {
811 match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
812 } else {
813 match_lengths.insert(next_token, match_length);
814 }
815 }
816
817 for (tok, match_len) in match_lengths {
819 if match_len >= params.allowed_length {
820 if tok as usize >= logits.len() {
822 continue;
823 }
824 let penalty = params.multiplier
825 * params.base.powf((match_len - params.allowed_length) as f32);
826 logits[tok as usize] -= penalty;
827 }
828 }
829 }
830 Ok(())
831 }
832
833 #[allow(unused)]
834 pub fn sample(
839 &self,
840 logits: Tensor,
841 context: &[u32],
842 return_logprobs: bool,
843 rng: Arc<Mutex<Isaac64Rng>>,
844 sample_speculative: bool,
845 multiple_sequences: bool,
846 ) -> Result<Logprobs> {
847 let logits = logits.to_vec1()?;
859 let mut logits = self.apply_penalties(logits, context)?;
860 for processor in &self.logits_processors {
861 logits = processor.apply(&logits, context)?;
862 }
863 let next_token = if sample_speculative {
864 match self.temperature {
865 None => self.sample_speculative_top_kp_min_p(
866 logits,
867 return_logprobs,
868 self.top_k,
869 self.top_p as f32,
870 self.min_p as f32,
871 )?,
872 Some(temperature) => {
873 let logits = (&logits / temperature)?;
874 let probs = candle_nn::ops::softmax_last_dim(&logits)?;
875
876 self.sample_speculative_top_kp_min_p(
877 probs,
878 return_logprobs,
879 self.top_k,
880 self.top_p as f32,
881 self.min_p as f32,
882 )?
883 }
884 }
885 } else {
886 match self.temperature {
887 None => self.sample_argmax(logits, return_logprobs)?,
888 Some(temperature) => {
889 let logits = (&logits / temperature)?;
890 let logits = candle_nn::ops::softmax_last_dim(&logits)?;
891 let mut probs: Vec<f32> = logits.to_vec1()?;
892
893 self.sample_top_kp_min_p(
894 &mut probs,
895 &logits,
896 self.top_k,
897 self.top_p as f32,
898 self.min_p as f32,
899 return_logprobs,
900 rng,
901 )?
902 }
903 }
904 };
905 Ok(next_token)
906 }
907}
908
909mod tests {
910 #[test]
911 fn test_argmax() {
912 use super::Sampler;
913 use candle_core::{Device, Tensor};
914 use rand::SeedableRng;
915 use rand_isaac::Isaac64Rng;
916 use std::sync::Arc;
917 use std::sync::Mutex;
918
919 let sampler = Sampler::new(
920 None,
921 10,
922 None,
923 None,
924 None,
925 None,
926 None,
927 32,
928 0.1,
929 0.05,
930 vec![],
931 )
932 .unwrap();
933 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
934 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
935 let res = sampler
936 .sample(
937 logits,
938 &(0..1024).collect::<Vec<_>>(),
939 false,
940 rng,
941 false,
942 false,
943 )
944 .unwrap();
945 assert_eq!(res.token, 1023);
946 assert_eq!(res.top_logprobs, None);
947 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
948 }
949
950 #[test]
951 fn test_gumbel_speculative() {
952 use super::Sampler;
953 use candle_core::{Device, Tensor};
954 use rand::SeedableRng;
955 use rand_isaac::Isaac64Rng;
956 use std::sync::Arc;
957 use std::sync::Mutex;
958
959 let sampler = Sampler::new(
960 None,
961 10,
962 None,
963 None,
964 None,
965 None,
966 None,
967 32,
968 0.1,
969 0.05,
970 vec![],
971 )
972 .unwrap();
973 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
974 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
975 let res = sampler
976 .sample(
977 logits,
978 &(0..1024).collect::<Vec<_>>(),
979 false,
980 rng,
981 true,
982 false,
983 )
984 .unwrap();
985 assert_eq!(res.token, 1023);
986 assert_eq!(res.top_logprobs, None);
987 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
988 }
989}