mistralrs_core/utils/
normal.rs1#[allow(dead_code)]
2use std::{fmt::Display, str::FromStr};
3
4use anyhow::Result;
5use candle_core::{DType, Device, Tensor};
6use serde::Deserialize;
7use tracing::info;
8
9#[derive(Clone, Copy, Default, Debug, Deserialize, PartialEq)]
10#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
11pub enum ModelDType {
17 #[default]
18 #[serde(rename = "auto")]
19 Auto,
20 #[serde(rename = "bf16")]
21 BF16,
22 #[serde(rename = "f16")]
23 F16,
24 #[serde(rename = "f32")]
25 F32,
26}
27
28impl Display for ModelDType {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 match self {
31 Self::Auto => write!(f, "auto"),
32 Self::BF16 => write!(f, "bf16"),
33 Self::F16 => write!(f, "f16"),
34 Self::F32 => write!(f, "f32"),
35 }
36 }
37}
38
39impl FromStr for ModelDType {
40 type Err = String;
41 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
42 match s.to_lowercase().as_str() {
43 "auto" => Ok(Self::Auto),
44 "bf16" => Ok(Self::BF16),
45 "f16" => Ok(Self::F16),
46 "f32" => Ok(Self::F32),
47 other => Err(format!("Model DType `{other}` is not supported.")),
48 }
49 }
50}
51
52pub trait TryIntoDType {
54 fn try_into_dtype(&self, devices: &[&Device]) -> Result<DType>;
55}
56
57impl TryIntoDType for DType {
58 fn try_into_dtype(&self, _: &[&Device]) -> Result<DType> {
59 info!("DType selected is {self:?}.");
60 if !matches!(self, DType::BF16 | DType::F32 | DType::F64 | DType::F16) {
61 anyhow::bail!("DType must be one of BF16, F16, F32, F64");
62 }
63 Ok(*self)
64 }
65}
66
67#[cfg(feature = "cuda")]
68fn get_dtypes() -> Vec<DType> {
69 use std::process::Command;
70
71 const MIN_BF16_CC: usize = 800;
73 const MIN_F16_CC: usize = 530;
75
76 let raw_out = Command::new("nvidia-smi")
77 .arg("--query-gpu=compute_cap")
78 .arg("--format=csv")
79 .output()
80 .expect("Failed to run `nvidia-smi` but CUDA is selected.")
81 .stdout;
82 let out = String::from_utf8(raw_out).expect("`nvidia-smi` did not return valid utf8");
83 let min_cc = out
85 .split('\n')
86 .skip(1)
87 .filter(|cc| !cc.trim().is_empty())
88 .map(|cc| cc.trim().parse::<f32>().unwrap())
89 .reduce(|a, b| if a < b { a } else { b })
90 .unwrap();
91 info!("Detected minimum CUDA compute capability {min_cc}");
92 #[allow(clippy::cast_possible_truncation)]
94 let min_cc = (min_cc * 100.) as usize;
95
96 let mut dtypes = Vec::new();
97 if min_cc >= MIN_BF16_CC {
98 dtypes.push(DType::BF16);
99 } else {
100 info!("Skipping BF16 because CC < 8.0");
101 }
102 if min_cc >= MIN_F16_CC {
103 dtypes.push(DType::F16);
104 } else {
105 info!("Skipping F16 because CC < 5.3");
106 }
107 dtypes
108}
109
110fn get_dtypes_non_cuda() -> Vec<DType> {
111 vec![DType::BF16, DType::F16]
112}
113
114#[cfg(not(feature = "cuda"))]
115fn get_dtypes() -> Vec<DType> {
116 get_dtypes_non_cuda()
117}
118
119fn determine_auto_dtype_all(devices: &[&Device]) -> candle_core::Result<DType> {
120 #[cfg(feature = "accelerate")]
122 return Ok(DType::BF16);
123 #[cfg(not(feature = "accelerate"))]
124 {
125 let dev_dtypes = get_dtypes();
126 for dtype in get_dtypes_non_cuda()
127 .iter()
128 .filter(|x| dev_dtypes.contains(x))
129 {
130 let mut results = Vec::new();
131 for device in devices {
132 let x = Tensor::zeros((2, 2), *dtype, device)?;
134 results.push(x.matmul(&x));
135 }
136 if results.iter().all(|x| x.is_ok()) {
137 return Ok(*dtype);
138 } else {
139 for result in results {
140 match result {
141 Ok(_) => (),
142 Err(e) => match e {
143 candle_core::Error::UnsupportedDTypeForOp(_, _) => continue,
145 candle_core::Error::Msg(_) => continue,
148 candle_core::Error::Metal(_) => continue,
150 candle_core::Error::WithBacktrace { .. } => continue,
152 other => return Err(other),
153 },
154 }
155 }
156 }
157 }
158 Ok(DType::F32)
159 }
160}
161
162impl TryIntoDType for ModelDType {
163 fn try_into_dtype(&self, devices: &[&Device]) -> Result<DType> {
164 let dtype = match self {
165 Self::Auto => Ok(determine_auto_dtype_all(devices).map_err(anyhow::Error::msg)?),
166 Self::BF16 => Ok(DType::BF16),
167 Self::F16 => Ok(DType::F16),
168 Self::F32 => Ok(DType::F32),
169 };
170 info!("DType selected is {:?}.", dtype.as_ref().unwrap());
171 dtype
172 }
173}