mistralrs_quant/cublaslt/
mod.rs1#![allow(unused_variables, unused_imports, dead_code)]
4
5use candle_core::{Device, Result, Tensor};
6use candle_nn::Activation as CandleActivation;
7use once_cell::sync::Lazy;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Mutex, Once};
10
11pub struct CublasLtController {
13 handle: Mutex<Option<&'static CublasLtWrapper>>,
14 inhibit: AtomicBool,
15}
16
17impl CublasLtController {
18 pub fn set_inhibit(&self, value: bool) {
20 self.inhibit.store(value, Ordering::SeqCst);
21 }
22
23 pub fn get(&self) -> Option<&'static CublasLtWrapper> {
25 let handle_opt = self.handle.lock().unwrap();
26 if self.inhibit.load(Ordering::SeqCst) {
27 None
28 } else {
29 *handle_opt
30 }
31 }
32}
33
34pub static CUBLASLT_CONTROLLER: Lazy<CublasLtController> = Lazy::new(|| CublasLtController {
35 handle: Mutex::new(None),
36 inhibit: AtomicBool::new(false),
37});
38
39#[cfg(feature = "cuda")]
40mod api;
41#[cfg(feature = "cuda")]
42mod matmul;
43#[cfg(test)]
44#[cfg(feature = "cuda")]
45mod tests;
46
47#[cfg(feature = "cuda")]
48pub use api::{fused_batch_matmul, fused_batch_matmul_f8, CublasLt};
49
50pub enum F8MatmulOutType {
51 F8,
52 BF16,
53}
54
55pub fn maybe_init_cublas_lt_wrapper(device: Device) {
56 static INIT: Once = Once::new();
57
58 INIT.call_once(|| {
59 #[cfg(feature = "cuda")]
60 {
61 match device {
62 Device::Cuda(_) => {
63 let wrapper = Box::new(CublasLtWrapper {
64 cublaslt: CublasLt::new(&device).unwrap(),
65 });
66 let wrapper_ptr = Box::leak(wrapper) as &'static CublasLtWrapper;
67
68 let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
70 *handle_lock = Some(wrapper_ptr);
71 }
72 _ => {
73 let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
74 *handle_lock = None;
75 }
76 }
77 }
78
79 #[cfg(not(feature = "cuda"))]
80 {
81 let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
82 *handle_lock = None;
83 }
84 });
85}
86
87#[derive(Debug, Clone)]
88pub struct CublasLtWrapper {
89 #[cfg(feature = "cuda")]
90 pub cublaslt: CublasLt,
91}
92
93impl CublasLtWrapper {
94 #[allow(clippy::too_many_arguments)]
111 pub fn batch_matmul_f8(
112 &self,
113 a: &Tensor,
114 b: &Tensor,
115 dequant_a_scale: &Tensor,
116 dequant_b_scale: &Tensor,
117 quantize_scale: &Tensor,
118 out: Option<&Tensor>,
119 alpha: Option<f32>,
120 beta: Option<f32>,
121 bias: Option<&Tensor>,
122 act: Option<CandleActivation>,
123 out_dtype: F8MatmulOutType,
124 ) -> Result<Tensor> {
125 #[cfg(feature = "cuda")]
126 {
127 let inner_act = act.map(|a| match a {
128 CandleActivation::Relu => matmul::Activation::Relu,
129 CandleActivation::Gelu => matmul::Activation::Gelu,
130 _ => unreachable!("Unsupported activation in cublaslt matmul"),
131 });
132 let mut result = fused_batch_matmul_f8(
133 a,
134 b,
135 dequant_a_scale,
136 dequant_b_scale,
137 quantize_scale,
138 out,
139 alpha,
140 beta,
141 bias,
142 inner_act,
143 out_dtype,
144 self.cublaslt.clone(),
145 )?;
146
147 if Some(CandleActivation::Swiglu) == act {
148 result = candle_nn::ops::swiglu(&result)?;
149 }
150 Ok(result)
151 }
152 #[cfg(not(feature = "cuda"))]
153 {
154 candle_core::bail!("`cuda` feature is not enabled")
155 }
156 }
157
158 #[allow(clippy::too_many_arguments)]
172 pub fn batch_matmul(
173 &self,
174 a: &Tensor,
175 b: &Tensor,
176 out: Option<&Tensor>,
177 alpha: Option<f32>,
178 beta: Option<f32>,
179 bias: Option<&Tensor>,
180 act: Option<CandleActivation>,
181 ) -> Result<Tensor> {
182 #[cfg(feature = "cuda")]
183 {
184 let inner_act = act.map(|a| match a {
185 CandleActivation::Relu => matmul::Activation::Relu,
186 CandleActivation::Gelu => matmul::Activation::Gelu,
187 _ => unreachable!("Unsupported activation in cublaslt matmul"),
188 });
189 let mut result = fused_batch_matmul(
190 a,
191 b,
192 out,
193 alpha,
194 beta,
195 bias,
196 inner_act,
197 self.cublaslt.clone(),
198 )?;
199
200 if Some(CandleActivation::Swiglu) == act {
201 result = candle_nn::ops::swiglu(&result)?;
202 }
203 Ok(result)
204 }
205 #[cfg(not(feature = "cuda"))]
206 {
207 candle_core::bail!("`cuda` feature is not enabled")
208 }
209 }
210}