mistralrs_quant/gemv/
mod.rs

1//! Custom GEMV (General Matrix-Vector multiplication) for decode-phase inference.
2//!
3//! This module provides an optimized GEMV kernel that replaces cuBLAS for
4//! small batch sizes (1-8) where cuBLAS GEMM overhead is significant.
5//!
6//! Key optimizations:
7//! - Vectorized loads (half2, nv_bfloat162, float2)
8//! - __ldg() for read-only cache path (L2 cache handles x reuse)
9//! - Warp-level reduction using XOR shuffle
10//! - Static shared memory for block-level reduction
11//! - Supports batch sizes 1-8 efficiently
12
13#[cfg(feature = "cuda")]
14mod ffi;
15
16#[cfg(feature = "cuda")]
17use candle_core::{
18    cuda::cudarc::driver::DevicePtr, CudaDevice, CudaStorage, DType, Result, Shape, Storage, Tensor,
19};
20
21#[cfg(feature = "cuda")]
22use crate::utils::{get_cuda_device, slice_ptr};
23
24#[cfg(feature = "cuda")]
25use half::{bf16, f16};
26
27use std::sync::atomic::{AtomicBool, Ordering};
28use std::sync::LazyLock;
29
30/// Maximum batch size supported by the GEMV kernel
31pub const MAX_GEMV_BATCH_SIZE: usize = 8;
32
33/// Controller for enabling/disabling custom GEMV kernel.
34pub struct GemvController {
35    enabled: AtomicBool,
36}
37
38impl GemvController {
39    /// Enable or disable the custom GEMV kernel.
40    pub fn set_enabled(&self, value: bool) {
41        self.enabled.store(value, Ordering::SeqCst);
42    }
43
44    /// Check if the custom GEMV kernel is enabled.
45    pub fn is_enabled(&self) -> bool {
46        self.enabled.load(Ordering::SeqCst)
47    }
48}
49
50/// Global controller for the custom GEMV kernel.
51pub static GEMV_CONTROLLER: LazyLock<GemvController> = LazyLock::new(|| GemvController {
52    enabled: AtomicBool::new(true),
53});
54
55/// Check if custom GEMV should be used instead of cuBLAS.
56///
57/// Returns true if:
58/// - GEMV is enabled via controller
59/// - Tensors are on CUDA device
60/// - Batch size is 1-8
61/// - Data type is supported (BF16, F16, F32)
62/// - K dimension is even (required for vectorized loads)
63#[cfg(feature = "cuda")]
64pub fn should_use_gemv(x: &Tensor, w: &Tensor) -> bool {
65    // Check if enabled
66    if !GEMV_CONTROLLER.is_enabled() {
67        return false;
68    }
69
70    // Only for CUDA tensors
71    if !x.device().is_cuda() {
72        return false;
73    }
74
75    // Check batch size (1-8 supported)
76    let x_dims = x.dims();
77    let batch_size: usize = x_dims[..x_dims.len().saturating_sub(1)]
78        .iter()
79        .product::<usize>()
80        .max(1);
81    if batch_size > MAX_GEMV_BATCH_SIZE {
82        return false;
83    }
84
85    // Must be supported dtype
86    let supported = matches!(x.dtype(), DType::BF16 | DType::F16 | DType::F32);
87    if !supported {
88        return false;
89    }
90
91    // Must match dtypes
92    if x.dtype() != w.dtype() {
93        return false;
94    }
95
96    // K must be even for vectorized loads
97    let k = x.dim(x.rank() - 1).unwrap_or(0);
98    if k % 2 != 0 {
99        return false;
100    }
101
102    // Check that K dimensions match
103    let w_k = w.dim(w.rank() - 1).unwrap_or(0);
104    if k != w_k {
105        return false;
106    }
107
108    true
109}
110
111/// Fallback for non-CUDA builds
112#[cfg(not(feature = "cuda"))]
113pub fn should_use_gemv(_x: &candle_core::Tensor, _w: &candle_core::Tensor) -> bool {
114    false
115}
116
117/// Execute custom GEMV: Y = X @ W^T + bias
118///
119/// # Arguments
120/// * `x` - Input tensor [B, K] where B is batch size (1-8)
121/// * `w` - Weight matrix tensor [M, K]
122/// * `bias` - Optional bias tensor [M]
123///
124/// # Returns
125/// * Output tensor [B, M]
126#[cfg(feature = "cuda")]
127pub fn gemv(x: &Tensor, w: &Tensor, bias: Option<&Tensor>) -> Result<Tensor> {
128    let dev = get_cuda_device(x)?;
129
130    // Get dimensions
131    let (m, k) = w.dims2()?;
132
133    // Calculate batch size from input shape
134    let x_dims = x.dims();
135    let batch_size: usize = x_dims[..x_dims.len().saturating_sub(1)]
136        .iter()
137        .product::<usize>()
138        .max(1);
139
140    if batch_size > MAX_GEMV_BATCH_SIZE {
141        candle_core::bail!(
142            "GEMV batch size {} exceeds maximum {}",
143            batch_size,
144            MAX_GEMV_BATCH_SIZE
145        );
146    }
147
148    // Check K dimension
149    let x_k = x.dim(x.rank() - 1)?;
150    if x_k != k {
151        candle_core::bail!("GEMV dimension mismatch: x has K={} but W has K={}", x_k, k);
152    }
153
154    // Validate bias if present
155    if let Some(b) = bias {
156        let b_len = b.elem_count();
157        if b_len != m {
158            candle_core::bail!(
159                "GEMV bias dimension mismatch: bias has {} elements but M={}",
160                b_len,
161                m
162            );
163        }
164    }
165
166    // Output shape matches input batch dims with last dim = M
167    let output_shape = {
168        let mut shape = x.dims().to_vec();
169        *shape.last_mut().unwrap() = m;
170        shape
171    };
172
173    // Dispatch based on dtype
174    match x.dtype() {
175        DType::BF16 => gemv_bf16(dev, x, w, bias, batch_size, m, k, &output_shape),
176        DType::F16 => gemv_f16(dev, x, w, bias, batch_size, m, k, &output_shape),
177        DType::F32 => gemv_f32(dev, x, w, bias, batch_size, m, k, &output_shape),
178        dt => candle_core::bail!("GEMV unsupported dtype: {:?}", dt),
179    }
180}
181
182#[cfg(feature = "cuda")]
183#[allow(clippy::too_many_arguments)]
184fn gemv_bf16(
185    dev: &CudaDevice,
186    x: &Tensor,
187    w: &Tensor,
188    bias: Option<&Tensor>,
189    batch_size: usize,
190    m: usize,
191    k: usize,
192    output_shape: &[usize],
193) -> Result<Tensor> {
194    // Allocate output: [B, M]
195    let y_buf = unsafe { dev.alloc::<bf16>(batch_size * m)? };
196
197    // Get weight pointer
198    let (w_s, w_l) = w.storage_and_layout();
199    let Storage::Cuda(w_s) = &*w_s else {
200        candle_core::bail!("Expected CUDA storage for weights");
201    };
202    let (w_ptr, _w_guard) = slice_ptr(w_s.as_cuda_slice::<bf16>()?, w_l.start_offset());
203
204    // Get input pointer (contiguous)
205    let x_contig = x.contiguous()?;
206    let (x_s, x_l) = x_contig.storage_and_layout();
207    let Storage::Cuda(x_s) = &*x_s else {
208        candle_core::bail!("Expected CUDA storage for input");
209    };
210    let (x_ptr, _x_guard) = slice_ptr(x_s.as_cuda_slice::<bf16>()?, x_l.start_offset());
211
212    let (y_ptr, y_guard) = y_buf.device_ptr(y_buf.stream());
213
214    // Get bias storage
215    let bias_storage = bias.map(|b| b.storage_and_layout());
216    let (bias_ptr, has_bias, _bias_guard) = if let Some((ref b_arc, b_l)) = bias_storage {
217        let Storage::Cuda(b_s) = &**b_arc else {
218            candle_core::bail!("Expected CUDA storage for bias");
219        };
220        let (b_ptr, b_guard) = slice_ptr(b_s.as_cuda_slice::<bf16>()?, b_l.start_offset());
221        (b_ptr, true, Some(b_guard))
222    } else {
223        (0u64, false, None)
224    };
225
226    let stream = dev.cuda_stream();
227
228    unsafe {
229        ffi::launch_gemv_bf16(
230            w_ptr as *const bf16,
231            x_ptr as *const bf16,
232            bias_ptr as *const bf16,
233            y_ptr as *mut bf16,
234            m as i32,
235            k as i32,
236            batch_size as i32,
237            has_bias,
238            stream.cu_stream() as *mut std::ffi::c_void,
239        );
240    }
241
242    drop(y_guard);
243
244    let y_storage = CudaStorage::wrap_cuda_slice(y_buf, dev.clone());
245    let y = Tensor::from((Storage::Cuda(y_storage), Shape::from(output_shape)));
246
247    Ok(y)
248}
249
250#[cfg(feature = "cuda")]
251#[allow(clippy::too_many_arguments)]
252fn gemv_f16(
253    dev: &CudaDevice,
254    x: &Tensor,
255    w: &Tensor,
256    bias: Option<&Tensor>,
257    batch_size: usize,
258    m: usize,
259    k: usize,
260    output_shape: &[usize],
261) -> Result<Tensor> {
262    let y_buf = unsafe { dev.alloc::<f16>(batch_size * m)? };
263
264    let (w_s, w_l) = w.storage_and_layout();
265    let Storage::Cuda(w_s) = &*w_s else {
266        candle_core::bail!("Expected CUDA storage for weights");
267    };
268    let (w_ptr, _w_guard) = slice_ptr(w_s.as_cuda_slice::<f16>()?, w_l.start_offset());
269
270    let x_contig = x.contiguous()?;
271    let (x_s, x_l) = x_contig.storage_and_layout();
272    let Storage::Cuda(x_s) = &*x_s else {
273        candle_core::bail!("Expected CUDA storage for input");
274    };
275    let (x_ptr, _x_guard) = slice_ptr(x_s.as_cuda_slice::<f16>()?, x_l.start_offset());
276
277    let (y_ptr, y_guard) = y_buf.device_ptr(y_buf.stream());
278
279    let bias_storage = bias.map(|b| b.storage_and_layout());
280    let (bias_ptr, has_bias, _bias_guard) = if let Some((ref b_arc, b_l)) = bias_storage {
281        let Storage::Cuda(b_s) = &**b_arc else {
282            candle_core::bail!("Expected CUDA storage for bias");
283        };
284        let (b_ptr, b_guard) = slice_ptr(b_s.as_cuda_slice::<f16>()?, b_l.start_offset());
285        (b_ptr, true, Some(b_guard))
286    } else {
287        (0u64, false, None)
288    };
289
290    let stream = dev.cuda_stream();
291
292    unsafe {
293        ffi::launch_gemv_f16(
294            w_ptr as *const f16,
295            x_ptr as *const f16,
296            bias_ptr as *const f16,
297            y_ptr as *mut f16,
298            m as i32,
299            k as i32,
300            batch_size as i32,
301            has_bias,
302            stream.cu_stream() as *mut std::ffi::c_void,
303        );
304    }
305
306    drop(y_guard);
307
308    let y_storage = CudaStorage::wrap_cuda_slice(y_buf, dev.clone());
309    let y = Tensor::from((Storage::Cuda(y_storage), Shape::from(output_shape)));
310
311    Ok(y)
312}
313
314#[cfg(feature = "cuda")]
315#[allow(clippy::too_many_arguments)]
316fn gemv_f32(
317    dev: &CudaDevice,
318    x: &Tensor,
319    w: &Tensor,
320    bias: Option<&Tensor>,
321    batch_size: usize,
322    m: usize,
323    k: usize,
324    output_shape: &[usize],
325) -> Result<Tensor> {
326    let y_buf = unsafe { dev.alloc::<f32>(batch_size * m)? };
327
328    let (w_s, w_l) = w.storage_and_layout();
329    let Storage::Cuda(w_s) = &*w_s else {
330        candle_core::bail!("Expected CUDA storage for weights");
331    };
332    let (w_ptr, _w_guard) = slice_ptr(w_s.as_cuda_slice::<f32>()?, w_l.start_offset());
333
334    let x_contig = x.contiguous()?;
335    let (x_s, x_l) = x_contig.storage_and_layout();
336    let Storage::Cuda(x_s) = &*x_s else {
337        candle_core::bail!("Expected CUDA storage for input");
338    };
339    let (x_ptr, _x_guard) = slice_ptr(x_s.as_cuda_slice::<f32>()?, x_l.start_offset());
340
341    let (y_ptr, y_guard) = y_buf.device_ptr(y_buf.stream());
342
343    let bias_storage = bias.map(|b| b.storage_and_layout());
344    let (bias_ptr, has_bias, _bias_guard) = if let Some((ref b_arc, b_l)) = bias_storage {
345        let Storage::Cuda(b_s) = &**b_arc else {
346            candle_core::bail!("Expected CUDA storage for bias");
347        };
348        let (b_ptr, b_guard) = slice_ptr(b_s.as_cuda_slice::<f32>()?, b_l.start_offset());
349        (b_ptr, true, Some(b_guard))
350    } else {
351        (0u64, false, None)
352    };
353
354    let stream = dev.cuda_stream();
355
356    unsafe {
357        ffi::launch_gemv_f32(
358            w_ptr as *const f32,
359            x_ptr as *const f32,
360            bias_ptr as *const f32,
361            y_ptr as *mut f32,
362            m as i32,
363            k as i32,
364            batch_size as i32,
365            has_bias,
366            stream.cu_stream() as *mut std::ffi::c_void,
367        );
368    }
369
370    drop(y_guard);
371
372    let y_storage = CudaStorage::wrap_cuda_slice(y_buf, dev.clone());
373    let y = Tensor::from((Storage::Cuda(y_storage), Shape::from(output_shape)));
374
375    Ok(y)
376}
377
378/// Fallback for non-CUDA builds
379#[cfg(not(feature = "cuda"))]
380pub fn gemv(
381    _x: &candle_core::Tensor,
382    _w: &candle_core::Tensor,
383    _bias: Option<&candle_core::Tensor>,
384) -> candle_core::Result<candle_core::Tensor> {
385    candle_core::bail!("GEMV requires CUDA feature");
386}