mistralrs_core/dummy_paged_attention/layers/
paged_attention.rs

1use candle_core::{Device, Result, Tensor};
2
3use crate::{
4    attention::SdpaParams,
5    pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
6};
7
8#[allow(dead_code)]
9pub struct PagedAttention;
10
11impl PagedAttention {
12    pub fn new(
13        _head_dim: usize,
14        _device: &Device,
15        _alibi_slopes: Option<Vec<f32>>,
16    ) -> Result<Self> {
17        unreachable!();
18    }
19
20    #[allow(clippy::too_many_arguments)]
21    #[allow(unused_variables)]
22    /// query: shape = [batch_size, seq_len, num_heads * head_size]
23    /// key: shape = [batch_size, seq_len, num_kv_heads * head_size]
24    /// value: shape = [batch_size, num_kv_heads * head_size]
25    /// key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
26    ///     block_size, x]
27    /// value_cache: shape = [num_blocks, num_kv_heads, head_size,
28    ///     block_size]
29    /// input_metadata: metadata for paged attention.
30    pub fn forward(
31        &self,
32        _query: &Tensor,
33        _key: &Tensor,
34        _value: &Tensor,
35        _attention_mask: Option<&Tensor>,
36        _key_cache: Option<Tensor>,
37        _value_cache: Option<Tensor>,
38        _input_metadata: &PagedAttentionInputMetadata,
39        _sdpa_params: &SdpaParams,
40        _flash_params: Option<&FlashParams>,
41    ) -> Result<Tensor> {
42        unreachable!();
43    }
44}