mistralrs_core/topology/
mod.rs1use std::{fs, io::Read, ops::Range, path::Path};
2
3use candle_core::Device;
4use indexmap::IndexMap;
5use itertools::Itertools;
6use mistralrs_quant::IsqType;
7use regex::Regex;
8use serde::Deserialize;
9
10use crate::parse_isq_value;
11
12const DEVICE_PATTERN: &str = r"^(cpu|cuda\[(\d+)\]|metal\[(\d+)\])$";
13
14#[derive(Deserialize)]
15pub struct DeserLayerTopology {
16 isq: Option<String>,
17 device: Option<String>,
18}
19
20#[derive(Deserialize)]
21pub struct DeserTopology(IndexMap<String, DeserLayerTopology>);
22
23#[derive(Clone, Debug)]
24pub struct LayerTopology {
25 pub isq: Option<IsqType>,
26 pub device: Option<Device>,
27}
28
29#[derive(PartialEq, Eq, Debug)]
30struct CustomRange {
31 start: usize,
32 end: usize,
33 index: usize,
34}
35
36impl From<CustomRange> for Range<usize> {
37 fn from(value: CustomRange) -> Self {
38 Self {
39 start: value.start,
40 end: value.end,
41 }
42 }
43}
44
45impl Ord for CustomRange {
46 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
47 self.end
49 .cmp(&other.end)
50 .then_with(|| self.index.cmp(&other.index))
51 }
52}
53
54impl PartialOrd for CustomRange {
55 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
56 Some(self.cmp(other))
57 }
58}
59
60#[derive(Clone, Debug)]
61pub struct Topology {
62 pub layers: Vec<Option<LayerTopology>>,
63 pub patterns: Vec<(Regex, LayerTopology)>,
64}
65
66impl Topology {
67 pub fn empty() -> Self {
69 Topology {
70 layers: Vec::new(),
71 patterns: Vec::new(),
72 }
73 }
74
75 pub fn with_capacity(cap: usize) -> Self {
76 Topology {
77 layers: vec![None; cap],
78 patterns: Vec::new(),
79 }
80 }
81
82 pub fn is_dummy_device_map(&self) -> bool {
83 self.layers
84 .iter()
85 .all(|l| l.is_none() || l.as_ref().is_some_and(|l| l.device.is_none()))
86 && self
87 .patterns
88 .iter()
89 .all(|(_, topo)| topo.device.as_ref().is_none())
90 }
91
92 pub fn with_range(mut self, range: Range<usize>, layer: LayerTopology) -> Self {
93 if self.layers.len() < range.end {
94 self.layers
95 .extend(vec![None; range.end - self.layers.len()]);
96 }
97 for i in range.start..range.end {
98 self.layers[i] = Some(layer.clone());
99 }
100 self
101 }
102
103 #[allow(clippy::should_implement_trait)]
104 pub fn from_str(topology: &str) -> anyhow::Result<Self> {
105 let deser: DeserTopology = serde_yaml::from_str(topology)?;
106 let device_regex = Regex::new(DEVICE_PATTERN)?;
107
108 let mut range_layers = Vec::new();
109 let mut pattern_layers = Vec::new();
110 for (index, (selector, DeserLayerTopology { isq, device })) in
111 deser.0.into_iter().enumerate()
112 {
113 let parsed_isq = if let Some(isq) = isq {
114 Some(parse_isq_value(&isq, None).map_err(anyhow::Error::msg)?)
115 } else {
116 None
117 };
118
119 let parsed_device = if let Some(device) = device {
120 let Some(captures) = device_regex.captures(&device) else {
121 anyhow::bail!(
122 "Device specifier must match regex {DEVICE_PATTERN}. Examples: `cpu`, `cuda[ORD]`, `metal[ORD]`"
123 );
124 };
125 let device = if let Some(val) = captures.get(2).or(captures.get(3)) {
126 let ord = val.as_str().parse::<usize>()?;
127 let device = device.split('[').collect::<Vec<_>>()[0];
128 match device {
129 "cuda" => Device::new_cuda(ord)?,
130 "metal" => Device::new_metal(ord)?,
131 _ => unreachable!(),
132 }
133 } else {
134 Device::Cpu
135 };
136
137 Some(device)
138 } else {
139 None
140 };
141
142 if selector.starts_with('/') && selector.ends_with('/') && selector.len() >= 2 {
143 let pattern = &selector[1..selector.len() - 1];
144 let regex = Regex::new(pattern)
145 .map_err(|err| anyhow::anyhow!("Invalid topology regex `{pattern}`: {err}"))?;
146 pattern_layers.push((
147 regex,
148 LayerTopology {
149 isq: parsed_isq,
150 device: parsed_device,
151 },
152 ));
153 continue;
154 }
155
156 let (start, end) = if selector.contains('-') {
157 let Some((start, end)) = selector.splitn(2, '-').collect_tuple() else {
159 anyhow::bail!("Topology range segment must follow the format START-END")
160 };
161 (start.parse::<usize>()?, end.parse::<usize>()?)
162 } else {
163 let layer = selector.parse::<usize>()?;
165 (layer, layer + 1)
166 };
167
168 if end <= start {
169 anyhow::bail!("Topology range end must be > start, got {end} <= {start}");
170 }
171 let range = CustomRange { start, end, index };
172
173 range_layers.push((
174 range,
175 LayerTopology {
176 isq: parsed_isq,
177 device: parsed_device,
178 },
179 ));
180 }
181 range_layers.sort_by(|(r1, _), (r2, _)| r1.cmp(r2));
183
184 let capacity = range_layers.iter().map(|(r, _)| r.end).max().unwrap_or(0);
185 let mut this = if capacity == 0 {
186 Self::empty()
187 } else {
188 Self::with_capacity(capacity)
189 };
190 for (range, layer) in range_layers {
191 for i in range.start..range.end {
192 this.layers[i] = Some(layer.clone());
193 }
194 }
195 this.patterns = pattern_layers;
196 Ok(this)
197 }
198
199 pub fn from_reader<R: Read>(mut reader: R) -> anyhow::Result<Self> {
200 let mut buf = String::new();
201 reader.read_to_string(&mut buf)?;
202 Self::from_str(&buf)
203 }
204
205 pub fn from_path<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
206 let buf = fs::read_to_string(path)?;
207 Self::from_str(&buf)
208 }
209
210 pub fn from_option_path<P: AsRef<Path>>(path: Option<P>) -> anyhow::Result<Option<Self>> {
211 if let Some(path) = path {
212 let buf = fs::read_to_string(path)?;
213 Ok(Some(Self::from_str(&buf)?))
214 } else {
215 Ok(None)
216 }
217 }
218
219 pub fn layer_for(&self, layer: usize) -> Option<&LayerTopology> {
220 self.layers.get(layer).and_then(|lt| lt.as_ref())
221 }
222
223 pub fn match_for_name(&self, name: &str) -> Option<LayerTopology> {
224 for (regex, layer) in self.patterns.iter().rev() {
225 if regex.is_match(name) {
226 return Some(layer.clone());
227 }
228 }
229 None
230 }
231
232 pub fn pattern_overrides(&self) -> Vec<(Regex, LayerTopology)> {
233 self.patterns
234 .iter()
235 .rev()
236 .map(|(regex, topo)| (regex.clone(), topo.clone()))
237 .collect()
238 }
239
240 pub fn requires_post_quantization(&self) -> bool {
241 self.layers.iter().any(|layer| {
242 layer
243 .as_ref()
244 .is_some_and(|layer| layer.isq.is_some() || layer.device.is_some())
245 })
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 fn layer_isq(topology: &Topology, layer: usize) -> Option<IsqType> {
254 topology
255 .layer_for(layer)
256 .and_then(|lt| lt.isq.as_ref().copied())
257 }
258
259 #[test]
260 fn highest_end_range_overrides_lower_end() {
261 let yaml = "0-4:\n isq: Q4K\n2-6:\n isq: Q6K\n";
262 let topology = Topology::from_str(yaml).expect("topology parses");
263
264 assert_eq!(layer_isq(&topology, 0), Some(IsqType::Q4K));
265 assert_eq!(layer_isq(&topology, 2), Some(IsqType::Q6K));
266 assert_eq!(layer_isq(&topology, 5), Some(IsqType::Q6K));
267 }
268
269 #[test]
270 fn later_range_with_same_end_wins() {
271 let yaml = "0-4:\n isq: Q4K\n2-4:\n isq: Q3K\n";
272 let topology = Topology::from_str(yaml).expect("topology parses");
273
274 assert_eq!(layer_isq(&topology, 1), Some(IsqType::Q4K));
275 assert_eq!(layer_isq(&topology, 2), Some(IsqType::Q3K));
276 assert_eq!(layer_isq(&topology, 3), Some(IsqType::Q3K));
277 }
278
279 #[test]
280 fn regex_overrides_respect_declaration_order() {
281 let yaml = r#"'/ffn\./':
282 isq: Q4K
283'/ffn\.weight$/':
284 isq: Q6K
285"#;
286 let topology = Topology::from_str(yaml).expect("topology parses");
287
288 let match_exact = topology
289 .match_for_name("model.layers.2.ffn.weight")
290 .expect("regex match");
291 assert_eq!(match_exact.isq, Some(IsqType::Q6K));
292
293 let overrides = topology.pattern_overrides();
294 assert_eq!(overrides.len(), 2);
295 assert_eq!(overrides[0].0.as_str(), "ffn\\.weight$");
296 assert_eq!(overrides[1].0.as_str(), "ffn\\.");
297 }
298
299 #[test]
300 fn match_for_name_returns_none_when_unmatched() {
301 let yaml = "0-2:\n isq: Q4K\n";
302 let topology = Topology::from_str(yaml).expect("topology parses");
303 assert!(topology.match_for_name("transformer.wte.weight").is_none());
304 }
305}