mistralrs_core/attention/
mod.rs1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::pipeline::text_models_inputs_processor::FlashParams;
4
5use candle_core::{Device, Result, Tensor};
6
7mod backends;
8
9#[allow(unused)]
10pub(crate) use backends::{flash_attn, maybe_synchronize, naive_sdpa};
11
12fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
13 if n_rep == 1 {
14 Ok(x)
15 } else {
16 let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
17 Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
18 }
19}
20
21pub struct SdpaParams {
22 pub n_kv_groups: usize,
23 pub softcap: Option<f32>,
24 pub softmax_scale: f32,
25 pub sliding_window: Option<usize>,
26}
27
28pub struct Sdpa;
29
30impl Sdpa {
31 #[allow(unused_variables, clippy::too_many_arguments)]
43 pub fn run_attention(
44 &self,
45 q: &Tensor,
46 k: &Tensor,
47 v: &Tensor,
48 mask: Option<&Tensor>,
49 flash_params: Option<&FlashParams>,
50 sdpa_params: &SdpaParams,
51 ) -> Result<Tensor> {
52 let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
53 let (_, _, _, k_head_dim) = k.dims4()?;
54 let (_, _, _, v_head_dim) = v.dims4()?;
55 if crate::using_flash_attn() && q.device().is_cuda() {
56 let q = q.transpose(1, 2)?;
58 let k = k.transpose(1, 2)?;
59 let v = v.transpose(1, 2)?;
60 return flash_attn(&q, &k, &v, flash_params, sdpa_params)?.transpose(1, 2);
61 }
62
63 let all_head_dims_match = head_dim == k_head_dim && k_head_dim == v_head_dim;
67 let tgt_mask_shape = vec![b_sz, n_attn_heads, seq_len, k.dim(2)?];
68 let can_use_mask = mask.is_none_or(|mask| {
69 mask.layout().broadcast_as(tgt_mask_shape.clone()).is_ok()
70 && sdpa_params.softcap.is_none_or(|x| x == 1.0)
71 });
72 let valid_head_dims: &[usize] = if seq_len == 1 {
73 &[32, 64, 72, 80, 96, 128, 256]
74 } else {
75 &[32, 64, 72, 80, 96, 128]
78 };
79 if [q, k, v].into_iter().all(|x| x.device().is_metal())
80 && all_head_dims_match
81 && valid_head_dims.contains(&head_dim)
82 && can_use_mask
83 {
84 let mask = match mask {
85 Some(mask) => Some(mask.broadcast_as(tgt_mask_shape)?),
86 None => None,
87 };
88 return candle_nn::ops::sdpa(
89 q,
90 k,
91 v,
92 mask.as_ref(),
93 false,
94 sdpa_params.softmax_scale,
95 sdpa_params.softcap.unwrap_or(1.0),
96 );
97 }
98
99 let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
100 let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
101
102 if mask.is_some_and(|x| x.rank() == 2) || mistralrs_quant::distributed::use_nccl() {
103 return naive_sdpa(q, &k, &v, mask, sdpa_params);
104 }
105
106 #[allow(unused)]
108 if let (Device::Cuda(_), Some(cublaslt)) = (
109 q.device(),
110 mistralrs_quant::cublaslt::CUBLASLT_CONTROLLER.get(),
111 ) {
112 #[cfg(feature = "cuda")]
113 {
114 maybe_synchronize(q.device())?;
115
116 let k = k.flatten(0, 1)?;
118 let q = q.flatten(0, 1)?;
119 let v = v.flatten(0, 1)?;
120 let attention_bias = match mask {
121 Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
122 Some(mask.repeat((n_attn_heads, 1, 1))?)
123 }
124 Some(mask) if mask.rank() == 3 => Some(mask.clone()),
125 Some(mask) if mask.rank() == 4 => {
126 Some(mask.broadcast_as(tgt_mask_shape)?.flatten(0, 1)?)
127 }
128 Some(mask) => {
129 candle_core::bail!("cublaslt attn mask: rank must be 3 or 4")
130 }
131 None => None,
132 };
133
134 let beta = match attention_bias.is_some() {
137 true => Some(1.0),
138 false => None,
139 };
140
141 let mut attention_scores = cublaslt.batch_matmul(
144 &k,
145 &q,
146 attention_bias.as_ref(),
147 Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
148 beta,
149 None,
150 None,
151 )?;
152 if let Some(softcap) = sdpa_params.softcap {
153 attention_scores = (attention_scores.tanh()? * softcap as f64)?;
154 }
155 candle_nn::ops::inplace_softmax_last_dim(&mut attention_scores)?;
156
157 let context_layer = cublaslt.batch_matmul(
158 &v.t()?.contiguous()?,
159 &attention_scores,
160 Some(&q),
162 None,
163 None,
164 None,
165 None,
166 )?;
167
168 context_layer.reshape((b_sz, n_attn_heads, seq_len, v_head_dim))
170 }
171 #[cfg(not(feature = "cuda"))]
172 {
173 candle_core::bail!("`cuda` feature is not enabled")
174 }
175 } else {
176 naive_sdpa(q, &k, &v, mask, sdpa_params)
177 }
178 }
179}