mistralrs_quant/cublaslt/
mod.rs

1// https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs
2
3#![allow(unused_variables, unused_imports, dead_code)]
4
5use candle_core::{Device, Result, Tensor};
6use candle_nn::Activation as CandleActivation;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{LazyLock, Mutex, Once};
9
10/// Controller for the CUBLASLT handle and inhibition flag.
11pub struct CublasLtController {
12    handle: Mutex<Option<&'static CublasLtWrapper>>,
13    inhibit: AtomicBool,
14}
15
16impl CublasLtController {
17    /// Set whether to inhibit CUBLASLT usage.
18    pub fn set_inhibit(&self, value: bool) {
19        self.inhibit.store(value, Ordering::SeqCst);
20    }
21
22    /// Get the handle if not inhibited.
23    ///
24    /// Note: We check inhibit BEFORE reading handle_opt to avoid TOCTOU race.
25    /// If inhibit is checked after reading handle_opt, another thread could
26    /// call set_inhibit(true) between reading handle_opt and checking inhibit,
27    /// causing us to return a handle that should have been inhibited.
28    pub fn get(&self) -> Option<&'static CublasLtWrapper> {
29        // Check inhibit first to prevent TOCTOU race condition
30        if self.inhibit.load(Ordering::SeqCst) {
31            return None;
32        }
33        let handle_opt = self.handle.lock().unwrap();
34        *handle_opt
35    }
36}
37
38pub static CUBLASLT_CONTROLLER: LazyLock<CublasLtController> =
39    LazyLock::new(|| CublasLtController {
40        handle: Mutex::new(None),
41        inhibit: AtomicBool::new(false),
42    });
43
44#[cfg(feature = "cuda")]
45mod api;
46#[cfg(feature = "cuda")]
47mod matmul;
48#[cfg(test)]
49#[cfg(feature = "cuda")]
50mod tests;
51
52#[cfg(feature = "cuda")]
53pub use api::{fused_batch_matmul, fused_batch_matmul_f8, CublasLt};
54
55pub fn maybe_init_cublas_lt_wrapper(device: Device) {
56    static INIT: Once = Once::new();
57
58    INIT.call_once(|| {
59        #[cfg(feature = "cuda")]
60        {
61            match device {
62                Device::Cuda(_) => {
63                    let wrapper = Box::new(CublasLtWrapper {
64                        cublaslt: CublasLt::new(&device).unwrap(),
65                    });
66                    let wrapper_ptr = Box::leak(wrapper) as &'static CublasLtWrapper;
67
68                    // Set the controller handle
69                    let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
70                    *handle_lock = Some(wrapper_ptr);
71                }
72                _ => {
73                    let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
74                    *handle_lock = None;
75                }
76            }
77        }
78
79        #[cfg(not(feature = "cuda"))]
80        {
81            let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
82            *handle_lock = None;
83        }
84    });
85}
86
87#[derive(Debug, Clone)]
88pub struct CublasLtWrapper {
89    #[cfg(feature = "cuda")]
90    pub cublaslt: CublasLt,
91}
92
93impl CublasLtWrapper {
94    /// Fused batch matmul + add + Relu/Gelu activation using CublasLt for F8 dtypes.
95    ///
96    /// # Arguments
97    ///
98    /// * `a` - Input tensor of size BxMxK
99    /// * `b` - Input tensor of size BxNxK
100    /// * `dequant_a_scale` - F32 scalar tensor, used to `a` the out tensor.
101    /// * `dequant_b_scale` - F32 scalar tensor, used to `b` the out tensor.
102    /// * `quantize_scale` - F32 scalar tensor, used to requantize.
103    /// * `out` - Optional Output tensor of size BxNxK. If set and beta != 0, will be added to the end result of A*B before `act`
104    /// * `alpha` - Optional scaling factor for A*B
105    /// * `beta` - Optional scaling factor for C
106    /// * `bias` - Optional bias tensor of size M
107    /// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result
108    ///
109    /// The resulting tensor is of shape NxM
110    #[allow(clippy::too_many_arguments)]
111    pub fn batch_matmul_f8(
112        &self,
113        a: &Tensor,
114        b: &Tensor,
115        dequant_a_scale: &Tensor,
116        dequant_b_scale: &Tensor,
117        quantize_scale: &Tensor,
118        out: Option<&Tensor>,
119        alpha: Option<f32>,
120        beta: Option<f32>,
121        bias: Option<&Tensor>,
122        act: Option<CandleActivation>,
123    ) -> Result<Tensor> {
124        #[cfg(feature = "cuda")]
125        {
126            let inner_act = act.map(|a| match a {
127                CandleActivation::Relu => matmul::Activation::Relu,
128                CandleActivation::Gelu => matmul::Activation::Gelu,
129                _ => unreachable!("Unsupported activation in cublaslt matmul"),
130            });
131            let mut result = fused_batch_matmul_f8(
132                a,
133                b,
134                dequant_a_scale,
135                dequant_b_scale,
136                quantize_scale,
137                out,
138                alpha,
139                beta,
140                bias,
141                inner_act,
142                self.cublaslt.clone(),
143            )?;
144
145            if Some(CandleActivation::Swiglu) == act {
146                result = candle_nn::ops::swiglu(&result)?;
147            }
148            Ok(result)
149        }
150        #[cfg(not(feature = "cuda"))]
151        {
152            candle_core::bail!("`cuda` feature is not enabled")
153        }
154    }
155
156    /// Fused batch matmul + add + Relu/Gelu activation using CublasLt.
157    ///
158    /// # Arguments
159    ///
160    /// * `a` - Input tensor of size BxMxK
161    /// * `b` - Input tensor of size BxNxK
162    /// * `out` - Optional Output tensor of size BxNxK. If set and beta != 0, will be added to the end result of A*B before `act`
163    /// * `alpha` - Optional scaling factor for A*B
164    /// * `beta` - Optional scaling factor for C
165    /// * `bias` - Optional bias tensor of size M
166    /// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result
167    ///
168    /// The resulting tensor is of shape NxM
169    #[allow(clippy::too_many_arguments)]
170    pub fn batch_matmul(
171        &self,
172        a: &Tensor,
173        b: &Tensor,
174        out: Option<&Tensor>,
175        alpha: Option<f32>,
176        beta: Option<f32>,
177        bias: Option<&Tensor>,
178        act: Option<CandleActivation>,
179    ) -> Result<Tensor> {
180        #[cfg(feature = "cuda")]
181        {
182            let inner_act = act.map(|a| match a {
183                CandleActivation::Relu => matmul::Activation::Relu,
184                CandleActivation::Gelu => matmul::Activation::Gelu,
185                _ => unreachable!("Unsupported activation in cublaslt matmul"),
186            });
187            let mut result = fused_batch_matmul(
188                a,
189                b,
190                out,
191                alpha,
192                beta,
193                bias,
194                inner_act,
195                self.cublaslt.clone(),
196            )?;
197
198            if Some(CandleActivation::Swiglu) == act {
199                result = candle_nn::ops::swiglu(&result)?;
200            }
201            Ok(result)
202        }
203        #[cfg(not(feature = "cuda"))]
204        {
205            candle_core::bail!("`cuda` feature is not enabled")
206        }
207    }
208}