mistralrs_core/utils/
normal.rs

1#![allow(dead_code, unused)]
2
3use std::{fmt::Display, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor};
7use serde::Deserialize;
8use tracing::info;
9
10#[derive(Clone, Copy, Default, Debug, Deserialize, PartialEq)]
11#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
12/// DType for the model.
13///
14/// If the model is quantized, this is ignored so it is reasonable to use the [`Default`] impl.
15///
16/// Note: When using `Auto`, fallback pattern is: BF16 -> F16 -> 32
17pub enum ModelDType {
18    #[default]
19    #[serde(rename = "auto")]
20    Auto,
21    #[serde(rename = "bf16")]
22    BF16,
23    #[serde(rename = "f16")]
24    F16,
25    #[serde(rename = "f32")]
26    F32,
27}
28
29impl Display for ModelDType {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            Self::Auto => write!(f, "auto"),
33            Self::BF16 => write!(f, "bf16"),
34            Self::F16 => write!(f, "f16"),
35            Self::F32 => write!(f, "f32"),
36        }
37    }
38}
39
40impl FromStr for ModelDType {
41    type Err = String;
42    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
43        match s.to_lowercase().as_str() {
44            "auto" => Ok(Self::Auto),
45            "bf16" => Ok(Self::BF16),
46            "f16" => Ok(Self::F16),
47            "f32" => Ok(Self::F32),
48            other => Err(format!("Model DType `{other}` is not supported.")),
49        }
50    }
51}
52
53/// Type which can be converted to a DType
54pub trait TryIntoDType {
55    fn try_into_dtype(&self, devices: &[&Device]) -> Result<DType>;
56}
57
58impl TryIntoDType for DType {
59    fn try_into_dtype(&self, _: &[&Device]) -> Result<DType> {
60        info!("DType selected is {self:?}.");
61        if !matches!(self, DType::BF16 | DType::F32 | DType::F64 | DType::F16) {
62            anyhow::bail!("DType must be one of BF16, F16, F32, F64");
63        }
64        Ok(*self)
65    }
66}
67
68#[cfg(feature = "cuda")]
69fn get_dtypes() -> Vec<DType> {
70    use std::process::Command;
71
72    // >= is supported
73    const MIN_BF16_CC: usize = 800;
74    // >= is supported
75    const MIN_F16_CC: usize = 530;
76
77    let raw_out = Command::new("nvidia-smi")
78        .arg("--query-gpu=compute_cap")
79        .arg("--format=csv")
80        .output()
81        .expect("Failed to run `nvidia-smi` but CUDA is selected.")
82        .stdout;
83    let out = String::from_utf8(raw_out).expect("`nvidia-smi` did not return valid utf8");
84    // This reduce-min will always return at least one value so unwrap is OK.
85    let min_cc = out
86        .split('\n')
87        .skip(1)
88        .filter(|cc| !cc.trim().is_empty())
89        .map(|cc| cc.trim().parse::<f32>().unwrap())
90        .reduce(|a, b| if a < b { a } else { b })
91        .unwrap();
92    info!("Detected minimum CUDA compute capability {min_cc}");
93    // 7.5 -> 750
94    #[allow(clippy::cast_possible_truncation)]
95    let min_cc = (min_cc * 100.) as usize;
96
97    let mut dtypes = Vec::new();
98    if min_cc >= MIN_BF16_CC {
99        dtypes.push(DType::BF16);
100    } else {
101        info!("Skipping BF16 because CC < 8.0");
102    }
103    if min_cc >= MIN_F16_CC {
104        dtypes.push(DType::F16);
105    } else {
106        info!("Skipping F16 because CC < 5.3");
107    }
108    dtypes
109}
110
111fn get_dtypes_non_cuda() -> Vec<DType> {
112    vec![DType::BF16, DType::F16]
113}
114
115#[cfg(not(feature = "cuda"))]
116fn get_dtypes() -> Vec<DType> {
117    get_dtypes_non_cuda()
118}
119
120fn determine_auto_dtype_all(devices: &[&Device]) -> candle_core::Result<DType> {
121    // We can safely use bf16 for accelerate because we cast up to f32 in all matmuls anyway.
122    #[cfg(feature = "accelerate")]
123    return Ok(DType::BF16);
124    #[cfg(not(feature = "accelerate"))]
125    {
126        let dev_dtypes = get_dtypes();
127        for dtype in get_dtypes_non_cuda()
128            .iter()
129            .filter(|x| dev_dtypes.contains(x))
130        {
131            let mut results = Vec::new();
132            for device in devices {
133                // Try a matmul
134                let x = Tensor::zeros((2, 2), *dtype, device)?;
135                results.push(x.matmul(&x));
136            }
137            if results.iter().all(|x| x.is_ok()) {
138                return Ok(*dtype);
139            } else {
140                for result in results {
141                    match result {
142                        Ok(_) => (),
143                        Err(e) => match e {
144                            // For CUDA
145                            candle_core::Error::UnsupportedDTypeForOp(_, _) => continue,
146                            // Accelerate backend doesn't support f16/bf16
147                            // Metal backend doesn't support f16
148                            candle_core::Error::Msg(_) => continue,
149                            // This is when the metal backend doesn't support bf16
150                            candle_core::Error::Metal(_) => continue,
151                            // If running with RUST_BACKTRACE=1
152                            candle_core::Error::WithBacktrace { .. } => continue,
153                            other => return Err(other),
154                        },
155                    }
156                }
157            }
158        }
159        Ok(DType::F32)
160    }
161}
162
163impl TryIntoDType for ModelDType {
164    fn try_into_dtype(&self, devices: &[&Device]) -> Result<DType> {
165        let dtype = match self {
166            Self::Auto => Ok(determine_auto_dtype_all(devices).map_err(anyhow::Error::msg)?),
167            Self::BF16 => Ok(DType::BF16),
168            Self::F16 => Ok(DType::F16),
169            Self::F32 => Ok(DType::F32),
170        };
171        info!("DType selected is {:?}.", dtype.as_ref().unwrap());
172        dtype
173    }
174}