mistralrs_quant/cublaslt/
mod.rs1#![allow(unused_variables, unused_imports, dead_code)]
4
5use candle_core::{Device, Result, Tensor};
6use candle_nn::Activation as CandleActivation;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{LazyLock, Mutex, Once};
9
10pub struct CublasLtController {
12 handle: Mutex<Option<&'static CublasLtWrapper>>,
13 inhibit: AtomicBool,
14}
15
16impl CublasLtController {
17 pub fn set_inhibit(&self, value: bool) {
19 self.inhibit.store(value, Ordering::SeqCst);
20 }
21
22 pub fn get(&self) -> Option<&'static CublasLtWrapper> {
29 if self.inhibit.load(Ordering::SeqCst) {
31 return None;
32 }
33 let handle_opt = self.handle.lock().unwrap();
34 *handle_opt
35 }
36}
37
38pub static CUBLASLT_CONTROLLER: LazyLock<CublasLtController> =
39 LazyLock::new(|| CublasLtController {
40 handle: Mutex::new(None),
41 inhibit: AtomicBool::new(false),
42 });
43
44#[cfg(feature = "cuda")]
45mod api;
46#[cfg(feature = "cuda")]
47mod matmul;
48#[cfg(test)]
49#[cfg(feature = "cuda")]
50mod tests;
51
52#[cfg(feature = "cuda")]
53pub use api::{fused_batch_matmul, fused_batch_matmul_f8, CublasLt};
54
55pub fn maybe_init_cublas_lt_wrapper(device: Device) {
56 static INIT: Once = Once::new();
57
58 INIT.call_once(|| {
59 #[cfg(feature = "cuda")]
60 {
61 match device {
62 Device::Cuda(_) => {
63 let wrapper = Box::new(CublasLtWrapper {
64 cublaslt: CublasLt::new(&device).unwrap(),
65 });
66 let wrapper_ptr = Box::leak(wrapper) as &'static CublasLtWrapper;
67
68 let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
70 *handle_lock = Some(wrapper_ptr);
71 }
72 _ => {
73 let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
74 *handle_lock = None;
75 }
76 }
77 }
78
79 #[cfg(not(feature = "cuda"))]
80 {
81 let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
82 *handle_lock = None;
83 }
84 });
85}
86
87#[derive(Debug, Clone)]
88pub struct CublasLtWrapper {
89 #[cfg(feature = "cuda")]
90 pub cublaslt: CublasLt,
91}
92
93impl CublasLtWrapper {
94 #[allow(clippy::too_many_arguments)]
111 pub fn batch_matmul_f8(
112 &self,
113 a: &Tensor,
114 b: &Tensor,
115 dequant_a_scale: &Tensor,
116 dequant_b_scale: &Tensor,
117 quantize_scale: &Tensor,
118 out: Option<&Tensor>,
119 alpha: Option<f32>,
120 beta: Option<f32>,
121 bias: Option<&Tensor>,
122 act: Option<CandleActivation>,
123 ) -> Result<Tensor> {
124 #[cfg(feature = "cuda")]
125 {
126 let inner_act = act.map(|a| match a {
127 CandleActivation::Relu => matmul::Activation::Relu,
128 CandleActivation::Gelu => matmul::Activation::Gelu,
129 _ => unreachable!("Unsupported activation in cublaslt matmul"),
130 });
131 let mut result = fused_batch_matmul_f8(
132 a,
133 b,
134 dequant_a_scale,
135 dequant_b_scale,
136 quantize_scale,
137 out,
138 alpha,
139 beta,
140 bias,
141 inner_act,
142 self.cublaslt.clone(),
143 )?;
144
145 if Some(CandleActivation::Swiglu) == act {
146 result = candle_nn::ops::swiglu(&result)?;
147 }
148 Ok(result)
149 }
150 #[cfg(not(feature = "cuda"))]
151 {
152 candle_core::bail!("`cuda` feature is not enabled")
153 }
154 }
155
156 #[allow(clippy::too_many_arguments)]
170 pub fn batch_matmul(
171 &self,
172 a: &Tensor,
173 b: &Tensor,
174 out: Option<&Tensor>,
175 alpha: Option<f32>,
176 beta: Option<f32>,
177 bias: Option<&Tensor>,
178 act: Option<CandleActivation>,
179 ) -> Result<Tensor> {
180 #[cfg(feature = "cuda")]
181 {
182 let inner_act = act.map(|a| match a {
183 CandleActivation::Relu => matmul::Activation::Relu,
184 CandleActivation::Gelu => matmul::Activation::Gelu,
185 _ => unreachable!("Unsupported activation in cublaslt matmul"),
186 });
187 let mut result = fused_batch_matmul(
188 a,
189 b,
190 out,
191 alpha,
192 beta,
193 bias,
194 inner_act,
195 self.cublaslt.clone(),
196 )?;
197
198 if Some(CandleActivation::Swiglu) == act {
199 result = candle_nn::ops::swiglu(&result)?;
200 }
201 Ok(result)
202 }
203 #[cfg(not(feature = "cuda"))]
204 {
205 candle_core::bail!("`cuda` feature is not enabled")
206 }
207 }
208}