mistralrs_quant/cublaslt/
mod.rs1#![allow(unused_variables, unused_imports, dead_code)]
4
5use candle_core::{Device, Result, Tensor};
6use candle_nn::Activation as CandleActivation;
7use once_cell::sync::Lazy;
8use std::sync::{Mutex, Once};
9
10#[cfg(feature = "cuda")]
11mod api;
12#[cfg(feature = "cuda")]
13mod matmul;
14#[cfg(test)]
15#[cfg(feature = "cuda")]
16mod tests;
17
18#[cfg(feature = "cuda")]
19pub use api::{fused_batch_matmul, fused_batch_matmul_f8, CublasLt};
20
21pub enum F8MatmulOutType {
22 F8,
23 BF16,
24}
25
26static INIT: Once = Once::new();
27static mut CUBLASLT: Option<CublasLtWrapper> = None;
28pub static CUBLASLT_HANDLE: Lazy<Mutex<Option<&'static CublasLtWrapper>>> =
29 Lazy::new(|| Mutex::new(None));
30
31pub fn maybe_init_cublas_lt_wrapper(device: Device) {
32 unsafe {
33 INIT.call_once(|| {
34 #[cfg(not(feature = "cuda"))]
35 {
36 CUBLASLT = None;
37 }
38
39 #[cfg(feature = "cuda")]
40 {
41 use candle_core::cuda_backend::cudarc::driver;
45 CUBLASLT = match device {
46 Device::Cuda(_) => Some(CublasLtWrapper {
47 cublaslt: CublasLt::new(&device).unwrap(),
48 }),
49 _ => None,
50 }
51 }
52 #[allow(static_mut_refs)]
53 let cublaslt: Option<&'static CublasLtWrapper> = CUBLASLT.as_ref();
54 *CUBLASLT_HANDLE.lock().unwrap() = cublaslt;
55 });
56 }
57}
58
59#[derive(Debug, Clone)]
60pub struct CublasLtWrapper {
61 #[cfg(feature = "cuda")]
62 pub cublaslt: CublasLt,
63}
64
65impl CublasLtWrapper {
66 #[allow(clippy::too_many_arguments)]
83 pub fn batch_matmul_f8(
84 &self,
85 a: &Tensor,
86 b: &Tensor,
87 dequant_a_scale: &Tensor,
88 dequant_b_scale: &Tensor,
89 quantize_scale: &Tensor,
90 out: Option<&Tensor>,
91 alpha: Option<f32>,
92 beta: Option<f32>,
93 bias: Option<&Tensor>,
94 act: Option<CandleActivation>,
95 out_dtype: F8MatmulOutType,
96 ) -> Result<Tensor> {
97 #[cfg(feature = "cuda")]
98 {
99 let inner_act = act.map(|a| match a {
100 CandleActivation::Relu => matmul::Activation::Relu,
101 CandleActivation::Gelu => matmul::Activation::Gelu,
102 _ => unreachable!("Unsupported activation in cublaslt matmul"),
103 });
104 let mut result = fused_batch_matmul_f8(
105 a,
106 b,
107 dequant_a_scale,
108 dequant_b_scale,
109 quantize_scale,
110 out,
111 alpha,
112 beta,
113 bias,
114 inner_act,
115 out_dtype,
116 self.cublaslt.clone(),
117 )?;
118
119 if Some(CandleActivation::Swiglu) == act {
120 result = candle_nn::ops::swiglu(&result)?;
121 }
122 Ok(result)
123 }
124 #[cfg(not(feature = "cuda"))]
125 {
126 candle_core::bail!("`cuda` feature is not enabled")
127 }
128 }
129
130 #[allow(clippy::too_many_arguments)]
144 pub fn batch_matmul(
145 &self,
146 a: &Tensor,
147 b: &Tensor,
148 out: Option<&Tensor>,
149 alpha: Option<f32>,
150 beta: Option<f32>,
151 bias: Option<&Tensor>,
152 act: Option<CandleActivation>,
153 ) -> Result<Tensor> {
154 #[cfg(feature = "cuda")]
155 {
156 let inner_act = act.map(|a| match a {
157 CandleActivation::Relu => matmul::Activation::Relu,
158 CandleActivation::Gelu => matmul::Activation::Gelu,
159 _ => unreachable!("Unsupported activation in cublaslt matmul"),
160 });
161 let mut result = fused_batch_matmul(
162 a,
163 b,
164 out,
165 alpha,
166 beta,
167 bias,
168 inner_act,
169 self.cublaslt.clone(),
170 )?;
171
172 if Some(CandleActivation::Swiglu) == act {
173 result = candle_nn::ops::swiglu(&result)?;
174 }
175 Ok(result)
176 }
177 #[cfg(not(feature = "cuda"))]
178 {
179 candle_core::bail!("`cuda` feature is not enabled")
180 }
181 }
182}