mistralrs_quant/cublaslt/
mod.rs1#![allow(unused_variables, unused_imports, dead_code)]
4
5use candle_core::{Device, Result, Tensor};
6use candle_nn::Activation as CandleActivation;
7use once_cell::sync::Lazy;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Mutex, Once};
10
11pub struct CublasLtController {
13 handle: Mutex<Option<&'static CublasLtWrapper>>,
14 inhibit: AtomicBool,
15}
16
17impl CublasLtController {
18 pub fn set_inhibit(&self, value: bool) {
20 self.inhibit.store(value, Ordering::SeqCst);
21 }
22
23 pub fn get(&self) -> Option<&'static CublasLtWrapper> {
25 let handle_opt = self.handle.lock().unwrap();
26 if self.inhibit.load(Ordering::SeqCst) {
27 None
28 } else {
29 *handle_opt
30 }
31 }
32}
33
34pub static CUBLASLT_CONTROLLER: Lazy<CublasLtController> = Lazy::new(|| CublasLtController {
35 handle: Mutex::new(None),
36 inhibit: AtomicBool::new(false),
37});
38
39#[cfg(feature = "cuda")]
40mod api;
41#[cfg(feature = "cuda")]
42mod matmul;
43#[cfg(test)]
44#[cfg(feature = "cuda")]
45mod tests;
46
47#[cfg(feature = "cuda")]
48pub use api::{fused_batch_matmul, fused_batch_matmul_f8, CublasLt};
49
50pub enum F8MatmulOutType {
51 F8,
52 BF16,
53}
54
55static INIT: Once = Once::new();
56static mut CUBLASLT: Option<CublasLtWrapper> = None;
57
58pub fn maybe_init_cublas_lt_wrapper(device: Device) {
59 unsafe {
60 INIT.call_once(|| {
61 #[cfg(not(feature = "cuda"))]
62 {
63 CUBLASLT = None;
64 }
65
66 #[cfg(feature = "cuda")]
67 {
68 use candle_core::cuda_backend::cudarc::driver;
72 CUBLASLT = match device {
73 Device::Cuda(_) => Some(CublasLtWrapper {
74 cublaslt: CublasLt::new(&device).unwrap(),
75 }),
76 _ => None,
77 }
78 }
79 #[allow(static_mut_refs)]
80 let cublaslt: Option<&'static CublasLtWrapper> = CUBLASLT.as_ref();
81
82 let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
84 *handle_lock = cublaslt;
85 });
86 }
87}
88
89#[derive(Debug, Clone)]
90pub struct CublasLtWrapper {
91 #[cfg(feature = "cuda")]
92 pub cublaslt: CublasLt,
93}
94
95impl CublasLtWrapper {
96 #[allow(clippy::too_many_arguments)]
113 pub fn batch_matmul_f8(
114 &self,
115 a: &Tensor,
116 b: &Tensor,
117 dequant_a_scale: &Tensor,
118 dequant_b_scale: &Tensor,
119 quantize_scale: &Tensor,
120 out: Option<&Tensor>,
121 alpha: Option<f32>,
122 beta: Option<f32>,
123 bias: Option<&Tensor>,
124 act: Option<CandleActivation>,
125 out_dtype: F8MatmulOutType,
126 ) -> Result<Tensor> {
127 #[cfg(feature = "cuda")]
128 {
129 let inner_act = act.map(|a| match a {
130 CandleActivation::Relu => matmul::Activation::Relu,
131 CandleActivation::Gelu => matmul::Activation::Gelu,
132 _ => unreachable!("Unsupported activation in cublaslt matmul"),
133 });
134 let mut result = fused_batch_matmul_f8(
135 a,
136 b,
137 dequant_a_scale,
138 dequant_b_scale,
139 quantize_scale,
140 out,
141 alpha,
142 beta,
143 bias,
144 inner_act,
145 out_dtype,
146 self.cublaslt.clone(),
147 )?;
148
149 if Some(CandleActivation::Swiglu) == act {
150 result = candle_nn::ops::swiglu(&result)?;
151 }
152 Ok(result)
153 }
154 #[cfg(not(feature = "cuda"))]
155 {
156 candle_core::bail!("`cuda` feature is not enabled")
157 }
158 }
159
160 #[allow(clippy::too_many_arguments)]
174 pub fn batch_matmul(
175 &self,
176 a: &Tensor,
177 b: &Tensor,
178 out: Option<&Tensor>,
179 alpha: Option<f32>,
180 beta: Option<f32>,
181 bias: Option<&Tensor>,
182 act: Option<CandleActivation>,
183 ) -> Result<Tensor> {
184 #[cfg(feature = "cuda")]
185 {
186 let inner_act = act.map(|a| match a {
187 CandleActivation::Relu => matmul::Activation::Relu,
188 CandleActivation::Gelu => matmul::Activation::Gelu,
189 _ => unreachable!("Unsupported activation in cublaslt matmul"),
190 });
191 let mut result = fused_batch_matmul(
192 a,
193 b,
194 out,
195 alpha,
196 beta,
197 bias,
198 inner_act,
199 self.cublaslt.clone(),
200 )?;
201
202 if Some(CandleActivation::Swiglu) == act {
203 result = candle_nn::ops::swiglu(&result)?;
204 }
205 Ok(result)
206 }
207 #[cfg(not(feature = "cuda"))]
208 {
209 candle_core::bail!("`cuda` feature is not enabled")
210 }
211 }
212}