mistralrs_core/
sampler.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4    collections::{HashMap, HashSet},
5    iter::zip,
6    sync::{Arc, Mutex},
7};
8
9use candle_core::{Device, Error, Result, Tensor, D};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use once_cell::sync::Lazy;
14use rand::distr::{weighted::WeightedIndex, Distribution};
15use rand_isaac::Isaac64Rng;
16use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
17use serde::{Deserialize, Serialize};
18use tokenizers::Tokenizer;
19
20static DRY_SEQUENCE_BREAKERS: Lazy<Vec<String>> =
21    Lazy::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
24/// Stop sequences or ids.
25pub enum StopTokens {
26    Seqs(Vec<String>),
27    Ids(Vec<u32>),
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
31/// Sampling params are used to control sampling.
32pub struct SamplingParams {
33    pub temperature: Option<f64>,
34    pub top_k: Option<usize>,
35    pub top_p: Option<f64>,
36    pub min_p: Option<f64>,
37    pub top_n_logprobs: usize,
38    pub frequency_penalty: Option<f32>,
39    pub presence_penalty: Option<f32>,
40    pub stop_toks: Option<StopTokens>,
41    pub max_len: Option<usize>,
42    pub logits_bias: Option<HashMap<u32, f32>>,
43    pub n_choices: usize,
44    pub dry_params: Option<DrySamplingParams>,
45}
46
47impl SamplingParams {
48    /// This sets up the parameters so that there is:
49    /// - No temperature, topk, topp, minp
50    /// - No penalties, stop tokens, or logit bias
51    /// - No maximum length
52    pub fn deterministic() -> Self {
53        Self {
54            temperature: None,
55            top_k: Some(1),
56            top_p: None,
57            min_p: None,
58            top_n_logprobs: 0,
59            frequency_penalty: None,
60            presence_penalty: None,
61            stop_toks: None,
62            max_len: None,
63            logits_bias: None,
64            n_choices: 1,
65            dry_params: None,
66        }
67    }
68}
69
70#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct DrySamplingParams {
72    pub sequence_breakers: Vec<String>,
73    pub multiplier: f32,
74    pub base: f32,
75    pub allowed_length: usize,
76}
77
78impl DrySamplingParams {
79    pub fn new_with_defaults(
80        multiplier: f32,
81        sequence_breakers: Option<Vec<String>>,
82        base: Option<f32>,
83        allowed_length: Option<usize>,
84    ) -> anyhow::Result<Self> {
85        Ok(Self {
86            base: base.unwrap_or(1.75),
87            allowed_length: allowed_length.unwrap_or(2),
88            sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
89            multiplier,
90        })
91    }
92}
93
94impl Default for DrySamplingParams {
95    fn default() -> Self {
96        Self {
97            multiplier: 0.0,
98            base: 1.75,
99            allowed_length: 2,
100            sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
101        }
102    }
103}
104
105#[derive(Clone, Debug)]
106struct DrySamplingParamsInner {
107    pub sequence_breakers: HashSet<u32>,
108    pub multiplier: f32,
109    pub base: f32,
110    pub allowed_length: usize,
111}
112
113impl DrySamplingParamsInner {
114    pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
115        Ok(Self {
116            base: other.base,
117            allowed_length: other.allowed_length,
118            sequence_breakers: HashSet::from_iter(
119                other
120                    .sequence_breakers
121                    .into_iter()
122                    .map(|breaker| {
123                        tokenizer
124                            // Prefix with 'a' to get the correct encoding of the token at the end of a text.
125                            //
126                            // FIXME: This is a hack. See https://github.com/LostRuins/koboldcpp/pull/982
127                            //        for the correct solution which covers multi-token sequence breakers
128                            //        and ambiguous encodings.
129                            .encode_fast(["a", &breaker].concat(), true)
130                            .map_err(anyhow::Error::msg)
131                            .map(|enc| {
132                                let ids = enc.get_ids();
133                                if !ids.is_empty() {
134                                    None
135                                } else {
136                                    Some(ids[ids.len() - 1])
137                                }
138                            })
139                    })
140                    .collect::<anyhow::Result<Vec<_>>>()?
141                    .into_iter()
142                    .flatten()
143                    .collect::<Vec<_>>(),
144            ),
145            multiplier: other.multiplier,
146        })
147    }
148}
149
150/// Customizable logits processor.
151///
152/// # Example
153/// ```rust
154/// use std::{sync::Arc, ops::Mul};
155/// use mistralrs_core::CustomLogitsProcessor;
156/// use candle_core::{Result, Tensor};
157///
158/// struct ThresholdLogitsProcessor;
159/// impl CustomLogitsProcessor for ThresholdLogitsProcessor {
160///     fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
161///         // Mask is 1 for true, 0 for false.
162///         let mask = logits.ge(0.5)?;
163///         logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
164///     }
165/// }
166/// let processor1: Arc<dyn CustomLogitsProcessor> = Arc::new(|logits: &Tensor, _context: &[u32]| logits * 1.23);
167/// let processor2: Arc<dyn CustomLogitsProcessor> = Arc::new(ThresholdLogitsProcessor);
168/// ```
169pub trait CustomLogitsProcessor: Send + Sync {
170    /// Logits and sequence context (prompt and generated tokens), returning modified tokens.
171    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
172}
173
174impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
175    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
176        self(logits, context)
177    }
178}
179
180/// Sampler for sampling.
181#[derive(Clone)]
182pub struct Sampler {
183    temperature: Option<f64>,
184    top_n_logprobs: usize,
185    tokenizer: Option<Arc<Tokenizer>>,
186    frequency_penalty: Option<f32>,
187    presence_penalty: Option<f32>,
188    dry_params: Option<DrySamplingParamsInner>,
189    top_k: i64,
190    top_p: f64,
191    min_p: f64,
192    logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
193}
194
195#[cfg_attr(feature = "pyo3_macros", pyclass)]
196#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
198/// Top-n logprobs element
199pub struct TopLogprob {
200    pub token: u32,
201    pub logprob: f32,
202    pub bytes: Option<String>,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct Logprobs {
207    pub token: u32,
208    pub logprob: f32,
209    pub bytes: Option<String>,
210    pub top_logprobs: Option<Vec<TopLogprob>>,
211}
212
213fn argmax_sample_last_dim(logits: &Tensor) -> Result<Tensor> {
214    logits.argmax(D::Minus1)
215}
216
217impl Sampler {
218    #[allow(clippy::too_many_arguments)]
219    pub fn new(
220        temperature: Option<f64>,
221        top_n_logprobs: usize,
222        tokenizer: Option<Arc<Tokenizer>>,
223        frequency_penalty: Option<f32>,
224        presence_penalty: Option<f32>,
225        dry_params: Option<DrySamplingParams>,
226        top_k: i64,
227        top_p: f64,
228        min_p: f64,
229        logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
230    ) -> anyhow::Result<Self> {
231        let temperature = if temperature.is_none_or(|v| v < 1e-7) {
232            None
233        } else {
234            temperature
235        };
236        let dry_params = if let Some(ref tokenizer) = tokenizer {
237            dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
238        } else {
239            None
240        };
241        let dry_params = match dry_params {
242            Some(fallible) => Some(fallible?),
243            None => None,
244        };
245        Ok(Self {
246            temperature,
247            top_n_logprobs,
248            tokenizer,
249            frequency_penalty,
250            presence_penalty,
251            dry_params,
252            top_k,
253            top_p,
254            min_p,
255            logits_processors,
256        })
257    }
258
259    fn get_top_logprobs(&self, probs: &[f32], argsort_indices: &[u32]) -> Result<Vec<TopLogprob>> {
260        let mut argsort_indices_sorted = argsort_indices.to_vec();
261        // Sort by descending prob
262        argsort_indices_sorted.sort_by(|a, b| {
263            probs[*b as usize]
264                .partial_cmp(&probs[*a as usize])
265                .expect("No ordering.")
266        });
267        // These are where the top n are
268        let top_n_toks_range = 0..self.top_n_logprobs;
269        // The top n's values
270        let top_n_logprobs = argsort_indices_sorted[top_n_toks_range.clone()]
271            .iter()
272            .map(|x| probs[*x as usize].log(10.0))
273            .collect::<Vec<_>>();
274        // Find where they actually are in the logits
275        let mut top_n_toks = Vec::new();
276        for val in top_n_toks_range {
277            top_n_toks.push(argsort_indices[val]);
278        }
279
280        if let Some(tokenizer) = &self.tokenizer {
281            let mut bytes = Vec::new();
282            for tok in &top_n_toks {
283                bytes.push(
284                    tokenizer
285                        .decode(&[{ *tok }], false)
286                        .map_err(|x| Error::Msg(x.to_string()))?,
287                );
288            }
289
290            Ok(zip(bytes, zip(top_n_toks, top_n_logprobs))
291                .map(|(bytes, (token, logprob))| TopLogprob {
292                    token,
293                    logprob,
294                    bytes: Some(bytes),
295                })
296                .collect::<Vec<_>>())
297        } else {
298            Ok(zip(top_n_toks, top_n_logprobs)
299                .map(|(token, logprob)| TopLogprob {
300                    token,
301                    logprob,
302                    bytes: None,
303                })
304                .collect::<Vec<_>>())
305        }
306    }
307
308    fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
309        let next_token = logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
310
311        let probs: Vec<f32> = logits.to_vec1()?;
312
313        let argsort_indices = (0..probs.len() as u32).collect::<Vec<_>>();
314        let logprob = probs[next_token as usize].log(10.0);
315
316        let top_logprobs = if return_logprobs {
317            Some(self.get_top_logprobs(&probs, &argsort_indices)?)
318        } else {
319            None
320        };
321
322        let bytes = if let Some(tokenizer) = &self.tokenizer {
323            Some(
324                tokenizer
325                    .decode(&[next_token], false)
326                    .map_err(|x| Error::Msg(x.to_string()))?,
327            )
328        } else {
329            None
330        };
331
332        Ok(Logprobs {
333            token: next_token,
334            logprob,
335            top_logprobs,
336            bytes,
337        })
338    }
339
340    fn sample_speculative_top_kp_min_p(
341        &self,
342        logits: Tensor,
343        return_logprobs: bool,
344        top_k: i64,
345        top_p: f32,
346        min_p: f32,
347    ) -> Result<Logprobs> {
348        let mut probs: Vec<f32> = logits.to_vec1()?;
349        let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
350
351        if top_k > 0 {
352            // Clamp smaller probabilities to zero.
353            for (index, val) in argsort_indices.iter().enumerate() {
354                if index >= top_k as usize {
355                    probs[*val as usize] = 0.0;
356                }
357            }
358        }
359
360        // TOP P
361
362        // top-p sampling (or "nucleus sampling") samples from the smallest set of
363        // tokens that exceed probability top_p. This way we never sample tokens that
364        // have very low probabilities and are less likely to go "off the rails".
365
366        // Clamp smaller probabilities to zero.
367        let mut cumsum = 0.;
368        for index in &argsort_indices {
369            if cumsum >= top_p {
370                probs[*index as usize] = 0.0;
371            } else {
372                cumsum += probs[*index as usize];
373            }
374        }
375
376        let max_p = probs[argsort_indices[0] as usize];
377
378        // MIN P
379
380        // min-p sampling samples from the tokens whose prob are greater than
381        // (max prob of token in dist) * min_p
382
383        // Clamp smaller probabilities to zero.
384        for index in &argsort_indices {
385            if max_p * min_p >= probs[*index as usize] {
386                probs[*index as usize] = 0.0;
387            }
388        }
389
390        let logits = Tensor::from_slice(&probs, logits.shape(), &Device::Cpu)?;
391
392        let next_token = argmax_sample_last_dim(&logits)?.to_scalar::<u32>()?;
393
394        let logprob = probs[next_token as usize].log(10.0);
395
396        let top_logprobs = if return_logprobs {
397            Some(self.get_top_logprobs(&probs, &argsort_indices)?)
398        } else {
399            None
400        };
401
402        let bytes = if let Some(tokenizer) = &self.tokenizer {
403            Some(
404                tokenizer
405                    .decode(&[next_token], false)
406                    .map_err(|x| Error::Msg(x.to_string()))?,
407            )
408        } else {
409            None
410        };
411
412        Ok(Logprobs {
413            token: next_token,
414            logprob,
415            top_logprobs,
416            bytes,
417        })
418    }
419
420    fn sample_multinomial(
421        &self,
422        probs: &mut Vec<f32>,
423        argsort_indices: Vec<u32>,
424        return_logprobs: bool,
425        rng: Arc<Mutex<Isaac64Rng>>,
426    ) -> Result<Logprobs> {
427        let distr = WeightedIndex::new(&*probs).map_err(Error::wrap)?;
428
429        let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
430        let next_token = distr.sample(&mut mut_ref_rng); // "Find the first item which has a weight *higher* than the chosen weight."
431        let logprob = probs[next_token].log(10.0);
432
433        let top_logprobs = if return_logprobs {
434            Some(self.get_top_logprobs(probs, &argsort_indices)?)
435        } else {
436            None
437        };
438
439        let bytes = if let Some(tokenizer) = &self.tokenizer {
440            Some(
441                tokenizer
442                    .decode(&[next_token.try_into().unwrap()], false)
443                    .map_err(|x| Error::Msg(x.to_string()))?,
444            )
445        } else {
446            None
447        };
448
449        Ok(Logprobs {
450            token: next_token as u32,
451            logprob,
452            top_logprobs,
453            bytes,
454        })
455    }
456
457    #[allow(clippy::too_many_arguments)]
458    fn sample_top_kp_min_p(
459        &self,
460        probs: &mut Vec<f32>,
461        logits: &Tensor,
462        top_k: i64,
463        top_p: f32,
464        min_p: f32,
465        return_logprobs: bool,
466        rng: Arc<Mutex<Isaac64Rng>>,
467    ) -> Result<Logprobs> {
468        let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;
469
470        if top_k > 0 {
471            // Clamp smaller probabilities to zero.
472            for (index, val) in argsort_indices.iter().enumerate() {
473                if index >= top_k as usize {
474                    probs[*val as usize] = 0.0;
475                }
476            }
477        }
478
479        if top_p <= 0.0 || top_p >= 1.0 {
480            return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
481        }
482
483        // TOP P
484
485        // top-p sampling (or "nucleus sampling") samples from the smallest set of
486        // tokens that exceed probability top_p. This way we never sample tokens that
487        // have very low probabilities and are less likely to go "off the rails".
488
489        // Clamp smaller probabilities to zero.
490        let mut cumsum = 0.;
491        for index in &argsort_indices {
492            if cumsum >= top_p {
493                probs[*index as usize] = 0.0;
494            } else {
495                cumsum += probs[*index as usize];
496            }
497        }
498
499        if min_p <= 0.0 || min_p >= 1.0 {
500            return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
501        }
502
503        let max_p = probs[argsort_indices[0] as usize];
504
505        // MIN P
506
507        // min-p sampling samples from the tokens whose prob are greater than
508        // (max prob of token in dist) * min_p
509
510        // Clamp smaller probabilities to zero.
511        for index in &argsort_indices {
512            if max_p * min_p >= probs[*index as usize] {
513                probs[*index as usize] = 0.0;
514            }
515        }
516
517        // Sample with clamped probabilities.
518        self.sample_multinomial(probs, argsort_indices, return_logprobs, rng)
519    }
520
521    fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
522        if context.is_empty() {
523            candle_core::bail!("Penalty context is empty, this should not happen.");
524        }
525
526        // Dry penalty
527        self.apply_dry_penalty(&mut logits, context)?;
528
529        // Frequency and Presence penalty
530        self.apply_freq_presc_penalty(&mut logits, context)?;
531
532        let vocab_size = logits.len();
533        Tensor::from_vec(logits, vocab_size, &Device::Cpu)
534    }
535
536    fn apply_freq_presc_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
537        if self.frequency_penalty.is_some() || self.presence_penalty.is_some() {
538            let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
539            let presence_penalty = self.presence_penalty.unwrap_or(0.);
540
541            //mu[j] -> mu[j] - c[j] * alpha_frequency - float(c[j] > 0) * alpha_presence
542
543            let mut counts = vec![0.0f32; logits.len()];
544            for ctx in context.iter() {
545                // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
546                if *ctx as usize >= logits.len() {
547                    continue;
548                }
549                counts[*ctx as usize] += 1.0;
550            }
551
552            for (token_id, logit) in logits.iter_mut().enumerate() {
553                let count = counts[token_id];
554                *logit = *logit
555                    - count * frequency_penalty
556                    - if count > 0.0 { 1. } else { 0. } * presence_penalty;
557            }
558        }
559        Ok(())
560    }
561
562    fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
563        if let Some(ref params) = self.dry_params {
564            if params.multiplier == 0. {
565                return Ok(());
566            }
567
568            let match_indices = context
569                .par_iter()
570                .enumerate()
571                .take(context.len() - 1)
572                .filter(|(_i, x)| *context.last().unwrap() == **x)
573                .map(|(i, _)| i)
574                .collect::<Vec<_>>();
575
576            let mut match_lengths = HashMap::new();
577
578            for i in match_indices {
579                let next_token = context[i + 1];
580
581                if params.sequence_breakers.contains(&next_token) {
582                    continue;
583                }
584
585                let mut match_length = 1;
586
587                // Limit match length to avoid quadratic runtime and potential DoS with adversarial inputs.
588                while match_length < 50 {
589                    if match_length > i {
590                        // Start of input
591                        break;
592                    }
593
594                    let j = i - match_length;
595
596                    let prev_tok = context[context.len() - (match_length + 1)];
597                    if context[j] != prev_tok {
598                        // Start of match reached
599                        break;
600                    }
601
602                    if params.sequence_breakers.contains(&prev_tok) {
603                        // Seq breaking tok reached
604                        break;
605                    }
606
607                    match_length += 1;
608                }
609
610                #[allow(clippy::map_entry)]
611                if match_lengths.contains_key(&next_token) {
612                    match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
613                } else {
614                    match_lengths.insert(next_token, match_length);
615                }
616            }
617
618            // Actually apply penalties
619            for (tok, match_len) in match_lengths {
620                if match_len >= params.allowed_length {
621                    // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
622                    if tok as usize >= logits.len() {
623                        continue;
624                    }
625                    let penalty = params.multiplier
626                        * params.base.powf((match_len - params.allowed_length) as f32);
627                    logits[tok as usize] -= penalty;
628                }
629            }
630        }
631        Ok(())
632    }
633
634    /// Sample the provided tokens.
635    ///
636    /// If the temperature is `None`, argmax sampling is used. Otherwise, the selected sampling is used.
637    /// With `top-p` sampling, if the `top-p` value is `<= 0.0` or `>= 1.0`, multinomial sampling is used.
638    pub fn sample(
639        &self,
640        logits: Tensor,
641        context: &[u32],
642        return_logprobs: bool,
643        rng: Arc<Mutex<Isaac64Rng>>,
644        sample_speculative: bool,
645    ) -> Result<Logprobs> {
646        let logits = logits.to_vec1()?;
647        let mut logits = self.apply_penalties(logits, context)?;
648        for processor in &self.logits_processors {
649            logits = processor.apply(&logits, context)?;
650        }
651        let next_token = if sample_speculative {
652            match self.temperature {
653                None => self.sample_speculative_top_kp_min_p(
654                    logits,
655                    return_logprobs,
656                    self.top_k,
657                    self.top_p as f32,
658                    self.min_p as f32,
659                )?,
660                Some(temperature) => {
661                    let logits = (&logits / temperature)?;
662                    let probs = candle_nn::ops::softmax_last_dim(&logits)?;
663
664                    self.sample_speculative_top_kp_min_p(
665                        probs,
666                        return_logprobs,
667                        self.top_k,
668                        self.top_p as f32,
669                        self.min_p as f32,
670                    )?
671                }
672            }
673        } else {
674            match self.temperature {
675                None => self.sample_argmax(logits, return_logprobs)?,
676                Some(temperature) => {
677                    let logits = (&logits / temperature)?;
678                    let logits = candle_nn::ops::softmax_last_dim(&logits)?;
679                    let mut probs: Vec<f32> = logits.to_vec1()?;
680
681                    self.sample_top_kp_min_p(
682                        &mut probs,
683                        &logits,
684                        self.top_k,
685                        self.top_p as f32,
686                        self.min_p as f32,
687                        return_logprobs,
688                        rng,
689                    )?
690                }
691            }
692        };
693        Ok(next_token)
694    }
695}
696
697mod tests {
698    #[test]
699    fn test_argmax() {
700        use super::Sampler;
701        use candle_core::{Device, Tensor};
702        use rand::SeedableRng;
703        use rand_isaac::Isaac64Rng;
704        use std::sync::Arc;
705        use std::sync::Mutex;
706
707        let sampler =
708            Sampler::new(None, 10, None, None, None, None, 32, 0.1, 0.05, vec![]).unwrap();
709        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
710        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
711        let res = sampler
712            .sample(logits, &(0..1024).collect::<Vec<_>>(), false, rng, false)
713            .unwrap();
714        assert_eq!(res.token, 1023);
715        assert_eq!(res.top_logprobs, None);
716        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
717    }
718
719    #[test]
720    fn test_gumbel_speculative() {
721        use super::Sampler;
722        use candle_core::{Device, Tensor};
723        use rand::SeedableRng;
724        use rand_isaac::Isaac64Rng;
725        use std::sync::Arc;
726        use std::sync::Mutex;
727
728        let sampler =
729            Sampler::new(None, 10, None, None, None, None, 32, 0.1, 0.05, vec![]).unwrap();
730        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
731        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
732        let res = sampler
733            .sample(logits, &(0..1024).collect::<Vec<_>>(), false, rng, true)
734            .unwrap();
735        assert_eq!(res.token, 1023);
736        assert_eq!(res.top_logprobs, None);
737        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
738    }
739}