mistralrs_core/attention/
mod.rs1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::pipeline::text_models_inputs_processor::FlashParams;
4
5use candle_core::{Device, Result, Tensor};
6
7mod backends;
8
9#[allow(unused)]
10pub(crate) use backends::{flash_attn, maybe_synchronize, naive_sdpa};
11
12pub(crate) const ATTENTION_CHUNK_SIZE: usize = 1024;
14
15pub(crate) fn chunked_attention<F>(
17 q: &Tensor,
18 k: &Tensor,
19 v: &Tensor,
20 mask: Option<&Tensor>,
21 attention_fn: F,
22) -> Result<Tensor>
23where
24 F: Fn(&Tensor, &Tensor, &Tensor, Option<&Tensor>) -> Result<Tensor>,
25{
26 let seq_len = q.dim(2)?;
27
28 if seq_len <= ATTENTION_CHUNK_SIZE {
29 return attention_fn(q, k, v, mask);
31 }
32
33 let num_chunks = seq_len.div_ceil(ATTENTION_CHUNK_SIZE);
35 let mut attn_chunks = Vec::with_capacity(num_chunks);
36
37 for chunk_idx in 0..num_chunks {
38 let offset = chunk_idx * ATTENTION_CHUNK_SIZE;
39 let chunk_len = ATTENTION_CHUNK_SIZE.min(seq_len - offset);
40
41 let q_chunk = q.narrow(2, offset, chunk_len)?;
43
44 let mask_chunk = mask
46 .map(|m| {
47 match m.rank() {
48 2 => {
49 m.narrow(0, offset, chunk_len)
51 }
52 3 => {
53 m.narrow(1, offset, chunk_len)
55 }
56 4 => {
57 m.narrow(2, offset, chunk_len)
59 }
60 _ => m.narrow(2, offset, chunk_len), }
62 })
63 .transpose()?;
64
65 let att_chunk = attention_fn(&q_chunk, k, v, mask_chunk.as_ref())?;
67
68 attn_chunks.push(att_chunk);
69 }
70
71 Tensor::cat(&attn_chunks, 2)
73}
74
75fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
76 if n_rep == 1 {
77 Ok(x)
78 } else {
79 let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
80 Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
81 }
82}
83
84pub struct SdpaParams {
85 pub n_kv_groups: usize,
86 pub softcap: Option<f32>,
87 pub softmax_scale: f32,
88 pub sliding_window: Option<usize>,
89}
90
91pub struct Sdpa;
92
93impl Sdpa {
94 #[allow(unused_variables, clippy::too_many_arguments)]
106 pub fn run_attention(
107 &self,
108 q: &Tensor,
109 k: &Tensor,
110 v: &Tensor,
111 mask: Option<&Tensor>,
112 flash_params: Option<&FlashParams>,
113 sdpa_params: &SdpaParams,
114 ) -> Result<Tensor> {
115 let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
116 let (_, _, _, k_head_dim) = k.dims4()?;
117 let (_, _, _, v_head_dim) = v.dims4()?;
118 if crate::using_flash_attn() && q.device().is_cuda() {
119 let q = q.transpose(1, 2)?;
121 let k = k.transpose(1, 2)?;
122 let v = v.transpose(1, 2)?;
123 return flash_attn(&q, &k, &v, flash_params, sdpa_params)?.transpose(1, 2);
124 }
125
126 let all_head_dims_match = head_dim == k_head_dim && k_head_dim == v_head_dim;
130 let tgt_mask_shape = vec![b_sz, n_attn_heads, seq_len, k.dim(2)?];
131 let can_use_mask = mask.is_none_or(|mask| {
132 mask.layout().broadcast_as(tgt_mask_shape.clone()).is_ok()
133 && sdpa_params.softcap.is_none_or(|x| x == 1.0)
134 });
135 let valid_head_dims: &[usize] = if seq_len == 1 {
136 &[32, 64, 72, 80, 96, 128, 256]
137 } else {
138 &[32, 64, 72, 80, 96, 128]
141 };
142 if [q, k, v].into_iter().all(|x| x.device().is_metal())
143 && all_head_dims_match
144 && valid_head_dims.contains(&head_dim)
145 && can_use_mask
146 {
147 let mask = match mask {
148 Some(mask) => Some(mask.broadcast_as(tgt_mask_shape)?),
149 None => None,
150 };
151 return candle_nn::ops::sdpa(
152 q,
153 k,
154 v,
155 mask.as_ref(),
156 false,
157 sdpa_params.softmax_scale,
158 sdpa_params.softcap.unwrap_or(1.0),
159 );
160 }
161
162 let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
163 let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
164
165 if mask.is_some_and(|x| x.rank() == 2) || mistralrs_quant::distributed::use_nccl() {
166 return naive_sdpa(
167 &q.contiguous()?,
168 &k.contiguous()?,
169 &v.contiguous()?,
170 mask,
171 sdpa_params,
172 );
173 }
174
175 #[allow(unused)]
177 if let (Device::Cuda(_), Some(cublaslt)) = (
178 q.device(),
179 mistralrs_quant::cublaslt::CUBLASLT_CONTROLLER.get(),
180 ) {
181 #[cfg(feature = "cuda")]
182 {
183 maybe_synchronize(q.device())?;
184
185 let k_flat = k.flatten(0, 1)?;
187 let v_flat = v.flatten(0, 1)?;
188
189 chunked_attention(q, &k, &v, mask, |q_chunk, _k, _v, mask_chunk| {
190 let (chunk_b_sz, chunk_n_heads, chunk_seq_len, chunk_head_dim) =
192 q_chunk.dims4()?;
193 let q_flat = q_chunk.flatten(0, 1)?;
194
195 let attention_bias = match mask_chunk {
196 Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
197 Some(mask.repeat((chunk_n_heads, 1, 1))?)
198 }
199 Some(mask) if mask.rank() == 3 => Some(mask.clone()),
200 Some(mask) if mask.rank() == 4 => {
201 let tgt_shape =
202 vec![chunk_b_sz, chunk_n_heads, chunk_seq_len, k.dim(2)?];
203 Some(mask.broadcast_as(tgt_shape)?.flatten(0, 1)?)
204 }
205 Some(mask) => {
206 candle_core::bail!("cublaslt attn mask: rank must be 3 or 4")
207 }
208 None => None,
209 };
210
211 let beta = match attention_bias.is_some() {
214 true => Some(1.0),
215 false => None,
216 };
217
218 let mut attention_scores = cublaslt.batch_matmul(
221 &k_flat,
222 &q_flat,
223 attention_bias.as_ref(),
224 Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
225 beta,
226 None,
227 None,
228 )?;
229 if let Some(softcap) = sdpa_params.softcap {
230 attention_scores = (attention_scores.tanh()? * softcap as f64)?;
231 }
232 attention_scores = candle_nn::ops::softmax_last_dim(&attention_scores)?;
233
234 let context_layer = cublaslt.batch_matmul(
235 &v_flat.t()?.contiguous()?,
236 &attention_scores,
237 Some(&q_flat),
239 None,
240 None,
241 None,
242 None,
243 )?;
244
245 context_layer.reshape((chunk_b_sz, chunk_n_heads, chunk_seq_len, v_head_dim))
247 })
248 }
249 #[cfg(not(feature = "cuda"))]
250 {
251 candle_core::bail!("`cuda` feature is not enabled")
252 }
253 } else {
254 naive_sdpa(q, &k, &v, mask, sdpa_params)
255 }
256 }
257}