1#[cfg(feature = "cuda")]
14mod ffi;
15
16#[cfg(feature = "cuda")]
17use candle_core::{
18 cuda::cudarc::driver::DevicePtr, CudaDevice, CudaStorage, DType, Result, Shape, Storage, Tensor,
19};
20
21#[cfg(feature = "cuda")]
22use crate::utils::{get_cuda_device, slice_ptr};
23
24#[cfg(feature = "cuda")]
25use half::{bf16, f16};
26
27use std::sync::atomic::{AtomicBool, Ordering};
28use std::sync::LazyLock;
29
30pub const MAX_GEMV_BATCH_SIZE: usize = 8;
32
33pub struct GemvController {
35 enabled: AtomicBool,
36}
37
38impl GemvController {
39 pub fn set_enabled(&self, value: bool) {
41 self.enabled.store(value, Ordering::SeqCst);
42 }
43
44 pub fn is_enabled(&self) -> bool {
46 self.enabled.load(Ordering::SeqCst)
47 }
48}
49
50pub static GEMV_CONTROLLER: LazyLock<GemvController> = LazyLock::new(|| GemvController {
52 enabled: AtomicBool::new(true),
53});
54
55#[cfg(feature = "cuda")]
64pub fn should_use_gemv(x: &Tensor, w: &Tensor) -> bool {
65 if !GEMV_CONTROLLER.is_enabled() {
67 return false;
68 }
69
70 if !x.device().is_cuda() {
72 return false;
73 }
74
75 let x_dims = x.dims();
77 let batch_size: usize = x_dims[..x_dims.len().saturating_sub(1)]
78 .iter()
79 .product::<usize>()
80 .max(1);
81 if batch_size > MAX_GEMV_BATCH_SIZE {
82 return false;
83 }
84
85 let supported = matches!(x.dtype(), DType::BF16 | DType::F16 | DType::F32);
87 if !supported {
88 return false;
89 }
90
91 if x.dtype() != w.dtype() {
93 return false;
94 }
95
96 let k = x.dim(x.rank() - 1).unwrap_or(0);
98 if k % 2 != 0 {
99 return false;
100 }
101
102 let w_k = w.dim(w.rank() - 1).unwrap_or(0);
104 if k != w_k {
105 return false;
106 }
107
108 true
109}
110
111#[cfg(not(feature = "cuda"))]
113pub fn should_use_gemv(_x: &candle_core::Tensor, _w: &candle_core::Tensor) -> bool {
114 false
115}
116
117#[cfg(feature = "cuda")]
127pub fn gemv(x: &Tensor, w: &Tensor, bias: Option<&Tensor>) -> Result<Tensor> {
128 let dev = get_cuda_device(x)?;
129
130 let (m, k) = w.dims2()?;
132
133 let x_dims = x.dims();
135 let batch_size: usize = x_dims[..x_dims.len().saturating_sub(1)]
136 .iter()
137 .product::<usize>()
138 .max(1);
139
140 if batch_size > MAX_GEMV_BATCH_SIZE {
141 candle_core::bail!(
142 "GEMV batch size {} exceeds maximum {}",
143 batch_size,
144 MAX_GEMV_BATCH_SIZE
145 );
146 }
147
148 let x_k = x.dim(x.rank() - 1)?;
150 if x_k != k {
151 candle_core::bail!("GEMV dimension mismatch: x has K={} but W has K={}", x_k, k);
152 }
153
154 if let Some(b) = bias {
156 let b_len = b.elem_count();
157 if b_len != m {
158 candle_core::bail!(
159 "GEMV bias dimension mismatch: bias has {} elements but M={}",
160 b_len,
161 m
162 );
163 }
164 }
165
166 let output_shape = {
168 let mut shape = x.dims().to_vec();
169 *shape.last_mut().unwrap() = m;
170 shape
171 };
172
173 match x.dtype() {
175 DType::BF16 => gemv_bf16(dev, x, w, bias, batch_size, m, k, &output_shape),
176 DType::F16 => gemv_f16(dev, x, w, bias, batch_size, m, k, &output_shape),
177 DType::F32 => gemv_f32(dev, x, w, bias, batch_size, m, k, &output_shape),
178 dt => candle_core::bail!("GEMV unsupported dtype: {:?}", dt),
179 }
180}
181
182#[cfg(feature = "cuda")]
183#[allow(clippy::too_many_arguments)]
184fn gemv_bf16(
185 dev: &CudaDevice,
186 x: &Tensor,
187 w: &Tensor,
188 bias: Option<&Tensor>,
189 batch_size: usize,
190 m: usize,
191 k: usize,
192 output_shape: &[usize],
193) -> Result<Tensor> {
194 let y_buf = unsafe { dev.alloc::<bf16>(batch_size * m)? };
196
197 let (w_s, w_l) = w.storage_and_layout();
199 let Storage::Cuda(w_s) = &*w_s else {
200 candle_core::bail!("Expected CUDA storage for weights");
201 };
202 let (w_ptr, _w_guard) = slice_ptr(w_s.as_cuda_slice::<bf16>()?, w_l.start_offset());
203
204 let x_contig = x.contiguous()?;
206 let (x_s, x_l) = x_contig.storage_and_layout();
207 let Storage::Cuda(x_s) = &*x_s else {
208 candle_core::bail!("Expected CUDA storage for input");
209 };
210 let (x_ptr, _x_guard) = slice_ptr(x_s.as_cuda_slice::<bf16>()?, x_l.start_offset());
211
212 let (y_ptr, y_guard) = y_buf.device_ptr(y_buf.stream());
213
214 let bias_storage = bias.map(|b| b.storage_and_layout());
216 let (bias_ptr, has_bias, _bias_guard) = if let Some((ref b_arc, b_l)) = bias_storage {
217 let Storage::Cuda(b_s) = &**b_arc else {
218 candle_core::bail!("Expected CUDA storage for bias");
219 };
220 let (b_ptr, b_guard) = slice_ptr(b_s.as_cuda_slice::<bf16>()?, b_l.start_offset());
221 (b_ptr, true, Some(b_guard))
222 } else {
223 (0u64, false, None)
224 };
225
226 let stream = dev.cuda_stream();
227
228 unsafe {
229 ffi::launch_gemv_bf16(
230 w_ptr as *const bf16,
231 x_ptr as *const bf16,
232 bias_ptr as *const bf16,
233 y_ptr as *mut bf16,
234 m as i32,
235 k as i32,
236 batch_size as i32,
237 has_bias,
238 stream.cu_stream() as *mut std::ffi::c_void,
239 );
240 }
241
242 drop(y_guard);
243
244 let y_storage = CudaStorage::wrap_cuda_slice(y_buf, dev.clone());
245 let y = Tensor::from((Storage::Cuda(y_storage), Shape::from(output_shape)));
246
247 Ok(y)
248}
249
250#[cfg(feature = "cuda")]
251#[allow(clippy::too_many_arguments)]
252fn gemv_f16(
253 dev: &CudaDevice,
254 x: &Tensor,
255 w: &Tensor,
256 bias: Option<&Tensor>,
257 batch_size: usize,
258 m: usize,
259 k: usize,
260 output_shape: &[usize],
261) -> Result<Tensor> {
262 let y_buf = unsafe { dev.alloc::<f16>(batch_size * m)? };
263
264 let (w_s, w_l) = w.storage_and_layout();
265 let Storage::Cuda(w_s) = &*w_s else {
266 candle_core::bail!("Expected CUDA storage for weights");
267 };
268 let (w_ptr, _w_guard) = slice_ptr(w_s.as_cuda_slice::<f16>()?, w_l.start_offset());
269
270 let x_contig = x.contiguous()?;
271 let (x_s, x_l) = x_contig.storage_and_layout();
272 let Storage::Cuda(x_s) = &*x_s else {
273 candle_core::bail!("Expected CUDA storage for input");
274 };
275 let (x_ptr, _x_guard) = slice_ptr(x_s.as_cuda_slice::<f16>()?, x_l.start_offset());
276
277 let (y_ptr, y_guard) = y_buf.device_ptr(y_buf.stream());
278
279 let bias_storage = bias.map(|b| b.storage_and_layout());
280 let (bias_ptr, has_bias, _bias_guard) = if let Some((ref b_arc, b_l)) = bias_storage {
281 let Storage::Cuda(b_s) = &**b_arc else {
282 candle_core::bail!("Expected CUDA storage for bias");
283 };
284 let (b_ptr, b_guard) = slice_ptr(b_s.as_cuda_slice::<f16>()?, b_l.start_offset());
285 (b_ptr, true, Some(b_guard))
286 } else {
287 (0u64, false, None)
288 };
289
290 let stream = dev.cuda_stream();
291
292 unsafe {
293 ffi::launch_gemv_f16(
294 w_ptr as *const f16,
295 x_ptr as *const f16,
296 bias_ptr as *const f16,
297 y_ptr as *mut f16,
298 m as i32,
299 k as i32,
300 batch_size as i32,
301 has_bias,
302 stream.cu_stream() as *mut std::ffi::c_void,
303 );
304 }
305
306 drop(y_guard);
307
308 let y_storage = CudaStorage::wrap_cuda_slice(y_buf, dev.clone());
309 let y = Tensor::from((Storage::Cuda(y_storage), Shape::from(output_shape)));
310
311 Ok(y)
312}
313
314#[cfg(feature = "cuda")]
315#[allow(clippy::too_many_arguments)]
316fn gemv_f32(
317 dev: &CudaDevice,
318 x: &Tensor,
319 w: &Tensor,
320 bias: Option<&Tensor>,
321 batch_size: usize,
322 m: usize,
323 k: usize,
324 output_shape: &[usize],
325) -> Result<Tensor> {
326 let y_buf = unsafe { dev.alloc::<f32>(batch_size * m)? };
327
328 let (w_s, w_l) = w.storage_and_layout();
329 let Storage::Cuda(w_s) = &*w_s else {
330 candle_core::bail!("Expected CUDA storage for weights");
331 };
332 let (w_ptr, _w_guard) = slice_ptr(w_s.as_cuda_slice::<f32>()?, w_l.start_offset());
333
334 let x_contig = x.contiguous()?;
335 let (x_s, x_l) = x_contig.storage_and_layout();
336 let Storage::Cuda(x_s) = &*x_s else {
337 candle_core::bail!("Expected CUDA storage for input");
338 };
339 let (x_ptr, _x_guard) = slice_ptr(x_s.as_cuda_slice::<f32>()?, x_l.start_offset());
340
341 let (y_ptr, y_guard) = y_buf.device_ptr(y_buf.stream());
342
343 let bias_storage = bias.map(|b| b.storage_and_layout());
344 let (bias_ptr, has_bias, _bias_guard) = if let Some((ref b_arc, b_l)) = bias_storage {
345 let Storage::Cuda(b_s) = &**b_arc else {
346 candle_core::bail!("Expected CUDA storage for bias");
347 };
348 let (b_ptr, b_guard) = slice_ptr(b_s.as_cuda_slice::<f32>()?, b_l.start_offset());
349 (b_ptr, true, Some(b_guard))
350 } else {
351 (0u64, false, None)
352 };
353
354 let stream = dev.cuda_stream();
355
356 unsafe {
357 ffi::launch_gemv_f32(
358 w_ptr as *const f32,
359 x_ptr as *const f32,
360 bias_ptr as *const f32,
361 y_ptr as *mut f32,
362 m as i32,
363 k as i32,
364 batch_size as i32,
365 has_bias,
366 stream.cu_stream() as *mut std::ffi::c_void,
367 );
368 }
369
370 drop(y_guard);
371
372 let y_storage = CudaStorage::wrap_cuda_slice(y_buf, dev.clone());
373 let y = Tensor::from((Storage::Cuda(y_storage), Shape::from(output_shape)));
374
375 Ok(y)
376}
377
378#[cfg(not(feature = "cuda"))]
380pub fn gemv(
381 _x: &candle_core::Tensor,
382 _w: &candle_core::Tensor,
383 _bias: Option<&candle_core::Tensor>,
384) -> candle_core::Result<candle_core::Tensor> {
385 candle_core::bail!("GEMV requires CUDA feature");
386}