mistralrs_quant/cublaslt/
mod.rs

1// https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs
2
3#![allow(unused_variables, unused_imports, dead_code)]
4
5use candle_core::{Device, Result, Tensor};
6use candle_nn::Activation as CandleActivation;
7use once_cell::sync::Lazy;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Mutex, Once};
10
11/// Controller for the CUBLASLT handle and inhibition flag.
12pub struct CublasLtController {
13    handle: Mutex<Option<&'static CublasLtWrapper>>,
14    inhibit: AtomicBool,
15}
16
17impl CublasLtController {
18    /// Set whether to inhibit CUBLASLT usage.
19    pub fn set_inhibit(&self, value: bool) {
20        self.inhibit.store(value, Ordering::SeqCst);
21    }
22
23    /// Get the handle if not inhibited.
24    pub fn get(&self) -> Option<&'static CublasLtWrapper> {
25        let handle_opt = self.handle.lock().unwrap();
26        if self.inhibit.load(Ordering::SeqCst) {
27            None
28        } else {
29            *handle_opt
30        }
31    }
32}
33
34pub static CUBLASLT_CONTROLLER: Lazy<CublasLtController> = Lazy::new(|| CublasLtController {
35    handle: Mutex::new(None),
36    inhibit: AtomicBool::new(false),
37});
38
39#[cfg(feature = "cuda")]
40mod api;
41#[cfg(feature = "cuda")]
42mod matmul;
43#[cfg(test)]
44#[cfg(feature = "cuda")]
45mod tests;
46
47#[cfg(feature = "cuda")]
48pub use api::{fused_batch_matmul, fused_batch_matmul_f8, CublasLt};
49
50pub enum F8MatmulOutType {
51    F8,
52    BF16,
53}
54
55static INIT: Once = Once::new();
56static mut CUBLASLT: Option<CublasLtWrapper> = None;
57
58pub fn maybe_init_cublas_lt_wrapper(device: Device) {
59    unsafe {
60        INIT.call_once(|| {
61            #[cfg(not(feature = "cuda"))]
62            {
63                CUBLASLT = None;
64            }
65
66            #[cfg(feature = "cuda")]
67            {
68                // Check if we can call the driver
69                // Then check if we can create a device
70                // Then check that the device is CUDA
71                use candle_core::cuda_backend::cudarc::driver;
72                CUBLASLT = match device {
73                    Device::Cuda(_) => Some(CublasLtWrapper {
74                        cublaslt: CublasLt::new(&device).unwrap(),
75                    }),
76                    _ => None,
77                }
78            }
79            #[allow(static_mut_refs)]
80            let cublaslt: Option<&'static CublasLtWrapper> = CUBLASLT.as_ref();
81
82            // Set the controller handle
83            let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
84            *handle_lock = cublaslt;
85        });
86    }
87}
88
89#[derive(Debug, Clone)]
90pub struct CublasLtWrapper {
91    #[cfg(feature = "cuda")]
92    pub cublaslt: CublasLt,
93}
94
95impl CublasLtWrapper {
96    /// Fused batch matmul + add + Relu/Gelu activation using CublasLt for F8 dtypes.
97    ///
98    /// # Arguments
99    ///
100    /// * `a` - Input tensor of size BxMxK
101    /// * `b` - Input tensor of size BxNxK
102    /// * `dequant_a_scale` - F32 scalar tensor, used to `a` the out tensor.
103    /// * `dequant_b_scale` - F32 scalar tensor, used to `b` the out tensor.
104    /// * `quantize_scale` - F32 scalar tensor, used to requantize.
105    /// * `out` - Optional Output tensor of size BxNxK. If set and beta != 0, will be added to the end result of A*B before `act`
106    /// * `alpha` - Optional scaling factor for A*B
107    /// * `beta` - Optional scaling factor for C
108    /// * `bias` - Optional bias tensor of size M
109    /// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result
110    ///
111    /// The resulting tensor is of shape NxM
112    #[allow(clippy::too_many_arguments)]
113    pub fn batch_matmul_f8(
114        &self,
115        a: &Tensor,
116        b: &Tensor,
117        dequant_a_scale: &Tensor,
118        dequant_b_scale: &Tensor,
119        quantize_scale: &Tensor,
120        out: Option<&Tensor>,
121        alpha: Option<f32>,
122        beta: Option<f32>,
123        bias: Option<&Tensor>,
124        act: Option<CandleActivation>,
125        out_dtype: F8MatmulOutType,
126    ) -> Result<Tensor> {
127        #[cfg(feature = "cuda")]
128        {
129            let inner_act = act.map(|a| match a {
130                CandleActivation::Relu => matmul::Activation::Relu,
131                CandleActivation::Gelu => matmul::Activation::Gelu,
132                _ => unreachable!("Unsupported activation in cublaslt matmul"),
133            });
134            let mut result = fused_batch_matmul_f8(
135                a,
136                b,
137                dequant_a_scale,
138                dequant_b_scale,
139                quantize_scale,
140                out,
141                alpha,
142                beta,
143                bias,
144                inner_act,
145                out_dtype,
146                self.cublaslt.clone(),
147            )?;
148
149            if Some(CandleActivation::Swiglu) == act {
150                result = candle_nn::ops::swiglu(&result)?;
151            }
152            Ok(result)
153        }
154        #[cfg(not(feature = "cuda"))]
155        {
156            candle_core::bail!("`cuda` feature is not enabled")
157        }
158    }
159
160    /// Fused batch matmul + add + Relu/Gelu activation using CublasLt.
161    ///
162    /// # Arguments
163    ///
164    /// * `a` - Input tensor of size BxMxK
165    /// * `b` - Input tensor of size BxNxK
166    /// * `out` - Optional Output tensor of size BxNxK. If set and beta != 0, will be added to the end result of A*B before `act`
167    /// * `alpha` - Optional scaling factor for A*B
168    /// * `beta` - Optional scaling factor for C
169    /// * `bias` - Optional bias tensor of size M
170    /// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result
171    ///
172    /// The resulting tensor is of shape NxM
173    #[allow(clippy::too_many_arguments)]
174    pub fn batch_matmul(
175        &self,
176        a: &Tensor,
177        b: &Tensor,
178        out: Option<&Tensor>,
179        alpha: Option<f32>,
180        beta: Option<f32>,
181        bias: Option<&Tensor>,
182        act: Option<CandleActivation>,
183    ) -> Result<Tensor> {
184        #[cfg(feature = "cuda")]
185        {
186            let inner_act = act.map(|a| match a {
187                CandleActivation::Relu => matmul::Activation::Relu,
188                CandleActivation::Gelu => matmul::Activation::Gelu,
189                _ => unreachable!("Unsupported activation in cublaslt matmul"),
190            });
191            let mut result = fused_batch_matmul(
192                a,
193                b,
194                out,
195                alpha,
196                beta,
197                bias,
198                inner_act,
199                self.cublaslt.clone(),
200            )?;
201
202            if Some(CandleActivation::Swiglu) == act {
203                result = candle_nn::ops::swiglu(&result)?;
204            }
205            Ok(result)
206        }
207        #[cfg(not(feature = "cuda"))]
208        {
209            candle_core::bail!("`cuda` feature is not enabled")
210        }
211    }
212}