mistralrs_quant/cublaslt/
mod.rs

1// https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs
2
3#![allow(unused_variables, unused_imports, dead_code)]
4
5use candle_core::{Device, Result, Tensor};
6use candle_nn::Activation as CandleActivation;
7use once_cell::sync::Lazy;
8use std::sync::{Mutex, Once};
9
10#[cfg(feature = "cuda")]
11mod api;
12#[cfg(feature = "cuda")]
13mod matmul;
14#[cfg(test)]
15#[cfg(feature = "cuda")]
16mod tests;
17
18#[cfg(feature = "cuda")]
19pub use api::{fused_batch_matmul, fused_batch_matmul_f8, CublasLt};
20
21pub enum F8MatmulOutType {
22    F8,
23    BF16,
24}
25
26static INIT: Once = Once::new();
27static mut CUBLASLT: Option<CublasLtWrapper> = None;
28pub static CUBLASLT_HANDLE: Lazy<Mutex<Option<&'static CublasLtWrapper>>> =
29    Lazy::new(|| Mutex::new(None));
30
31pub fn maybe_init_cublas_lt_wrapper(device: Device) {
32    unsafe {
33        INIT.call_once(|| {
34            #[cfg(not(feature = "cuda"))]
35            {
36                CUBLASLT = None;
37            }
38
39            #[cfg(feature = "cuda")]
40            {
41                // Check if we can call the driver
42                // Then check if we can create a device
43                // Then check that the device is CUDA
44                use candle_core::cuda_backend::cudarc::driver;
45                CUBLASLT = match device {
46                    Device::Cuda(_) => Some(CublasLtWrapper {
47                        cublaslt: CublasLt::new(&device).unwrap(),
48                    }),
49                    _ => None,
50                }
51            }
52            #[allow(static_mut_refs)]
53            let cublaslt: Option<&'static CublasLtWrapper> = CUBLASLT.as_ref();
54            *CUBLASLT_HANDLE.lock().unwrap() = cublaslt;
55        });
56    }
57}
58
59#[derive(Debug, Clone)]
60pub struct CublasLtWrapper {
61    #[cfg(feature = "cuda")]
62    pub cublaslt: CublasLt,
63}
64
65impl CublasLtWrapper {
66    /// Fused batch matmul + add + Relu/Gelu activation using CublasLt for F8 dtypes.
67    ///
68    /// # Arguments
69    ///
70    /// * `a` - Input tensor of size BxMxK
71    /// * `b` - Input tensor of size BxNxK
72    /// * `dequant_a_scale` - F32 scalar tensor, used to `a` the out tensor.
73    /// * `dequant_b_scale` - F32 scalar tensor, used to `b` the out tensor.
74    /// * `quantize_scale` - F32 scalar tensor, used to requantize.
75    /// * `out` - Optional Output tensor of size BxNxK. If set and beta != 0, will be added to the end result of A*B before `act`
76    /// * `alpha` - Optional scaling factor for A*B
77    /// * `beta` - Optional scaling factor for C
78    /// * `bias` - Optional bias tensor of size M
79    /// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result
80    ///
81    /// The resulting tensor is of shape NxM
82    #[allow(clippy::too_many_arguments)]
83    pub fn batch_matmul_f8(
84        &self,
85        a: &Tensor,
86        b: &Tensor,
87        dequant_a_scale: &Tensor,
88        dequant_b_scale: &Tensor,
89        quantize_scale: &Tensor,
90        out: Option<&Tensor>,
91        alpha: Option<f32>,
92        beta: Option<f32>,
93        bias: Option<&Tensor>,
94        act: Option<CandleActivation>,
95        out_dtype: F8MatmulOutType,
96    ) -> Result<Tensor> {
97        #[cfg(feature = "cuda")]
98        {
99            let inner_act = act.map(|a| match a {
100                CandleActivation::Relu => matmul::Activation::Relu,
101                CandleActivation::Gelu => matmul::Activation::Gelu,
102                _ => unreachable!("Unsupported activation in cublaslt matmul"),
103            });
104            let mut result = fused_batch_matmul_f8(
105                a,
106                b,
107                dequant_a_scale,
108                dequant_b_scale,
109                quantize_scale,
110                out,
111                alpha,
112                beta,
113                bias,
114                inner_act,
115                out_dtype,
116                self.cublaslt.clone(),
117            )?;
118
119            if Some(CandleActivation::Swiglu) == act {
120                result = candle_nn::ops::swiglu(&result)?;
121            }
122            Ok(result)
123        }
124        #[cfg(not(feature = "cuda"))]
125        {
126            candle_core::bail!("`cuda` feature is not enabled")
127        }
128    }
129
130    /// Fused batch matmul + add + Relu/Gelu activation using CublasLt.
131    ///
132    /// # Arguments
133    ///
134    /// * `a` - Input tensor of size BxMxK
135    /// * `b` - Input tensor of size BxNxK
136    /// * `out` - Optional Output tensor of size BxNxK. If set and beta != 0, will be added to the end result of A*B before `act`
137    /// * `alpha` - Optional scaling factor for A*B
138    /// * `beta` - Optional scaling factor for C
139    /// * `bias` - Optional bias tensor of size M
140    /// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result
141    ///
142    /// The resulting tensor is of shape NxM
143    #[allow(clippy::too_many_arguments)]
144    pub fn batch_matmul(
145        &self,
146        a: &Tensor,
147        b: &Tensor,
148        out: Option<&Tensor>,
149        alpha: Option<f32>,
150        beta: Option<f32>,
151        bias: Option<&Tensor>,
152        act: Option<CandleActivation>,
153    ) -> Result<Tensor> {
154        #[cfg(feature = "cuda")]
155        {
156            let inner_act = act.map(|a| match a {
157                CandleActivation::Relu => matmul::Activation::Relu,
158                CandleActivation::Gelu => matmul::Activation::Gelu,
159                _ => unreachable!("Unsupported activation in cublaslt matmul"),
160            });
161            let mut result = fused_batch_matmul(
162                a,
163                b,
164                out,
165                alpha,
166                beta,
167                bias,
168                inner_act,
169                self.cublaslt.clone(),
170            )?;
171
172            if Some(CandleActivation::Swiglu) == act {
173                result = candle_nn::ops::swiglu(&result)?;
174            }
175            Ok(result)
176        }
177        #[cfg(not(feature = "cuda"))]
178        {
179            candle_core::bail!("`cuda` feature is not enabled")
180        }
181    }
182}