mistralrs_core/dummy_paged_attention/layers/
paged_attention.rsuse candle_core::{Device, Result, Tensor};
use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
const _PARTITION_SIZE: usize = 512;
#[allow(dead_code)]
pub struct PagedAttention {
num_attention_heads: usize,
head_dim: usize,
num_key_value_heads: usize,
scale: f32,
sliding_window: Option<usize>,
num_queries_per_kv: usize,
alibi_slopes: Option<Tensor>,
}
impl PagedAttention {
pub fn new(
num_attention_heads: usize,
head_dim: usize,
scale: f32,
num_key_value_heads: Option<usize>,
sliding_window: Option<usize>,
device: &Device,
alibi_slopes: Option<Vec<f32>>,
) -> Result<Self> {
let num_key_value_heads = num_key_value_heads.unwrap_or(num_attention_heads);
let num_queries_per_kv = num_attention_heads / num_key_value_heads;
let alibi_slopes = if let Some(alibi_slopes) = alibi_slopes {
assert_eq!(alibi_slopes.len(), head_dim);
Some(Tensor::new(alibi_slopes, device)?)
} else {
None
};
Ok(Self {
num_attention_heads,
head_dim,
num_key_value_heads,
scale,
sliding_window,
num_queries_per_kv,
alibi_slopes,
})
}
#[allow(clippy::too_many_arguments)]
#[allow(unused_variables)]
pub fn forward(
&self,
_query: &Tensor,
_key: &Tensor,
_value: &Tensor,
_attention_mask: Option<&Tensor>,
_key_cache: Option<Tensor>,
_value_cache: Option<Tensor>,
_input_metadata: &mut PagedAttentionInputMetadata,
_softcapping: Option<f64>,
) -> Result<Tensor> {
unreachable!();
}
}