mistralrs_core

Trait CustomLogitsProcessor

source
pub trait CustomLogitsProcessor: Send + Sync {
    // Required method
    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
}
Expand description

Customizable logits processor.

§Example

use std::{sync::Arc, ops::Mul};
use mistralrs_core::CustomLogitsProcessor;
use candle_core::{Result, Tensor};

struct ThresholdLogitsProcessor;
impl CustomLogitsProcessor for ThresholdLogitsProcessor {
    fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
        // Mask is 1 for true, 0 for false.
        let mask = logits.ge(0.5)?;
        logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
    }
}
let processor1: Arc<dyn CustomLogitsProcessor> = Arc::new(|logits: &Tensor, _context: &[u32]| logits * 1.23);
let processor2: Arc<dyn CustomLogitsProcessor> = Arc::new(ThresholdLogitsProcessor);

Required Methods§

source

fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>

Logits and sequence context (prompt and generated tokens), returning modified tokens.

Implementors§

source§

impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T