mistralrs_core/utils/
normal.rs1#![allow(dead_code, unused)]
2
3use std::{fmt::Display, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor};
7use serde::Deserialize;
8use tracing::info;
9
10#[derive(Clone, Copy, Default, Debug, Deserialize, PartialEq)]
11#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
12pub enum ModelDType {
18 #[default]
19 #[serde(rename = "auto")]
20 Auto,
21 #[serde(rename = "bf16")]
22 BF16,
23 #[serde(rename = "f16")]
24 F16,
25 #[serde(rename = "f32")]
26 F32,
27}
28
29impl Display for ModelDType {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 Self::Auto => write!(f, "auto"),
33 Self::BF16 => write!(f, "bf16"),
34 Self::F16 => write!(f, "f16"),
35 Self::F32 => write!(f, "f32"),
36 }
37 }
38}
39
40impl FromStr for ModelDType {
41 type Err = String;
42 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
43 match s.to_lowercase().as_str() {
44 "auto" => Ok(Self::Auto),
45 "bf16" => Ok(Self::BF16),
46 "f16" => Ok(Self::F16),
47 "f32" => Ok(Self::F32),
48 other => Err(format!("Model DType `{other}` is not supported.")),
49 }
50 }
51}
52
53pub trait TryIntoDType {
55 fn try_into_dtype(&self, devices: &[&Device]) -> Result<DType>;
56}
57
58impl TryIntoDType for DType {
59 fn try_into_dtype(&self, _: &[&Device]) -> Result<DType> {
60 info!("DType selected is {self:?}.");
61 if !matches!(self, DType::BF16 | DType::F32 | DType::F64 | DType::F16) {
62 anyhow::bail!("DType must be one of BF16, F16, F32, F64");
63 }
64 Ok(*self)
65 }
66}
67
68#[cfg(feature = "cuda")]
69fn get_dtypes() -> Vec<DType> {
70 use std::process::Command;
71
72 const MIN_BF16_CC: usize = 800;
74 const MIN_F16_CC: usize = 530;
76
77 let raw_out = Command::new("nvidia-smi")
78 .arg("--query-gpu=compute_cap")
79 .arg("--format=csv")
80 .output()
81 .expect("Failed to run `nvidia-smi` but CUDA is selected.")
82 .stdout;
83 let out = String::from_utf8(raw_out).expect("`nvidia-smi` did not return valid utf8");
84 let min_cc = out
86 .split('\n')
87 .skip(1)
88 .filter(|cc| !cc.trim().is_empty())
89 .map(|cc| cc.trim().parse::<f32>().unwrap())
90 .reduce(|a, b| if a < b { a } else { b })
91 .unwrap();
92 info!("Detected minimum CUDA compute capability {min_cc}");
93 #[allow(clippy::cast_possible_truncation)]
95 let min_cc = (min_cc * 100.) as usize;
96
97 let mut dtypes = Vec::new();
98 if min_cc >= MIN_BF16_CC {
99 dtypes.push(DType::BF16);
100 } else {
101 info!("Skipping BF16 because CC < 8.0");
102 }
103 if min_cc >= MIN_F16_CC {
104 dtypes.push(DType::F16);
105 } else {
106 info!("Skipping F16 because CC < 5.3");
107 }
108 dtypes
109}
110
111fn get_dtypes_non_cuda() -> Vec<DType> {
112 vec![DType::BF16, DType::F16]
113}
114
115#[cfg(not(feature = "cuda"))]
116fn get_dtypes() -> Vec<DType> {
117 get_dtypes_non_cuda()
118}
119
120fn determine_auto_dtype_all(devices: &[&Device]) -> candle_core::Result<DType> {
121 #[cfg(feature = "accelerate")]
123 return Ok(DType::BF16);
124 #[cfg(not(feature = "accelerate"))]
125 {
126 let dev_dtypes = get_dtypes();
127 for dtype in get_dtypes_non_cuda()
128 .iter()
129 .filter(|x| dev_dtypes.contains(x))
130 {
131 let mut results = Vec::new();
132 for device in devices {
133 let x = Tensor::zeros((2, 2), *dtype, device)?;
135 results.push(x.matmul(&x));
136 }
137 if results.iter().all(|x| x.is_ok()) {
138 return Ok(*dtype);
139 } else {
140 for result in results {
141 match result {
142 Ok(_) => (),
143 Err(e) => match e {
144 candle_core::Error::UnsupportedDTypeForOp(_, _) => continue,
146 candle_core::Error::Msg(_) => continue,
149 candle_core::Error::Metal(_) => continue,
151 candle_core::Error::WithBacktrace { .. } => continue,
153 other => return Err(other),
154 },
155 }
156 }
157 }
158 }
159 Ok(DType::F32)
160 }
161}
162
163impl TryIntoDType for ModelDType {
164 fn try_into_dtype(&self, devices: &[&Device]) -> Result<DType> {
165 let dtype = match self {
166 Self::Auto => Ok(determine_auto_dtype_all(devices).map_err(anyhow::Error::msg)?),
167 Self::BF16 => Ok(DType::BF16),
168 Self::F16 => Ok(DType::F16),
169 Self::F32 => Ok(DType::F32),
170 };
171 info!("DType selected is {:?}.", dtype.as_ref().unwrap());
172 dtype
173 }
174}