mistralrs_core/utils/
normal.rs

1#[allow(dead_code)]
2use std::{fmt::Display, str::FromStr};
3
4use anyhow::Result;
5use candle_core::{DType, Device, Tensor};
6use serde::Deserialize;
7use tracing::info;
8
9#[derive(Clone, Copy, Default, Debug, Deserialize, PartialEq)]
10#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
11/// DType for the model.
12///
13/// If the model is quantized, this is ignored so it is reasonable to use the [`Default`] impl.
14///
15/// Note: When using `Auto`, fallback pattern is: BF16 -> F16 -> 32
16pub enum ModelDType {
17    #[default]
18    #[serde(rename = "auto")]
19    Auto,
20    #[serde(rename = "bf16")]
21    BF16,
22    #[serde(rename = "f16")]
23    F16,
24    #[serde(rename = "f32")]
25    F32,
26}
27
28impl Display for ModelDType {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        match self {
31            Self::Auto => write!(f, "auto"),
32            Self::BF16 => write!(f, "bf16"),
33            Self::F16 => write!(f, "f16"),
34            Self::F32 => write!(f, "f32"),
35        }
36    }
37}
38
39impl FromStr for ModelDType {
40    type Err = String;
41    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
42        match s.to_lowercase().as_str() {
43            "auto" => Ok(Self::Auto),
44            "bf16" => Ok(Self::BF16),
45            "f16" => Ok(Self::F16),
46            "f32" => Ok(Self::F32),
47            other => Err(format!("Model DType `{other}` is not supported.")),
48        }
49    }
50}
51
52/// Type which can be converted to a DType
53pub trait TryIntoDType {
54    fn try_into_dtype(&self, devices: &[&Device]) -> Result<DType>;
55}
56
57impl TryIntoDType for DType {
58    fn try_into_dtype(&self, _: &[&Device]) -> Result<DType> {
59        info!("DType selected is {self:?}.");
60        if !matches!(self, DType::BF16 | DType::F32 | DType::F64 | DType::F16) {
61            anyhow::bail!("DType must be one of BF16, F16, F32, F64");
62        }
63        Ok(*self)
64    }
65}
66
67#[cfg(feature = "cuda")]
68fn get_dtypes() -> Vec<DType> {
69    use std::process::Command;
70
71    // >= is supported
72    const MIN_BF16_CC: usize = 800;
73    // >= is supported
74    const MIN_F16_CC: usize = 530;
75
76    let raw_out = Command::new("nvidia-smi")
77        .arg("--query-gpu=compute_cap")
78        .arg("--format=csv")
79        .output()
80        .expect("Failed to run `nvidia-smi` but CUDA is selected.")
81        .stdout;
82    let out = String::from_utf8(raw_out).expect("`nvidia-smi` did not return valid utf8");
83    // This reduce-min will always return at least one value so unwrap is OK.
84    let min_cc = out
85        .split('\n')
86        .skip(1)
87        .filter(|cc| !cc.trim().is_empty())
88        .map(|cc| cc.trim().parse::<f32>().unwrap())
89        .reduce(|a, b| if a < b { a } else { b })
90        .unwrap();
91    info!("Detected minimum CUDA compute capability {min_cc}");
92    // 7.5 -> 750
93    #[allow(clippy::cast_possible_truncation)]
94    let min_cc = (min_cc * 100.) as usize;
95
96    let mut dtypes = Vec::new();
97    if min_cc >= MIN_BF16_CC {
98        dtypes.push(DType::BF16);
99    } else {
100        info!("Skipping BF16 because CC < 8.0");
101    }
102    if min_cc >= MIN_F16_CC {
103        dtypes.push(DType::F16);
104    } else {
105        info!("Skipping F16 because CC < 5.3");
106    }
107    dtypes
108}
109
110fn get_dtypes_non_cuda() -> Vec<DType> {
111    vec![DType::BF16, DType::F16]
112}
113
114#[cfg(not(feature = "cuda"))]
115fn get_dtypes() -> Vec<DType> {
116    get_dtypes_non_cuda()
117}
118
119fn determine_auto_dtype_all(devices: &[&Device]) -> candle_core::Result<DType> {
120    // We can safely use bf16 for accelerate because we cast up to f32 in all matmuls anyway.
121    #[cfg(feature = "accelerate")]
122    return Ok(DType::BF16);
123    #[cfg(not(feature = "accelerate"))]
124    {
125        let dev_dtypes = get_dtypes();
126        for dtype in get_dtypes_non_cuda()
127            .iter()
128            .filter(|x| dev_dtypes.contains(x))
129        {
130            let mut results = Vec::new();
131            for device in devices {
132                // Try a matmul
133                let x = Tensor::zeros((2, 2), *dtype, device)?;
134                results.push(x.matmul(&x));
135            }
136            if results.iter().all(|x| x.is_ok()) {
137                return Ok(*dtype);
138            } else {
139                for result in results {
140                    match result {
141                        Ok(_) => (),
142                        Err(e) => match e {
143                            // For CUDA
144                            candle_core::Error::UnsupportedDTypeForOp(_, _) => continue,
145                            // Accelerate backend doesn't support f16/bf16
146                            // Metal backend doesn't support f16
147                            candle_core::Error::Msg(_) => continue,
148                            // This is when the metal backend doesn't support bf16
149                            candle_core::Error::Metal(_) => continue,
150                            // If running with RUST_BACKTRACE=1
151                            candle_core::Error::WithBacktrace { .. } => continue,
152                            other => return Err(other),
153                        },
154                    }
155                }
156            }
157        }
158        Ok(DType::F32)
159    }
160}
161
162impl TryIntoDType for ModelDType {
163    fn try_into_dtype(&self, devices: &[&Device]) -> Result<DType> {
164        let dtype = match self {
165            Self::Auto => Ok(determine_auto_dtype_all(devices).map_err(anyhow::Error::msg)?),
166            Self::BF16 => Ok(DType::BF16),
167            Self::F16 => Ok(DType::F16),
168            Self::F32 => Ok(DType::F32),
169        };
170        info!("DType selected is {:?}.", dtype.as_ref().unwrap());
171        dtype
172    }
173}