mistralrs_quant/cublaslt/
mod.rs

1// https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs
2
3#![allow(unused_variables, unused_imports, dead_code)]
4
5use candle_core::{Device, Result, Tensor};
6use candle_nn::Activation as CandleActivation;
7use once_cell::sync::Lazy;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Mutex, Once};
10
11/// Controller for the CUBLASLT handle and inhibition flag.
12pub struct CublasLtController {
13    handle: Mutex<Option<&'static CublasLtWrapper>>,
14    inhibit: AtomicBool,
15}
16
17impl CublasLtController {
18    /// Set whether to inhibit CUBLASLT usage.
19    pub fn set_inhibit(&self, value: bool) {
20        self.inhibit.store(value, Ordering::SeqCst);
21    }
22
23    /// Get the handle if not inhibited.
24    pub fn get(&self) -> Option<&'static CublasLtWrapper> {
25        let handle_opt = self.handle.lock().unwrap();
26        if self.inhibit.load(Ordering::SeqCst) {
27            None
28        } else {
29            *handle_opt
30        }
31    }
32}
33
34pub static CUBLASLT_CONTROLLER: Lazy<CublasLtController> = Lazy::new(|| CublasLtController {
35    handle: Mutex::new(None),
36    inhibit: AtomicBool::new(false),
37});
38
39#[cfg(feature = "cuda")]
40mod api;
41#[cfg(feature = "cuda")]
42mod matmul;
43#[cfg(test)]
44#[cfg(feature = "cuda")]
45mod tests;
46
47#[cfg(feature = "cuda")]
48pub use api::{fused_batch_matmul, fused_batch_matmul_f8, CublasLt};
49
50pub fn maybe_init_cublas_lt_wrapper(device: Device) {
51    static INIT: Once = Once::new();
52
53    INIT.call_once(|| {
54        #[cfg(feature = "cuda")]
55        {
56            match device {
57                Device::Cuda(_) => {
58                    let wrapper = Box::new(CublasLtWrapper {
59                        cublaslt: CublasLt::new(&device).unwrap(),
60                    });
61                    let wrapper_ptr = Box::leak(wrapper) as &'static CublasLtWrapper;
62
63                    // Set the controller handle
64                    let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
65                    *handle_lock = Some(wrapper_ptr);
66                }
67                _ => {
68                    let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
69                    *handle_lock = None;
70                }
71            }
72        }
73
74        #[cfg(not(feature = "cuda"))]
75        {
76            let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
77            *handle_lock = None;
78        }
79    });
80}
81
82#[derive(Debug, Clone)]
83pub struct CublasLtWrapper {
84    #[cfg(feature = "cuda")]
85    pub cublaslt: CublasLt,
86}
87
88impl CublasLtWrapper {
89    /// Fused batch matmul + add + Relu/Gelu activation using CublasLt for F8 dtypes.
90    ///
91    /// # Arguments
92    ///
93    /// * `a` - Input tensor of size BxMxK
94    /// * `b` - Input tensor of size BxNxK
95    /// * `dequant_a_scale` - F32 scalar tensor, used to `a` the out tensor.
96    /// * `dequant_b_scale` - F32 scalar tensor, used to `b` the out tensor.
97    /// * `quantize_scale` - F32 scalar tensor, used to requantize.
98    /// * `out` - Optional Output tensor of size BxNxK. If set and beta != 0, will be added to the end result of A*B before `act`
99    /// * `alpha` - Optional scaling factor for A*B
100    /// * `beta` - Optional scaling factor for C
101    /// * `bias` - Optional bias tensor of size M
102    /// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result
103    ///
104    /// The resulting tensor is of shape NxM
105    #[allow(clippy::too_many_arguments)]
106    pub fn batch_matmul_f8(
107        &self,
108        a: &Tensor,
109        b: &Tensor,
110        dequant_a_scale: &Tensor,
111        dequant_b_scale: &Tensor,
112        quantize_scale: &Tensor,
113        out: Option<&Tensor>,
114        alpha: Option<f32>,
115        beta: Option<f32>,
116        bias: Option<&Tensor>,
117        act: Option<CandleActivation>,
118    ) -> Result<Tensor> {
119        #[cfg(feature = "cuda")]
120        {
121            let inner_act = act.map(|a| match a {
122                CandleActivation::Relu => matmul::Activation::Relu,
123                CandleActivation::Gelu => matmul::Activation::Gelu,
124                _ => unreachable!("Unsupported activation in cublaslt matmul"),
125            });
126            let mut result = fused_batch_matmul_f8(
127                a,
128                b,
129                dequant_a_scale,
130                dequant_b_scale,
131                quantize_scale,
132                out,
133                alpha,
134                beta,
135                bias,
136                inner_act,
137                self.cublaslt.clone(),
138            )?;
139
140            if Some(CandleActivation::Swiglu) == act {
141                result = candle_nn::ops::swiglu(&result)?;
142            }
143            Ok(result)
144        }
145        #[cfg(not(feature = "cuda"))]
146        {
147            candle_core::bail!("`cuda` feature is not enabled")
148        }
149    }
150
151    /// Fused batch matmul + add + Relu/Gelu activation using CublasLt.
152    ///
153    /// # Arguments
154    ///
155    /// * `a` - Input tensor of size BxMxK
156    /// * `b` - Input tensor of size BxNxK
157    /// * `out` - Optional Output tensor of size BxNxK. If set and beta != 0, will be added to the end result of A*B before `act`
158    /// * `alpha` - Optional scaling factor for A*B
159    /// * `beta` - Optional scaling factor for C
160    /// * `bias` - Optional bias tensor of size M
161    /// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result
162    ///
163    /// The resulting tensor is of shape NxM
164    #[allow(clippy::too_many_arguments)]
165    pub fn batch_matmul(
166        &self,
167        a: &Tensor,
168        b: &Tensor,
169        out: Option<&Tensor>,
170        alpha: Option<f32>,
171        beta: Option<f32>,
172        bias: Option<&Tensor>,
173        act: Option<CandleActivation>,
174    ) -> Result<Tensor> {
175        #[cfg(feature = "cuda")]
176        {
177            let inner_act = act.map(|a| match a {
178                CandleActivation::Relu => matmul::Activation::Relu,
179                CandleActivation::Gelu => matmul::Activation::Gelu,
180                _ => unreachable!("Unsupported activation in cublaslt matmul"),
181            });
182            let mut result = fused_batch_matmul(
183                a,
184                b,
185                out,
186                alpha,
187                beta,
188                bias,
189                inner_act,
190                self.cublaslt.clone(),
191            )?;
192
193            if Some(CandleActivation::Swiglu) == act {
194                result = candle_nn::ops::swiglu(&result)?;
195            }
196            Ok(result)
197        }
198        #[cfg(not(feature = "cuda"))]
199        {
200            candle_core::bail!("`cuda` feature is not enabled")
201        }
202    }
203}