mistralrs_quant/cublaslt/
mod.rs

1// https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs
2
3#![allow(unused_variables, unused_imports, dead_code)]
4
5use candle_core::{Device, Result, Tensor};
6use candle_nn::Activation as CandleActivation;
7use once_cell::sync::Lazy;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Mutex, Once};
10
11/// Controller for the CUBLASLT handle and inhibition flag.
12pub struct CublasLtController {
13    handle: Mutex<Option<&'static CublasLtWrapper>>,
14    inhibit: AtomicBool,
15}
16
17impl CublasLtController {
18    /// Set whether to inhibit CUBLASLT usage.
19    pub fn set_inhibit(&self, value: bool) {
20        self.inhibit.store(value, Ordering::SeqCst);
21    }
22
23    /// Get the handle if not inhibited.
24    pub fn get(&self) -> Option<&'static CublasLtWrapper> {
25        let handle_opt = self.handle.lock().unwrap();
26        if self.inhibit.load(Ordering::SeqCst) {
27            None
28        } else {
29            *handle_opt
30        }
31    }
32}
33
34pub static CUBLASLT_CONTROLLER: Lazy<CublasLtController> = Lazy::new(|| CublasLtController {
35    handle: Mutex::new(None),
36    inhibit: AtomicBool::new(false),
37});
38
39#[cfg(feature = "cuda")]
40mod api;
41#[cfg(feature = "cuda")]
42mod matmul;
43#[cfg(test)]
44#[cfg(feature = "cuda")]
45mod tests;
46
47#[cfg(feature = "cuda")]
48pub use api::{fused_batch_matmul, fused_batch_matmul_f8, CublasLt};
49
50pub enum F8MatmulOutType {
51    F8,
52    BF16,
53}
54
55pub fn maybe_init_cublas_lt_wrapper(device: Device) {
56    static INIT: Once = Once::new();
57
58    INIT.call_once(|| {
59        #[cfg(feature = "cuda")]
60        {
61            match device {
62                Device::Cuda(_) => {
63                    let wrapper = Box::new(CublasLtWrapper {
64                        cublaslt: CublasLt::new(&device).unwrap(),
65                    });
66                    let wrapper_ptr = Box::leak(wrapper) as &'static CublasLtWrapper;
67
68                    // Set the controller handle
69                    let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
70                    *handle_lock = Some(wrapper_ptr);
71                }
72                _ => {
73                    let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
74                    *handle_lock = None;
75                }
76            }
77        }
78
79        #[cfg(not(feature = "cuda"))]
80        {
81            let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
82            *handle_lock = None;
83        }
84    });
85}
86
87#[derive(Debug, Clone)]
88pub struct CublasLtWrapper {
89    #[cfg(feature = "cuda")]
90    pub cublaslt: CublasLt,
91}
92
93impl CublasLtWrapper {
94    /// Fused batch matmul + add + Relu/Gelu activation using CublasLt for F8 dtypes.
95    ///
96    /// # Arguments
97    ///
98    /// * `a` - Input tensor of size BxMxK
99    /// * `b` - Input tensor of size BxNxK
100    /// * `dequant_a_scale` - F32 scalar tensor, used to `a` the out tensor.
101    /// * `dequant_b_scale` - F32 scalar tensor, used to `b` the out tensor.
102    /// * `quantize_scale` - F32 scalar tensor, used to requantize.
103    /// * `out` - Optional Output tensor of size BxNxK. If set and beta != 0, will be added to the end result of A*B before `act`
104    /// * `alpha` - Optional scaling factor for A*B
105    /// * `beta` - Optional scaling factor for C
106    /// * `bias` - Optional bias tensor of size M
107    /// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result
108    ///
109    /// The resulting tensor is of shape NxM
110    #[allow(clippy::too_many_arguments)]
111    pub fn batch_matmul_f8(
112        &self,
113        a: &Tensor,
114        b: &Tensor,
115        dequant_a_scale: &Tensor,
116        dequant_b_scale: &Tensor,
117        quantize_scale: &Tensor,
118        out: Option<&Tensor>,
119        alpha: Option<f32>,
120        beta: Option<f32>,
121        bias: Option<&Tensor>,
122        act: Option<CandleActivation>,
123        out_dtype: F8MatmulOutType,
124    ) -> Result<Tensor> {
125        #[cfg(feature = "cuda")]
126        {
127            let inner_act = act.map(|a| match a {
128                CandleActivation::Relu => matmul::Activation::Relu,
129                CandleActivation::Gelu => matmul::Activation::Gelu,
130                _ => unreachable!("Unsupported activation in cublaslt matmul"),
131            });
132            let mut result = fused_batch_matmul_f8(
133                a,
134                b,
135                dequant_a_scale,
136                dequant_b_scale,
137                quantize_scale,
138                out,
139                alpha,
140                beta,
141                bias,
142                inner_act,
143                out_dtype,
144                self.cublaslt.clone(),
145            )?;
146
147            if Some(CandleActivation::Swiglu) == act {
148                result = candle_nn::ops::swiglu(&result)?;
149            }
150            Ok(result)
151        }
152        #[cfg(not(feature = "cuda"))]
153        {
154            candle_core::bail!("`cuda` feature is not enabled")
155        }
156    }
157
158    /// Fused batch matmul + add + Relu/Gelu activation using CublasLt.
159    ///
160    /// # Arguments
161    ///
162    /// * `a` - Input tensor of size BxMxK
163    /// * `b` - Input tensor of size BxNxK
164    /// * `out` - Optional Output tensor of size BxNxK. If set and beta != 0, will be added to the end result of A*B before `act`
165    /// * `alpha` - Optional scaling factor for A*B
166    /// * `beta` - Optional scaling factor for C
167    /// * `bias` - Optional bias tensor of size M
168    /// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result
169    ///
170    /// The resulting tensor is of shape NxM
171    #[allow(clippy::too_many_arguments)]
172    pub fn batch_matmul(
173        &self,
174        a: &Tensor,
175        b: &Tensor,
176        out: Option<&Tensor>,
177        alpha: Option<f32>,
178        beta: Option<f32>,
179        bias: Option<&Tensor>,
180        act: Option<CandleActivation>,
181    ) -> Result<Tensor> {
182        #[cfg(feature = "cuda")]
183        {
184            let inner_act = act.map(|a| match a {
185                CandleActivation::Relu => matmul::Activation::Relu,
186                CandleActivation::Gelu => matmul::Activation::Gelu,
187                _ => unreachable!("Unsupported activation in cublaslt matmul"),
188            });
189            let mut result = fused_batch_matmul(
190                a,
191                b,
192                out,
193                alpha,
194                beta,
195                bias,
196                inner_act,
197                self.cublaslt.clone(),
198            )?;
199
200            if Some(CandleActivation::Swiglu) == act {
201                result = candle_nn::ops::swiglu(&result)?;
202            }
203            Ok(result)
204        }
205        #[cfg(not(feature = "cuda"))]
206        {
207            candle_core::bail!("`cuda` feature is not enabled")
208        }
209    }
210}