pub trait CustomLogitsProcessor: Send + Sync {
// Required method
fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor, Error>;
}
Expand description
Customizable logits processor.
§Example
use std::{sync::Arc, ops::Mul};
use mistralrs_core::CustomLogitsProcessor;
use candle_core::{Result, Tensor};
struct ThresholdLogitsProcessor;
impl CustomLogitsProcessor for ThresholdLogitsProcessor {
fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
// Mask is 1 for true, 0 for false.
let mask = logits.ge(0.5)?;
logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
}
}
let processor1: Arc<dyn CustomLogitsProcessor> = Arc::new(|logits: &Tensor, _context: &[u32]| logits * 1.23);
let processor2: Arc<dyn CustomLogitsProcessor> = Arc::new(ThresholdLogitsProcessor);