mistralrs_core/topology/
mod.rs

1use std::{fs, io::Read, ops::Range, path::Path};
2
3use candle_core::Device;
4use indexmap::IndexMap;
5use itertools::Itertools;
6use mistralrs_quant::IsqType;
7use regex::Regex;
8use serde::Deserialize;
9
10use crate::parse_isq_value;
11
12const DEVICE_PATTERN: &str = r"^(cpu|cuda\[(\d+)\]|metal\[(\d+)\])$";
13
14#[derive(Deserialize)]
15pub struct DeserLayerTopology {
16    isq: Option<String>,
17    device: Option<String>,
18}
19
20#[derive(Deserialize)]
21pub struct DeserTopology(IndexMap<String, DeserLayerTopology>);
22
23#[derive(Clone, Debug)]
24pub struct LayerTopology {
25    pub isq: Option<IsqType>,
26    pub device: Option<Device>,
27}
28
29#[derive(PartialEq, Eq, Debug)]
30struct CustomRange {
31    start: usize,
32    end: usize,
33    index: usize,
34}
35
36impl From<CustomRange> for Range<usize> {
37    fn from(value: CustomRange) -> Self {
38        Self {
39            start: value.start,
40            end: value.end,
41        }
42    }
43}
44
45impl Ord for CustomRange {
46    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
47        // Order based on end position followed by declaration order so later ranges override
48        self.end
49            .cmp(&other.end)
50            .then_with(|| self.index.cmp(&other.index))
51    }
52}
53
54impl PartialOrd for CustomRange {
55    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
56        Some(self.cmp(other))
57    }
58}
59
60#[derive(Clone, Debug)]
61pub struct Topology {
62    pub layers: Vec<Option<LayerTopology>>,
63    pub patterns: Vec<(Regex, LayerTopology)>,
64}
65
66impl Topology {
67    /// Create an empty topology.
68    pub fn empty() -> Self {
69        Topology {
70            layers: Vec::new(),
71            patterns: Vec::new(),
72        }
73    }
74
75    pub fn with_capacity(cap: usize) -> Self {
76        Topology {
77            layers: vec![None; cap],
78            patterns: Vec::new(),
79        }
80    }
81
82    pub fn is_dummy_device_map(&self) -> bool {
83        self.layers
84            .iter()
85            .all(|l| l.is_none() || l.as_ref().is_some_and(|l| l.device.is_none()))
86            && self
87                .patterns
88                .iter()
89                .all(|(_, topo)| topo.device.as_ref().is_none())
90    }
91
92    pub fn with_range(mut self, range: Range<usize>, layer: LayerTopology) -> Self {
93        if self.layers.len() < range.end {
94            self.layers
95                .extend(vec![None; range.end - self.layers.len()]);
96        }
97        for i in range.start..range.end {
98            self.layers[i] = Some(layer.clone());
99        }
100        self
101    }
102
103    #[allow(clippy::should_implement_trait)]
104    pub fn from_str(topology: &str) -> anyhow::Result<Self> {
105        let deser: DeserTopology = serde_yaml::from_str(topology)?;
106        let device_regex = Regex::new(DEVICE_PATTERN)?;
107
108        let mut range_layers = Vec::new();
109        let mut pattern_layers = Vec::new();
110        for (index, (selector, DeserLayerTopology { isq, device })) in
111            deser.0.into_iter().enumerate()
112        {
113            let parsed_isq = if let Some(isq) = isq {
114                Some(parse_isq_value(&isq, None).map_err(anyhow::Error::msg)?)
115            } else {
116                None
117            };
118
119            let parsed_device = if let Some(device) = device {
120                let Some(captures) = device_regex.captures(&device) else {
121                    anyhow::bail!(
122                        "Device specifier must match regex {DEVICE_PATTERN}. Examples: `cpu`, `cuda[ORD]`, `metal[ORD]`"
123                    );
124                };
125                let device = if let Some(val) = captures.get(2).or(captures.get(3)) {
126                    let ord = val.as_str().parse::<usize>()?;
127                    let device = device.split('[').collect::<Vec<_>>()[0];
128                    match device {
129                        "cuda" => Device::new_cuda(ord)?,
130                        "metal" => Device::new_metal(ord)?,
131                        _ => unreachable!(),
132                    }
133                } else {
134                    Device::Cpu
135                };
136
137                Some(device)
138            } else {
139                None
140            };
141
142            if selector.starts_with('/') && selector.ends_with('/') && selector.len() >= 2 {
143                let pattern = &selector[1..selector.len() - 1];
144                let regex = Regex::new(pattern)
145                    .map_err(|err| anyhow::anyhow!("Invalid topology regex `{pattern}`: {err}"))?;
146                pattern_layers.push((
147                    regex,
148                    LayerTopology {
149                        isq: parsed_isq,
150                        device: parsed_device,
151                    },
152                ));
153                continue;
154            }
155
156            let (start, end) = if selector.contains('-') {
157                // Range (inclusive, exclusive)
158                let Some((start, end)) = selector.splitn(2, '-').collect_tuple() else {
159                    anyhow::bail!("Topology range segment must follow the format START-END")
160                };
161                (start.parse::<usize>()?, end.parse::<usize>()?)
162            } else {
163                // Single layer here
164                let layer = selector.parse::<usize>()?;
165                (layer, layer + 1)
166            };
167
168            if end <= start {
169                anyhow::bail!("Topology range end must be > start, got {end} <= {start}");
170            }
171            let range = CustomRange { start, end, index };
172
173            range_layers.push((
174                range,
175                LayerTopology {
176                    isq: parsed_isq,
177                    device: parsed_device,
178                },
179            ));
180        }
181        // Sort so that we increase in end points
182        range_layers.sort_by(|(r1, _), (r2, _)| r1.cmp(r2));
183
184        let capacity = range_layers.iter().map(|(r, _)| r.end).max().unwrap_or(0);
185        let mut this = if capacity == 0 {
186            Self::empty()
187        } else {
188            Self::with_capacity(capacity)
189        };
190        for (range, layer) in range_layers {
191            for i in range.start..range.end {
192                this.layers[i] = Some(layer.clone());
193            }
194        }
195        this.patterns = pattern_layers;
196        Ok(this)
197    }
198
199    pub fn from_reader<R: Read>(mut reader: R) -> anyhow::Result<Self> {
200        let mut buf = String::new();
201        reader.read_to_string(&mut buf)?;
202        Self::from_str(&buf)
203    }
204
205    pub fn from_path<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
206        let buf = fs::read_to_string(path)?;
207        Self::from_str(&buf)
208    }
209
210    pub fn from_option_path<P: AsRef<Path>>(path: Option<P>) -> anyhow::Result<Option<Self>> {
211        if let Some(path) = path {
212            let buf = fs::read_to_string(path)?;
213            Ok(Some(Self::from_str(&buf)?))
214        } else {
215            Ok(None)
216        }
217    }
218
219    pub fn layer_for(&self, layer: usize) -> Option<&LayerTopology> {
220        self.layers.get(layer).and_then(|lt| lt.as_ref())
221    }
222
223    pub fn match_for_name(&self, name: &str) -> Option<LayerTopology> {
224        for (regex, layer) in self.patterns.iter().rev() {
225            if regex.is_match(name) {
226                return Some(layer.clone());
227            }
228        }
229        None
230    }
231
232    pub fn pattern_overrides(&self) -> Vec<(Regex, LayerTopology)> {
233        self.patterns
234            .iter()
235            .rev()
236            .map(|(regex, topo)| (regex.clone(), topo.clone()))
237            .collect()
238    }
239
240    pub fn requires_post_quantization(&self) -> bool {
241        self.layers.iter().any(|layer| {
242            layer
243                .as_ref()
244                .is_some_and(|layer| layer.isq.is_some() || layer.device.is_some())
245        })
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    fn layer_isq(topology: &Topology, layer: usize) -> Option<IsqType> {
254        topology
255            .layer_for(layer)
256            .and_then(|lt| lt.isq.as_ref().copied())
257    }
258
259    #[test]
260    fn highest_end_range_overrides_lower_end() {
261        let yaml = "0-4:\n  isq: Q4K\n2-6:\n  isq: Q6K\n";
262        let topology = Topology::from_str(yaml).expect("topology parses");
263
264        assert_eq!(layer_isq(&topology, 0), Some(IsqType::Q4K));
265        assert_eq!(layer_isq(&topology, 2), Some(IsqType::Q6K));
266        assert_eq!(layer_isq(&topology, 5), Some(IsqType::Q6K));
267    }
268
269    #[test]
270    fn later_range_with_same_end_wins() {
271        let yaml = "0-4:\n  isq: Q4K\n2-4:\n  isq: Q3K\n";
272        let topology = Topology::from_str(yaml).expect("topology parses");
273
274        assert_eq!(layer_isq(&topology, 1), Some(IsqType::Q4K));
275        assert_eq!(layer_isq(&topology, 2), Some(IsqType::Q3K));
276        assert_eq!(layer_isq(&topology, 3), Some(IsqType::Q3K));
277    }
278
279    #[test]
280    fn regex_overrides_respect_declaration_order() {
281        let yaml = r#"'/ffn\./':
282  isq: Q4K
283'/ffn\.weight$/':
284  isq: Q6K
285"#;
286        let topology = Topology::from_str(yaml).expect("topology parses");
287
288        let match_exact = topology
289            .match_for_name("model.layers.2.ffn.weight")
290            .expect("regex match");
291        assert_eq!(match_exact.isq, Some(IsqType::Q6K));
292
293        let overrides = topology.pattern_overrides();
294        assert_eq!(overrides.len(), 2);
295        assert_eq!(overrides[0].0.as_str(), "ffn\\.weight$");
296        assert_eq!(overrides[1].0.as_str(), "ffn\\.");
297    }
298
299    #[test]
300    fn match_for_name_returns_none_when_unmatched() {
301        let yaml = "0-2:\n  isq: Q4K\n";
302        let topology = Topology::from_str(yaml).expect("topology parses");
303        assert!(topology.match_for_name("transformer.wte.weight").is_none());
304    }
305}