mistralrs_quant/cublaslt/
mod.rs1#![allow(unused_variables, unused_imports, dead_code)]
4
5use candle_core::{Device, Result, Tensor};
6use candle_nn::Activation as CandleActivation;
7use once_cell::sync::Lazy;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Mutex, Once};
10
11pub struct CublasLtController {
13 handle: Mutex<Option<&'static CublasLtWrapper>>,
14 inhibit: AtomicBool,
15}
16
17impl CublasLtController {
18 pub fn set_inhibit(&self, value: bool) {
20 self.inhibit.store(value, Ordering::SeqCst);
21 }
22
23 pub fn get(&self) -> Option<&'static CublasLtWrapper> {
25 let handle_opt = self.handle.lock().unwrap();
26 if self.inhibit.load(Ordering::SeqCst) {
27 None
28 } else {
29 *handle_opt
30 }
31 }
32}
33
34pub static CUBLASLT_CONTROLLER: Lazy<CublasLtController> = Lazy::new(|| CublasLtController {
35 handle: Mutex::new(None),
36 inhibit: AtomicBool::new(false),
37});
38
39#[cfg(feature = "cuda")]
40mod api;
41#[cfg(feature = "cuda")]
42mod matmul;
43#[cfg(test)]
44#[cfg(feature = "cuda")]
45mod tests;
46
47#[cfg(feature = "cuda")]
48pub use api::{fused_batch_matmul, fused_batch_matmul_f8, CublasLt};
49
50pub fn maybe_init_cublas_lt_wrapper(device: Device) {
51 static INIT: Once = Once::new();
52
53 INIT.call_once(|| {
54 #[cfg(feature = "cuda")]
55 {
56 match device {
57 Device::Cuda(_) => {
58 let wrapper = Box::new(CublasLtWrapper {
59 cublaslt: CublasLt::new(&device).unwrap(),
60 });
61 let wrapper_ptr = Box::leak(wrapper) as &'static CublasLtWrapper;
62
63 let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
65 *handle_lock = Some(wrapper_ptr);
66 }
67 _ => {
68 let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
69 *handle_lock = None;
70 }
71 }
72 }
73
74 #[cfg(not(feature = "cuda"))]
75 {
76 let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
77 *handle_lock = None;
78 }
79 });
80}
81
82#[derive(Debug, Clone)]
83pub struct CublasLtWrapper {
84 #[cfg(feature = "cuda")]
85 pub cublaslt: CublasLt,
86}
87
88impl CublasLtWrapper {
89 #[allow(clippy::too_many_arguments)]
106 pub fn batch_matmul_f8(
107 &self,
108 a: &Tensor,
109 b: &Tensor,
110 dequant_a_scale: &Tensor,
111 dequant_b_scale: &Tensor,
112 quantize_scale: &Tensor,
113 out: Option<&Tensor>,
114 alpha: Option<f32>,
115 beta: Option<f32>,
116 bias: Option<&Tensor>,
117 act: Option<CandleActivation>,
118 ) -> Result<Tensor> {
119 #[cfg(feature = "cuda")]
120 {
121 let inner_act = act.map(|a| match a {
122 CandleActivation::Relu => matmul::Activation::Relu,
123 CandleActivation::Gelu => matmul::Activation::Gelu,
124 _ => unreachable!("Unsupported activation in cublaslt matmul"),
125 });
126 let mut result = fused_batch_matmul_f8(
127 a,
128 b,
129 dequant_a_scale,
130 dequant_b_scale,
131 quantize_scale,
132 out,
133 alpha,
134 beta,
135 bias,
136 inner_act,
137 self.cublaslt.clone(),
138 )?;
139
140 if Some(CandleActivation::Swiglu) == act {
141 result = candle_nn::ops::swiglu(&result)?;
142 }
143 Ok(result)
144 }
145 #[cfg(not(feature = "cuda"))]
146 {
147 candle_core::bail!("`cuda` feature is not enabled")
148 }
149 }
150
151 #[allow(clippy::too_many_arguments)]
165 pub fn batch_matmul(
166 &self,
167 a: &Tensor,
168 b: &Tensor,
169 out: Option<&Tensor>,
170 alpha: Option<f32>,
171 beta: Option<f32>,
172 bias: Option<&Tensor>,
173 act: Option<CandleActivation>,
174 ) -> Result<Tensor> {
175 #[cfg(feature = "cuda")]
176 {
177 let inner_act = act.map(|a| match a {
178 CandleActivation::Relu => matmul::Activation::Relu,
179 CandleActivation::Gelu => matmul::Activation::Gelu,
180 _ => unreachable!("Unsupported activation in cublaslt matmul"),
181 });
182 let mut result = fused_batch_matmul(
183 a,
184 b,
185 out,
186 alpha,
187 beta,
188 bias,
189 inner_act,
190 self.cublaslt.clone(),
191 )?;
192
193 if Some(CandleActivation::Swiglu) == act {
194 result = candle_nn::ops::swiglu(&result)?;
195 }
196 Ok(result)
197 }
198 #[cfg(not(feature = "cuda"))]
199 {
200 candle_core::bail!("`cuda` feature is not enabled")
201 }
202 }
203}