1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4 collections::{HashMap, HashSet},
5 iter::zip,
6 sync::{Arc, Mutex},
7};
8
9use candle_core::{Device, Error, Result, Tensor, D};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use once_cell::sync::Lazy;
14use rand::distr::{weighted::WeightedIndex, Distribution};
15use rand_isaac::Isaac64Rng;
16use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
17use serde::{Deserialize, Serialize};
18use tokenizers::Tokenizer;
19
20static DRY_SEQUENCE_BREAKERS: Lazy<Vec<String>> =
21 Lazy::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
24pub enum StopTokens {
26 Seqs(Vec<String>),
27 Ids(Vec<u32>),
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct SamplingParams {
33 pub temperature: Option<f64>,
34 pub top_k: Option<usize>,
35 pub top_p: Option<f64>,
36 pub min_p: Option<f64>,
37 pub top_n_logprobs: usize,
38 pub frequency_penalty: Option<f32>,
39 pub presence_penalty: Option<f32>,
40 pub stop_toks: Option<StopTokens>,
41 pub max_len: Option<usize>,
42 pub logits_bias: Option<HashMap<u32, f32>>,
43 pub n_choices: usize,
44 pub dry_params: Option<DrySamplingParams>,
45}
46
47impl SamplingParams {
48 pub fn deterministic() -> Self {
53 Self {
54 temperature: None,
55 top_k: Some(1),
56 top_p: None,
57 min_p: None,
58 top_n_logprobs: 0,
59 frequency_penalty: None,
60 presence_penalty: None,
61 stop_toks: None,
62 max_len: None,
63 logits_bias: None,
64 n_choices: 1,
65 dry_params: None,
66 }
67 }
68}
69
70#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct DrySamplingParams {
72 pub sequence_breakers: Vec<String>,
73 pub multiplier: f32,
74 pub base: f32,
75 pub allowed_length: usize,
76}
77
78impl DrySamplingParams {
79 pub fn new_with_defaults(
80 multiplier: f32,
81 sequence_breakers: Option<Vec<String>>,
82 base: Option<f32>,
83 allowed_length: Option<usize>,
84 ) -> anyhow::Result<Self> {
85 Ok(Self {
86 base: base.unwrap_or(1.75),
87 allowed_length: allowed_length.unwrap_or(2),
88 sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
89 multiplier,
90 })
91 }
92}
93
94impl Default for DrySamplingParams {
95 fn default() -> Self {
96 Self {
97 multiplier: 0.0,
98 base: 1.75,
99 allowed_length: 2,
100 sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
101 }
102 }
103}
104
105#[derive(Clone, Debug)]
106struct DrySamplingParamsInner {
107 pub sequence_breakers: HashSet<u32>,
108 pub multiplier: f32,
109 pub base: f32,
110 pub allowed_length: usize,
111}
112
113impl DrySamplingParamsInner {
114 pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
115 Ok(Self {
116 base: other.base,
117 allowed_length: other.allowed_length,
118 sequence_breakers: HashSet::from_iter(
119 other
120 .sequence_breakers
121 .into_iter()
122 .map(|breaker| {
123 tokenizer
124 .encode_fast(["a", &breaker].concat(), true)
130 .map_err(anyhow::Error::msg)
131 .map(|enc| {
132 let ids = enc.get_ids();
133 if !ids.is_empty() {
134 None
135 } else {
136 Some(ids[ids.len() - 1])
137 }
138 })
139 })
140 .collect::<anyhow::Result<Vec<_>>>()?
141 .into_iter()
142 .flatten()
143 .collect::<Vec<_>>(),
144 ),
145 multiplier: other.multiplier,
146 })
147 }
148}
149
150pub trait CustomLogitsProcessor: Send + Sync {
170 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
172}
173
174impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
175 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
176 self(logits, context)
177 }
178}
179
180#[derive(Clone)]
182pub struct Sampler {
183 temperature: Option<f64>,
184 top_n_logprobs: usize,
185 tokenizer: Option<Arc<Tokenizer>>,
186 frequency_penalty: Option<f32>,
187 presence_penalty: Option<f32>,
188 dry_params: Option<DrySamplingParamsInner>,
189 top_k: i64,
190 top_p: f64,
191 min_p: f64,
192 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
193}
194
195#[cfg_attr(feature = "pyo3_macros", pyclass)]
196#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
198pub struct TopLogprob {
200 pub token: u32,
201 pub logprob: f32,
202 pub bytes: Option<String>,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct Logprobs {
207 pub token: u32,
208 pub logprob: f32,
209 pub bytes: Option<String>,
210 pub top_logprobs: Option<Vec<TopLogprob>>,
211}
212
213fn argmax_sample_last_dim(logits: &Tensor) -> Result<Tensor> {
214 logits.argmax(D::Minus1)
215}
216
217impl Sampler {
218 #[allow(clippy::too_many_arguments)]
219 pub fn new(
220 temperature: Option<f64>,
221 top_n_logprobs: usize,
222 tokenizer: Option<Arc<Tokenizer>>,
223 frequency_penalty: Option<f32>,
224 presence_penalty: Option<f32>,
225 dry_params: Option<DrySamplingParams>,
226 top_k: i64,
227 top_p: f64,
228 min_p: f64,
229 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
230 ) -> anyhow::Result<Self> {
231 let temperature = if temperature.is_none_or(|v| v < 1e-7) {
232 None
233 } else {
234 temperature
235 };
236 let dry_params = if let Some(ref tokenizer) = tokenizer {
237 dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
238 } else {
239 None
240 };
241 let dry_params = match dry_params {
242 Some(fallible) => Some(fallible?),
243 None => None,
244 };
245 Ok(Self {
246 temperature,
247 top_n_logprobs,
248 tokenizer,
249 frequency_penalty,
250 presence_penalty,
251 dry_params,
252 top_k,
253 top_p,
254 min_p,
255 logits_processors,
256 })
257 }
258
259 fn get_top_logprobs(&self, probs: &[f32], argsort_indices: &[u32]) -> Result<Vec<TopLogprob>> {
260 let mut argsort_indices_sorted = argsort_indices.to_vec();
261 argsort_indices_sorted.sort_by(|a, b| {
263 probs[*b as usize]
264 .partial_cmp(&probs[*a as usize])
265 .expect("No ordering.")
266 });
267 let top_n_toks_range = 0..self.top_n_logprobs;
269 let top_n_logprobs = argsort_indices_sorted[top_n_toks_range.clone()]
271 .iter()
272 .map(|x| probs[*x as usize].log(10.0))
273 .collect::<Vec<_>>();
274 let mut top_n_toks = Vec::new();
276 for val in top_n_toks_range {
277 top_n_toks.push(argsort_indices[val]);
278 }
279
280 if let Some(tokenizer) = &self.tokenizer {
281 let mut bytes = Vec::new();
282 for tok in &top_n_toks {
283 bytes.push(
284 tokenizer
285 .decode(&[{ *tok }], false)
286 .map_err(|x| Error::Msg(x.to_string()))?,
287 );
288 }
289
290 Ok(zip(bytes, zip(top_n_toks, top_n_logprobs))
291 .map(|(bytes, (token, logprob))| TopLogprob {
292 token,
293 logprob,
294 bytes: Some(bytes),
295 })
296 .collect::<Vec<_>>())
297 } else {
298 Ok(zip(top_n_toks, top_n_logprobs)
299 .map(|(token, logprob)| TopLogprob {
300 token,
301 logprob,
302 bytes: None,
303 })
304 .collect::<Vec<_>>())
305 }
306 }
307
308 fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
309 let next_token = logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
310
311 let probs: Vec<f32> = logits.to_vec1()?;
312
313 let argsort_indices = (0..probs.len() as u32).collect::<Vec<_>>();
314 let logprob = probs[next_token as usize].log(10.0);
315
316 let top_logprobs = if return_logprobs {
317 Some(self.get_top_logprobs(&probs, &argsort_indices)?)
318 } else {
319 None
320 };
321
322 let bytes = if let Some(tokenizer) = &self.tokenizer {
323 Some(
324 tokenizer
325 .decode(&[next_token], false)
326 .map_err(|x| Error::Msg(x.to_string()))?,
327 )
328 } else {
329 None
330 };
331
332 Ok(Logprobs {
333 token: next_token,
334 logprob,
335 top_logprobs,
336 bytes,
337 })
338 }
339
340 fn sample_speculative_top_kp_min_p(
341 &self,
342 logits: Tensor,
343 return_logprobs: bool,
344 top_k: i64,
345 top_p: f32,
346 min_p: f32,
347 ) -> Result<Logprobs> {
348 let mut probs: Vec<f32> = logits.to_vec1()?;
349 let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
350
351 if top_k > 0 {
352 for (index, val) in argsort_indices.iter().enumerate() {
354 if index >= top_k as usize {
355 probs[*val as usize] = 0.0;
356 }
357 }
358 }
359
360 let mut cumsum = 0.;
368 for index in &argsort_indices {
369 if cumsum >= top_p {
370 probs[*index as usize] = 0.0;
371 } else {
372 cumsum += probs[*index as usize];
373 }
374 }
375
376 let max_p = probs[argsort_indices[0] as usize];
377
378 for index in &argsort_indices {
385 if max_p * min_p >= probs[*index as usize] {
386 probs[*index as usize] = 0.0;
387 }
388 }
389
390 let logits = Tensor::from_slice(&probs, logits.shape(), &Device::Cpu)?;
391
392 let next_token = argmax_sample_last_dim(&logits)?.to_scalar::<u32>()?;
393
394 let logprob = probs[next_token as usize].log(10.0);
395
396 let top_logprobs = if return_logprobs {
397 Some(self.get_top_logprobs(&probs, &argsort_indices)?)
398 } else {
399 None
400 };
401
402 let bytes = if let Some(tokenizer) = &self.tokenizer {
403 Some(
404 tokenizer
405 .decode(&[next_token], false)
406 .map_err(|x| Error::Msg(x.to_string()))?,
407 )
408 } else {
409 None
410 };
411
412 Ok(Logprobs {
413 token: next_token,
414 logprob,
415 top_logprobs,
416 bytes,
417 })
418 }
419
420 fn sample_multinomial(
421 &self,
422 probs: &mut Vec<f32>,
423 argsort_indices: Vec<u32>,
424 return_logprobs: bool,
425 rng: Arc<Mutex<Isaac64Rng>>,
426 ) -> Result<Logprobs> {
427 let distr = WeightedIndex::new(&*probs).map_err(Error::wrap)?;
428
429 let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
430 let next_token = distr.sample(&mut mut_ref_rng); let logprob = probs[next_token].log(10.0);
432
433 let top_logprobs = if return_logprobs {
434 Some(self.get_top_logprobs(probs, &argsort_indices)?)
435 } else {
436 None
437 };
438
439 let bytes = if let Some(tokenizer) = &self.tokenizer {
440 Some(
441 tokenizer
442 .decode(&[next_token.try_into().unwrap()], false)
443 .map_err(|x| Error::Msg(x.to_string()))?,
444 )
445 } else {
446 None
447 };
448
449 Ok(Logprobs {
450 token: next_token as u32,
451 logprob,
452 top_logprobs,
453 bytes,
454 })
455 }
456
457 #[allow(clippy::too_many_arguments)]
458 fn sample_top_kp_min_p(
459 &self,
460 probs: &mut Vec<f32>,
461 logits: &Tensor,
462 top_k: i64,
463 top_p: f32,
464 min_p: f32,
465 return_logprobs: bool,
466 rng: Arc<Mutex<Isaac64Rng>>,
467 ) -> Result<Logprobs> {
468 let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
469
470 if top_k > 0 {
471 for (index, val) in argsort_indices.iter().enumerate() {
473 if index >= top_k as usize {
474 probs[*val as usize] = 0.0;
475 }
476 }
477 }
478
479 if top_p <= 0.0 || top_p >= 1.0 {
480 return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
481 }
482
483 let mut cumsum = 0.;
491 for index in &argsort_indices {
492 if cumsum >= top_p {
493 probs[*index as usize] = 0.0;
494 } else {
495 cumsum += probs[*index as usize];
496 }
497 }
498
499 if min_p <= 0.0 || min_p >= 1.0 {
500 return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
501 }
502
503 let max_p = probs[argsort_indices[0] as usize];
504
505 for index in &argsort_indices {
512 if max_p * min_p >= probs[*index as usize] {
513 probs[*index as usize] = 0.0;
514 }
515 }
516
517 self.sample_multinomial(probs, argsort_indices, return_logprobs, rng)
519 }
520
521 fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
522 if context.is_empty() {
523 candle_core::bail!("Penalty context is empty, this should not happen.");
524 }
525
526 self.apply_dry_penalty(&mut logits, context)?;
528
529 self.apply_freq_presc_penalty(&mut logits, context)?;
531
532 let vocab_size = logits.len();
533 Tensor::from_vec(logits, vocab_size, &Device::Cpu)
534 }
535
536 fn apply_freq_presc_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
537 if self.frequency_penalty.is_some() || self.presence_penalty.is_some() {
538 let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
539 let presence_penalty = self.presence_penalty.unwrap_or(0.);
540
541 let mut counts = vec![0.0f32; logits.len()];
544 for ctx in context.iter() {
545 if *ctx as usize >= logits.len() {
547 continue;
548 }
549 counts[*ctx as usize] += 1.0;
550 }
551
552 for (token_id, logit) in logits.iter_mut().enumerate() {
553 let count = counts[token_id];
554 *logit = *logit
555 - count * frequency_penalty
556 - if count > 0.0 { 1. } else { 0. } * presence_penalty;
557 }
558 }
559 Ok(())
560 }
561
562 fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
563 if let Some(ref params) = self.dry_params {
564 if params.multiplier == 0. {
565 return Ok(());
566 }
567
568 let match_indices = context
569 .par_iter()
570 .enumerate()
571 .take(context.len() - 1)
572 .filter(|(_i, x)| *context.last().unwrap() == **x)
573 .map(|(i, _)| i)
574 .collect::<Vec<_>>();
575
576 let mut match_lengths = HashMap::new();
577
578 for i in match_indices {
579 let next_token = context[i + 1];
580
581 if params.sequence_breakers.contains(&next_token) {
582 continue;
583 }
584
585 let mut match_length = 1;
586
587 while match_length < 50 {
589 if match_length > i {
590 break;
592 }
593
594 let j = i - match_length;
595
596 let prev_tok = context[context.len() - (match_length + 1)];
597 if context[j] != prev_tok {
598 break;
600 }
601
602 if params.sequence_breakers.contains(&prev_tok) {
603 break;
605 }
606
607 match_length += 1;
608 }
609
610 #[allow(clippy::map_entry)]
611 if match_lengths.contains_key(&next_token) {
612 match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
613 } else {
614 match_lengths.insert(next_token, match_length);
615 }
616 }
617
618 for (tok, match_len) in match_lengths {
620 if match_len >= params.allowed_length {
621 if tok as usize >= logits.len() {
623 continue;
624 }
625 let penalty = params.multiplier
626 * params.base.powf((match_len - params.allowed_length) as f32);
627 logits[tok as usize] -= penalty;
628 }
629 }
630 }
631 Ok(())
632 }
633
634 pub fn sample(
639 &self,
640 logits: Tensor,
641 context: &[u32],
642 return_logprobs: bool,
643 rng: Arc<Mutex<Isaac64Rng>>,
644 sample_speculative: bool,
645 ) -> Result<Logprobs> {
646 let logits = logits.to_vec1()?;
647 let mut logits = self.apply_penalties(logits, context)?;
648 for processor in &self.logits_processors {
649 logits = processor.apply(&logits, context)?;
650 }
651 let next_token = if sample_speculative {
652 match self.temperature {
653 None => self.sample_speculative_top_kp_min_p(
654 logits,
655 return_logprobs,
656 self.top_k,
657 self.top_p as f32,
658 self.min_p as f32,
659 )?,
660 Some(temperature) => {
661 let logits = (&logits / temperature)?;
662 let probs = candle_nn::ops::softmax_last_dim(&logits)?;
663
664 self.sample_speculative_top_kp_min_p(
665 probs,
666 return_logprobs,
667 self.top_k,
668 self.top_p as f32,
669 self.min_p as f32,
670 )?
671 }
672 }
673 } else {
674 match self.temperature {
675 None => self.sample_argmax(logits, return_logprobs)?,
676 Some(temperature) => {
677 let logits = (&logits / temperature)?;
678 let logits = candle_nn::ops::softmax_last_dim(&logits)?;
679 let mut probs: Vec<f32> = logits.to_vec1()?;
680
681 self.sample_top_kp_min_p(
682 &mut probs,
683 &logits,
684 self.top_k,
685 self.top_p as f32,
686 self.min_p as f32,
687 return_logprobs,
688 rng,
689 )?
690 }
691 }
692 };
693 Ok(next_token)
694 }
695}
696
697mod tests {
698 #[test]
699 fn test_argmax() {
700 use super::Sampler;
701 use candle_core::{Device, Tensor};
702 use rand::SeedableRng;
703 use rand_isaac::Isaac64Rng;
704 use std::sync::Arc;
705 use std::sync::Mutex;
706
707 let sampler =
708 Sampler::new(None, 10, None, None, None, None, 32, 0.1, 0.05, vec![]).unwrap();
709 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
710 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
711 let res = sampler
712 .sample(logits, &(0..1024).collect::<Vec<_>>(), false, rng, false)
713 .unwrap();
714 assert_eq!(res.token, 1023);
715 assert_eq!(res.top_logprobs, None);
716 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
717 }
718
719 #[test]
720 fn test_gumbel_speculative() {
721 use super::Sampler;
722 use candle_core::{Device, Tensor};
723 use rand::SeedableRng;
724 use rand_isaac::Isaac64Rng;
725 use std::sync::Arc;
726 use std::sync::Mutex;
727
728 let sampler =
729 Sampler::new(None, 10, None, None, None, None, 32, 0.1, 0.05, vec![]).unwrap();
730 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
731 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
732 let res = sampler
733 .sample(logits, &(0..1024).collect::<Vec<_>>(), false, rng, true)
734 .unwrap();
735 assert_eq!(res.token, 1023);
736 assert_eq!(res.top_logprobs, None);
737 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
738 }
739}