1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{attention::backends::cpu, pipeline::text_models_inputs_processor::FlashParams};
4
5use candle_core::{DType, Device, Result, Tensor};
6
7mod backends;
8
9#[allow(unused)]
10pub(crate) use backends::{flash_attn, maybe_synchronize, naive_sdpa};
11
12pub(crate) const ATTENTION_CHUNK_SIZE: usize = 1024;
14
15pub(crate) fn chunked_attention<F>(
17 q: &Tensor,
18 k: &Tensor,
19 v: &Tensor,
20 mask: Option<&Tensor>,
21 attention_fn: F,
22) -> Result<Tensor>
23where
24 F: Fn(&Tensor, &Tensor, &Tensor, Option<&Tensor>) -> Result<Tensor>,
25{
26 let seq_len = q.dim(2)?;
27
28 if seq_len <= ATTENTION_CHUNK_SIZE {
29 return attention_fn(q, k, v, mask);
31 }
32
33 let num_chunks = seq_len.div_ceil(ATTENTION_CHUNK_SIZE);
35 let mut attn_chunks = Vec::with_capacity(num_chunks);
36
37 for chunk_idx in 0..num_chunks {
38 let offset = chunk_idx * ATTENTION_CHUNK_SIZE;
39 let chunk_len = ATTENTION_CHUNK_SIZE.min(seq_len - offset);
40
41 let q_chunk = q.narrow(2, offset, chunk_len)?;
43
44 let mask_chunk = mask
46 .map(|m| {
47 match m.rank() {
48 2 => {
49 m.narrow(0, offset, chunk_len)
51 }
52 3 => {
53 m.narrow(1, offset, chunk_len)
55 }
56 4 => {
57 m.narrow(2, offset, chunk_len)
59 }
60 _ => m.narrow(2, offset, chunk_len), }
62 })
63 .transpose()?;
64
65 let att_chunk = attention_fn(&q_chunk, k, v, mask_chunk.as_ref())?;
67
68 attn_chunks.push(att_chunk);
69 }
70
71 Tensor::cat(&attn_chunks, 2)
73}
74
75fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
76 if n_rep == 1 {
77 Ok(x)
78 } else {
79 let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
80 Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
81 }
82}
83
84pub struct SdpaParams {
85 pub n_kv_groups: usize,
86 pub softcap: Option<f32>,
87 pub softmax_scale: f32,
88 pub sliding_window: Option<usize>,
89}
90
91pub struct Sdpa;
92
93impl Sdpa {
94 #[allow(unused_variables, clippy::too_many_arguments)]
106 pub fn run_attention(
107 &self,
108 q: &Tensor,
109 k: &Tensor,
110 v: &Tensor,
111 mask: Option<&Tensor>,
112 flash_params: Option<&FlashParams>,
113 sdpa_params: &SdpaParams,
114 ) -> Result<Tensor> {
115 let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
116 let (_, _, _, k_head_dim) = k.dims4()?;
117 let (_, _, _, v_head_dim) = v.dims4()?;
118
119 let can_use_flash =
120 q.device().is_cpu() || q.device().is_cuda() && crate::using_flash_attn();
121
122 if can_use_flash {
123 let q = q.transpose(1, 2)?;
125 let k = k.transpose(1, 2)?;
126 let v = v.transpose(1, 2)?;
127
128 if q.device().is_cpu() {
129 match q.dtype() {
130 DType::F32 => {
131 return cpu::run_flash_attn_cpu::<f32>(&q, &k, &v, mask, sdpa_params);
132 }
133 DType::F16 => {
134 return cpu::run_flash_attn_cpu::<half::f16>(&q, &k, &v, mask, sdpa_params)
135 }
136 DType::BF16 => {
137 return cpu::run_flash_attn_cpu::<half::bf16>(
138 &q,
139 &k,
140 &v,
141 mask,
142 sdpa_params,
143 );
144 }
145 _ => {
146 return Err(candle_core::Error::Msg("Unsupported data type".into()));
147 }
148 }
149 } else {
150 return flash_attn(&q, &k, &v, flash_params, sdpa_params)?.transpose(1, 2);
151 }
152 }
153
154 self.run_attention_noflash(q, k, v, mask, sdpa_params)
155 }
156
157 #[allow(unused_variables, clippy::too_many_arguments)]
159 pub fn run_attention_noflash(
160 &self,
161 q: &Tensor,
162 k: &Tensor,
163 v: &Tensor,
164 mask: Option<&Tensor>,
165 sdpa_params: &SdpaParams,
166 ) -> Result<Tensor> {
167 let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
168 let (_, _, _, k_head_dim) = k.dims4()?;
169 let (_, _, _, v_head_dim) = v.dims4()?;
170
171 let all_head_dims_match = head_dim == k_head_dim && k_head_dim == v_head_dim;
175 let tgt_mask_shape = vec![b_sz, n_attn_heads, seq_len, k.dim(2)?];
176 let can_use_mask = mask.is_none_or(|mask| {
177 mask.layout().broadcast_as(tgt_mask_shape.clone()).is_ok()
178 && sdpa_params.softcap.is_none_or(|x| x == 1.0)
179 });
180 let valid_head_dims: &[usize] = if seq_len == 1 {
181 &[32, 64, 72, 80, 96, 128, 256]
182 } else {
183 &[32, 64, 72, 80, 96, 128]
186 };
187 if [q, k, v].into_iter().all(|x| x.device().is_metal())
188 && all_head_dims_match
189 && valid_head_dims.contains(&head_dim)
190 && can_use_mask
191 {
192 let mask = match mask {
193 Some(mask) => Some(mask.broadcast_as(tgt_mask_shape)?),
194 None => None,
195 };
196 return candle_nn::ops::sdpa(
197 q,
198 k,
199 v,
200 mask.as_ref(),
201 false,
202 sdpa_params.softmax_scale,
203 sdpa_params.softcap.unwrap_or(1.0),
204 );
205 }
206
207 let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
208 let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
209
210 if mask.is_some_and(|x| x.rank() == 2) || mistralrs_quant::distributed::use_nccl() {
211 return naive_sdpa(
212 &q.contiguous()?,
213 &k.contiguous()?,
214 &v.contiguous()?,
215 mask,
216 sdpa_params,
217 );
218 }
219
220 #[allow(unused)]
222 if let (Device::Cuda(_), Some(cublaslt)) = (
223 q.device(),
224 mistralrs_quant::cublaslt::CUBLASLT_CONTROLLER.get(),
225 ) {
226 #[cfg(feature = "cuda")]
227 {
228 maybe_synchronize(q.device())?;
229
230 let k_flat = k.flatten(0, 1)?;
232 let v_flat = v.flatten(0, 1)?;
233
234 chunked_attention(q, &k, &v, mask, |q_chunk, _k, _v, mask_chunk| {
235 let (chunk_b_sz, chunk_n_heads, chunk_seq_len, chunk_head_dim) =
237 q_chunk.dims4()?;
238 let q_flat = q_chunk.flatten(0, 1)?;
239
240 let attention_bias = match mask_chunk {
241 Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
242 Some(mask.repeat((chunk_n_heads, 1, 1))?)
243 }
244 Some(mask) if mask.rank() == 3 => Some(mask.clone()),
245 Some(mask) if mask.rank() == 4 => {
246 let tgt_shape =
247 vec![chunk_b_sz, chunk_n_heads, chunk_seq_len, k.dim(2)?];
248 Some(mask.broadcast_as(tgt_shape)?.flatten(0, 1)?)
249 }
250 Some(mask) => {
251 candle_core::bail!("cublaslt attn mask: rank must be 3 or 4")
252 }
253 None => None,
254 };
255
256 let beta = match attention_bias.is_some() {
259 true => Some(1.0),
260 false => None,
261 };
262
263 let mut attention_scores = cublaslt.batch_matmul(
266 &k_flat,
267 &q_flat,
268 attention_bias.as_ref(),
269 Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
270 beta,
271 None,
272 None,
273 )?;
274 if let Some(softcap) = sdpa_params.softcap {
275 attention_scores = (attention_scores.tanh()? * softcap as f64)?;
276 }
277 attention_scores = candle_nn::ops::softmax_last_dim(&attention_scores)?;
278
279 let context_layer = cublaslt.batch_matmul(
280 &v_flat.t()?.contiguous()?,
281 &attention_scores,
282 Some(&q_flat),
284 None,
285 None,
286 None,
287 None,
288 )?;
289
290 context_layer.reshape((chunk_b_sz, chunk_n_heads, chunk_seq_len, v_head_dim))
292 })
293 }
294 #[cfg(not(feature = "cuda"))]
295 {
296 candle_core::bail!("`cuda` feature is not enabled")
297 }
298 } else {
299 naive_sdpa(q, &k, &v, mask, sdpa_params)
300 }
301 }
302}