1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4 collections::{HashMap, HashSet},
5 sync::{Arc, Mutex},
6};
7
8use candle_core::{DType, Device, Error, Result, Tensor, D};
9use mistralrs_quant::{CumSumOp, SortOp};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use once_cell::sync::Lazy;
14use rand::distr::{weighted::WeightedIndex, Distribution};
15use rand_isaac::Isaac64Rng;
16use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
17use serde::{Deserialize, Serialize};
18use tokenizers::Tokenizer;
19
20static DRY_SEQUENCE_BREAKERS: Lazy<Vec<String>> =
21 Lazy::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
24pub enum StopTokens {
26 Seqs(Vec<String>),
27 Ids(Vec<u32>),
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct SamplingParams {
33 pub temperature: Option<f64>,
34 pub top_k: Option<usize>,
35 pub top_p: Option<f64>,
36 pub min_p: Option<f64>,
37 pub top_n_logprobs: usize,
38 pub frequency_penalty: Option<f32>,
39 pub presence_penalty: Option<f32>,
40 pub stop_toks: Option<StopTokens>,
41 pub max_len: Option<usize>,
42 pub logits_bias: Option<HashMap<u32, f32>>,
43 pub n_choices: usize,
44 pub dry_params: Option<DrySamplingParams>,
45}
46
47impl SamplingParams {
48 pub fn deterministic() -> Self {
53 Self {
54 temperature: None,
55 top_k: Some(1),
56 top_p: None,
57 min_p: None,
58 top_n_logprobs: 0,
59 frequency_penalty: None,
60 presence_penalty: None,
61 stop_toks: None,
62 max_len: None,
63 logits_bias: None,
64 n_choices: 1,
65 dry_params: None,
66 }
67 }
68}
69
70#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct DrySamplingParams {
72 pub sequence_breakers: Vec<String>,
73 pub multiplier: f32,
74 pub base: f32,
75 pub allowed_length: usize,
76}
77
78impl DrySamplingParams {
79 pub fn new_with_defaults(
80 multiplier: f32,
81 sequence_breakers: Option<Vec<String>>,
82 base: Option<f32>,
83 allowed_length: Option<usize>,
84 ) -> anyhow::Result<Self> {
85 Ok(Self {
86 base: base.unwrap_or(1.75),
87 allowed_length: allowed_length.unwrap_or(2),
88 sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
89 multiplier,
90 })
91 }
92}
93
94impl Default for DrySamplingParams {
95 fn default() -> Self {
96 Self {
97 multiplier: 0.0,
98 base: 1.75,
99 allowed_length: 2,
100 sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
101 }
102 }
103}
104
105#[derive(Clone, Debug)]
106struct DrySamplingParamsInner {
107 pub sequence_breakers: HashSet<u32>,
108 pub multiplier: f32,
109 pub base: f32,
110 pub allowed_length: usize,
111}
112
113impl DrySamplingParamsInner {
114 pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
115 Ok(Self {
116 base: other.base,
117 allowed_length: other.allowed_length,
118 sequence_breakers: HashSet::from_iter(
119 other
120 .sequence_breakers
121 .into_iter()
122 .map(|breaker| {
123 tokenizer
124 .encode_fast(["a", &breaker].concat(), true)
130 .map_err(anyhow::Error::msg)
131 .map(|enc| {
132 let ids = enc.get_ids();
133 if !ids.is_empty() {
134 None
135 } else {
136 Some(ids[ids.len() - 1])
137 }
138 })
139 })
140 .collect::<anyhow::Result<Vec<_>>>()?
141 .into_iter()
142 .flatten()
143 .collect::<Vec<_>>(),
144 ),
145 multiplier: other.multiplier,
146 })
147 }
148}
149
150pub trait CustomLogitsProcessor: Send + Sync {
170 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
172}
173
174impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
175 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
176 self(logits, context)
177 }
178}
179
180#[derive(Clone)]
182pub struct Sampler {
183 temperature: Option<f64>,
184 top_n_logprobs: usize,
185 tokenizer: Option<Arc<Tokenizer>>,
186 frequency_penalty: Option<f32>,
187 presence_penalty: Option<f32>,
188 dry_params: Option<DrySamplingParamsInner>,
189 top_k: i64,
190 top_p: f64,
191 min_p: f64,
192 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
193}
194
195#[cfg_attr(feature = "pyo3_macros", pyclass)]
196#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
198pub struct TopLogprob {
200 pub token: u32,
201 pub logprob: f32,
202 pub bytes: Option<String>,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct Logprobs {
207 pub token: u32,
208 pub logprob: f32,
209 pub bytes: Option<String>,
210 pub top_logprobs: Option<Vec<TopLogprob>>,
211}
212
213fn argmax_sample_last_dim(logits: &Tensor) -> Result<Tensor> {
214 logits.argmax(D::Minus1)
215}
216
217impl Sampler {
218 #[allow(clippy::too_many_arguments)]
219 pub fn new(
220 temperature: Option<f64>,
221 top_n_logprobs: usize,
222 tokenizer: Option<Arc<Tokenizer>>,
223 frequency_penalty: Option<f32>,
224 presence_penalty: Option<f32>,
225 dry_params: Option<DrySamplingParams>,
226 top_k: i64,
227 top_p: f64,
228 min_p: f64,
229 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
230 ) -> anyhow::Result<Self> {
231 let temperature = if temperature.is_none_or(|v| v < 1e-7) {
232 None
233 } else {
234 temperature
235 };
236 let dry_params = if let Some(ref tokenizer) = tokenizer {
237 dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
238 } else {
239 None
240 };
241 let dry_params = match dry_params {
242 Some(fallible) => Some(fallible?),
243 None => None,
244 };
245 Ok(Self {
246 temperature,
247 top_n_logprobs,
248 tokenizer,
249 frequency_penalty,
250 presence_penalty,
251 dry_params,
252 top_k,
253 top_p,
254 min_p,
255 logits_processors,
256 })
257 }
258
259 fn get_top_logprobs(&self, probs: &[f32], _argsort_indices: &[u32]) -> Result<Vec<TopLogprob>> {
260 let k = self.top_n_logprobs.min(probs.len());
262 if k == 0 {
263 return Ok(Vec::new());
264 }
265 let mut idx_probs: Vec<(u32, f32)> = (0..probs.len() as u32)
267 .map(|i| (i, probs[i as usize]))
268 .collect();
269 let (top_k_slice, _, _) = idx_probs.select_nth_unstable_by(k, |a, b| {
271 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
272 });
273 let mut top_k: Vec<(u32, f32)> = top_k_slice.to_vec();
275 top_k.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
276 let mut result = Vec::with_capacity(k);
278 if let Some(tokenizer) = &self.tokenizer {
279 for (token, prob) in top_k {
280 let decoded = tokenizer
281 .decode(&[token], false)
282 .map_err(|e| Error::Msg(e.to_string()))?;
283 result.push(TopLogprob {
284 token,
285 logprob: prob.log(10.0),
286 bytes: Some(decoded),
287 });
288 }
289 } else {
290 for (token, prob) in top_k {
291 result.push(TopLogprob {
292 token,
293 logprob: prob.log(10.0),
294 bytes: None,
295 });
296 }
297 }
298 Ok(result)
299 }
300
301 fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
302 let next_token = logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
303
304 let probs: Vec<f32> = logits.to_vec1()?;
305
306 let argsort_indices = (0..probs.len() as u32).collect::<Vec<_>>();
307 let logprob = probs[next_token as usize].log(10.0);
308
309 let top_logprobs = if return_logprobs {
310 Some(self.get_top_logprobs(&probs, &argsort_indices)?)
311 } else {
312 None
313 };
314
315 let bytes = if let Some(tokenizer) = &self.tokenizer {
316 Some(
317 tokenizer
318 .decode(&[next_token], false)
319 .map_err(|x| Error::Msg(x.to_string()))?,
320 )
321 } else {
322 None
323 };
324
325 Ok(Logprobs {
326 token: next_token,
327 logprob,
328 top_logprobs,
329 bytes,
330 })
331 }
332
333 #[allow(unused)]
334 fn sample_fast(
335 &self,
336 logits: Tensor,
337 context: &[u32],
338 return_logprobs: bool,
339 top_k: i64,
340 top_p: f64,
341 min_p: f64,
342 ) -> Result<Logprobs> {
343 let mut probs = logits.to_dtype(DType::F32)?;
344
345 for processor in &self.logits_processors {
346 probs = processor.apply(&probs, context)?;
347 }
348
349 let context = Tensor::new(context, logits.device())?;
350 let mut counts = logits.zeros_like()?;
351 counts = counts.scatter_add(
352 &context,
353 &context.ones_like()?.to_dtype(counts.dtype())?,
354 D::Minus1,
355 )?;
356
357 let presence = counts
358 .gt(0.)?
359 .where_cond(&counts.ones_like()?, &counts.zeros_like()?)?;
360
361 match self.frequency_penalty {
362 Some(freq_penalty) if freq_penalty != 0. => {
363 probs = (probs - (freq_penalty as f64 * counts)?)?;
364 }
365 _ => (),
366 }
367
368 match self.presence_penalty {
369 Some(pres_penalty) if pres_penalty != 0. => {
370 probs = (probs - (pres_penalty as f64 * presence)?)?;
371 }
372 _ => (),
373 }
374
375 probs = candle_nn::ops::softmax_last_dim(&(probs / self.temperature.unwrap_or(1.))?)?;
376
377 if top_k > 0 {
379 let sorted_values = probs.fast_sort_asc(D::Minus1)?;
380 let topk_values = sorted_values.narrow(
381 D::Minus1,
382 sorted_values.dim(D::Minus1)? - top_k as usize,
383 top_k as usize,
384 )?;
385
386 let threshold = topk_values.get_on_dim(D::Minus1, 0)?.unsqueeze(0)?;
388 let mask_topk = probs.broadcast_ge(&threshold)?;
389 probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
390 }
391
392 if top_p > 0.0 && top_p < 1.0 {
394 let sorted_probs = probs.fast_sort_asc(D::Minus1)?;
395
396 let cumsum = sorted_probs.fast_cumsum(D::Minus1)?;
397
398 let mask_topp = cumsum.le(top_p)?;
399
400 let masked_sorted =
401 mask_topp.where_cond(&sorted_probs, &Tensor::zeros_like(&sorted_probs)?)?;
402
403 let threshold = masked_sorted.max(D::Minus1)?;
404 let threshold = threshold.unsqueeze(D::Minus1)?;
405 let mask_full = probs.broadcast_ge(&threshold)?;
406 probs = mask_full.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
407 }
408
409 if min_p > 0.0 && min_p < 1.0 {
411 let max_vals = probs.max(D::Minus1)?;
412 let threshold_min = (max_vals.unsqueeze(D::Minus1)? * min_p)?;
413 let mask_minp = probs.broadcast_gt(&threshold_min)?;
414 probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
415 }
416
417 let next_token = probs.argmax(D::Minus1)?.to_scalar::<u32>()?;
418
419 let (top_logprobs, logprob) = if return_logprobs {
421 let k = self.top_n_logprobs;
422
423 let sorted_values = probs.fast_sort_asc(D::Minus1)?;
424 let topk_values = sorted_values
425 .narrow(
426 D::Minus1,
427 sorted_values.dim(D::Minus1)? - top_k as usize,
428 top_k as usize,
429 )?
430 .to_vec1::<f32>()?;
431
432 let sorted_idxs = probs.fast_argsort_asc(D::Minus1)?;
433 let topk_idxs = sorted_idxs
434 .narrow(
435 D::Minus1,
436 sorted_values.dim(D::Minus1)? - top_k as usize,
437 top_k as usize,
438 )?
439 .to_vec1::<u32>()?;
440
441 let mut result = Vec::with_capacity(k);
442 if let Some(tokenizer) = &self.tokenizer {
443 for (prob, token) in topk_values.iter().zip(topk_idxs) {
444 let decoded = tokenizer
445 .decode(&[token], false)
446 .map_err(|e| Error::Msg(e.to_string()))?;
447 result.push(TopLogprob {
448 token,
449 logprob: prob.log(10.0),
450 bytes: Some(decoded),
451 });
452 }
453 } else {
454 for (prob, token) in topk_values.iter().zip(topk_idxs) {
455 result.push(TopLogprob {
456 token,
457 logprob: prob.log(10.0),
458 bytes: None,
459 });
460 }
461 }
462
463 let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
464
465 (Some(result), logprob)
466 } else {
467 (None, 1.)
468 };
469
470 let bytes = if let Some(tokenizer) = &self.tokenizer {
471 Some(
472 tokenizer
473 .decode(&[next_token], false)
474 .map_err(|x| Error::Msg(x.to_string()))?,
475 )
476 } else {
477 None
478 };
479
480 Ok(Logprobs {
481 token: next_token,
482 logprob,
483 top_logprobs,
484 bytes,
485 })
486 }
487 fn sample_speculative_top_kp_min_p(
488 &self,
489 logits: Tensor,
490 return_logprobs: bool,
491 top_k: i64,
492 top_p: f32,
493 min_p: f32,
494 ) -> Result<Logprobs> {
495 let mut probs: Vec<f32> = logits.to_vec1()?;
496 let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
497
498 if top_k > 0 {
499 for (index, val) in argsort_indices.iter().enumerate() {
501 if index >= top_k as usize {
502 probs[*val as usize] = 0.0;
503 }
504 }
505 }
506
507 let mut cumsum = 0.;
515 for index in &argsort_indices {
516 if cumsum >= top_p {
517 probs[*index as usize] = 0.0;
518 } else {
519 cumsum += probs[*index as usize];
520 }
521 }
522
523 let max_p = probs[argsort_indices[0] as usize];
524
525 for index in &argsort_indices {
532 if max_p * min_p >= probs[*index as usize] {
533 probs[*index as usize] = 0.0;
534 }
535 }
536
537 let logits = Tensor::from_slice(&probs, logits.shape(), &Device::Cpu)?;
538
539 let next_token = argmax_sample_last_dim(&logits)?.to_scalar::<u32>()?;
540
541 let logprob = probs[next_token as usize].log(10.0);
542
543 let top_logprobs = if return_logprobs {
544 Some(self.get_top_logprobs(&probs, &argsort_indices)?)
545 } else {
546 None
547 };
548
549 let bytes = if let Some(tokenizer) = &self.tokenizer {
550 Some(
551 tokenizer
552 .decode(&[next_token], false)
553 .map_err(|x| Error::Msg(x.to_string()))?,
554 )
555 } else {
556 None
557 };
558
559 Ok(Logprobs {
560 token: next_token,
561 logprob,
562 top_logprobs,
563 bytes,
564 })
565 }
566
567 fn sample_multinomial(
568 &self,
569 probs: &mut Vec<f32>,
570 argsort_indices: Vec<u32>,
571 return_logprobs: bool,
572 rng: Arc<Mutex<Isaac64Rng>>,
573 ) -> Result<Logprobs> {
574 let distr = WeightedIndex::new(&*probs).map_err(Error::wrap)?;
575
576 let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
577 let next_token = distr.sample(&mut mut_ref_rng); let logprob = probs[next_token].log(10.0);
579
580 let top_logprobs = if return_logprobs {
581 Some(self.get_top_logprobs(probs, &argsort_indices)?)
582 } else {
583 None
584 };
585
586 let bytes = if let Some(tokenizer) = &self.tokenizer {
587 Some(
588 tokenizer
589 .decode(&[next_token.try_into().unwrap()], false)
590 .map_err(|x| Error::Msg(x.to_string()))?,
591 )
592 } else {
593 None
594 };
595
596 Ok(Logprobs {
597 token: next_token as u32,
598 logprob,
599 top_logprobs,
600 bytes,
601 })
602 }
603
604 #[allow(clippy::too_many_arguments)]
605 fn sample_top_kp_min_p(
606 &self,
607 probs: &mut Vec<f32>,
608 logits: &Tensor,
609 top_k: i64,
610 top_p: f32,
611 min_p: f32,
612 return_logprobs: bool,
613 rng: Arc<Mutex<Isaac64Rng>>,
614 ) -> Result<Logprobs> {
615 let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
616
617 if top_k > 0 {
618 for (index, val) in argsort_indices.iter().enumerate() {
620 if index >= top_k as usize {
621 probs[*val as usize] = 0.0;
622 }
623 }
624 }
625
626 if top_p <= 0.0 || top_p >= 1.0 {
627 return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
628 }
629
630 let mut cumsum = 0.;
638 for index in &argsort_indices {
639 if cumsum >= top_p {
640 probs[*index as usize] = 0.0;
641 } else {
642 cumsum += probs[*index as usize];
643 }
644 }
645
646 if min_p <= 0.0 || min_p >= 1.0 {
647 return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
648 }
649
650 let max_p = probs[argsort_indices[0] as usize];
651
652 for index in &argsort_indices {
659 if max_p * min_p >= probs[*index as usize] {
660 probs[*index as usize] = 0.0;
661 }
662 }
663
664 self.sample_multinomial(probs, argsort_indices, return_logprobs, rng)
666 }
667
668 fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
669 if context.is_empty() {
670 candle_core::bail!("Penalty context is empty, this should not happen.");
671 }
672
673 self.apply_dry_penalty(&mut logits, context)?;
675
676 self.apply_freq_presc_penalty(&mut logits, context)?;
678
679 let vocab_size = logits.len();
680 Tensor::from_vec(logits, vocab_size, &Device::Cpu)
681 }
682
683 fn apply_freq_presc_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
684 if self.frequency_penalty.is_some() || self.presence_penalty.is_some() {
685 let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
686 let presence_penalty = self.presence_penalty.unwrap_or(0.);
687
688 let mut counts = vec![0.0f32; logits.len()];
691 for ctx in context.iter() {
692 if *ctx as usize >= logits.len() {
694 continue;
695 }
696 counts[*ctx as usize] += 1.0;
697 }
698
699 for (token_id, logit) in logits.iter_mut().enumerate() {
700 let count = counts[token_id];
701 *logit = *logit
702 - count * frequency_penalty
703 - if count > 0.0 { 1. } else { 0. } * presence_penalty;
704 }
705 }
706 Ok(())
707 }
708
709 fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
710 if let Some(ref params) = self.dry_params {
711 if params.multiplier == 0. {
712 return Ok(());
713 }
714
715 let match_indices = context
716 .par_iter()
717 .enumerate()
718 .take(context.len() - 1)
719 .filter(|(_i, x)| *context.last().unwrap() == **x)
720 .map(|(i, _)| i)
721 .collect::<Vec<_>>();
722
723 let mut match_lengths = HashMap::new();
724
725 for i in match_indices {
726 let next_token = context[i + 1];
727
728 if params.sequence_breakers.contains(&next_token) {
729 continue;
730 }
731
732 let mut match_length = 1;
733
734 while match_length < 50 {
736 if match_length > i {
737 break;
739 }
740
741 let j = i - match_length;
742
743 let prev_tok = context[context.len() - (match_length + 1)];
744 if context[j] != prev_tok {
745 break;
747 }
748
749 if params.sequence_breakers.contains(&prev_tok) {
750 break;
752 }
753
754 match_length += 1;
755 }
756
757 #[allow(clippy::map_entry)]
758 if match_lengths.contains_key(&next_token) {
759 match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
760 } else {
761 match_lengths.insert(next_token, match_length);
762 }
763 }
764
765 for (tok, match_len) in match_lengths {
767 if match_len >= params.allowed_length {
768 if tok as usize >= logits.len() {
770 continue;
771 }
772 let penalty = params.multiplier
773 * params.base.powf((match_len - params.allowed_length) as f32);
774 logits[tok as usize] -= penalty;
775 }
776 }
777 }
778 Ok(())
779 }
780
781 #[allow(unused)]
782 pub fn sample(
787 &self,
788 logits: Tensor,
789 context: &[u32],
790 return_logprobs: bool,
791 rng: Arc<Mutex<Isaac64Rng>>,
792 sample_speculative: bool,
793 multiple_sequences: bool,
794 ) -> Result<Logprobs> {
795 if cfg!(feature = "metal") && !multiple_sequences {
796 return self.sample_fast(
797 logits,
798 context,
799 return_logprobs,
800 self.top_k,
801 self.top_p,
802 self.min_p,
803 );
804 }
805
806 let logits = logits.to_vec1()?;
807 let mut logits = self.apply_penalties(logits, context)?;
808 for processor in &self.logits_processors {
809 logits = processor.apply(&logits, context)?;
810 }
811 let next_token = if sample_speculative {
812 match self.temperature {
813 None => self.sample_speculative_top_kp_min_p(
814 logits,
815 return_logprobs,
816 self.top_k,
817 self.top_p as f32,
818 self.min_p as f32,
819 )?,
820 Some(temperature) => {
821 let logits = (&logits / temperature)?;
822 let probs = candle_nn::ops::softmax_last_dim(&logits)?;
823
824 self.sample_speculative_top_kp_min_p(
825 probs,
826 return_logprobs,
827 self.top_k,
828 self.top_p as f32,
829 self.min_p as f32,
830 )?
831 }
832 }
833 } else {
834 match self.temperature {
835 None => self.sample_argmax(logits, return_logprobs)?,
836 Some(temperature) => {
837 let logits = (&logits / temperature)?;
838 let logits = candle_nn::ops::softmax_last_dim(&logits)?;
839 let mut probs: Vec<f32> = logits.to_vec1()?;
840
841 self.sample_top_kp_min_p(
842 &mut probs,
843 &logits,
844 self.top_k,
845 self.top_p as f32,
846 self.min_p as f32,
847 return_logprobs,
848 rng,
849 )?
850 }
851 }
852 };
853 Ok(next_token)
854 }
855}
856
857mod tests {
858 #[test]
859 fn test_argmax() {
860 use super::Sampler;
861 use candle_core::{Device, Tensor};
862 use rand::SeedableRng;
863 use rand_isaac::Isaac64Rng;
864 use std::sync::Arc;
865 use std::sync::Mutex;
866
867 let sampler =
868 Sampler::new(None, 10, None, None, None, None, 32, 0.1, 0.05, vec![]).unwrap();
869 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
870 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
871 let res = sampler
872 .sample(
873 logits,
874 &(0..1024).collect::<Vec<_>>(),
875 false,
876 rng,
877 false,
878 false,
879 )
880 .unwrap();
881 assert_eq!(res.token, 1023);
882 assert_eq!(res.top_logprobs, None);
883 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
884 }
885
886 #[test]
887 fn test_gumbel_speculative() {
888 use super::Sampler;
889 use candle_core::{Device, Tensor};
890 use rand::SeedableRng;
891 use rand_isaac::Isaac64Rng;
892 use std::sync::Arc;
893 use std::sync::Mutex;
894
895 let sampler =
896 Sampler::new(None, 10, None, None, None, None, 32, 0.1, 0.05, vec![]).unwrap();
897 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
898 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
899 let res = sampler
900 .sample(
901 logits,
902 &(0..1024).collect::<Vec<_>>(),
903 false,
904 rng,
905 true,
906 false,
907 )
908 .unwrap();
909 assert_eq!(res.token, 1023);
910 assert_eq!(res.top_logprobs, None);
911 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
912 }
913}