mistralrs_core/topology/
mod.rs

1use std::{collections::HashMap, fs, io::Read, ops::Range, path::Path};
2
3use candle_core::Device;
4use itertools::Itertools;
5use mistralrs_quant::IsqType;
6use regex::Regex;
7use serde::Deserialize;
8
9use crate::parse_isq_value;
10
11const DEVICE_PATTERN: &str = r"^(cpu|cuda\[(\d+)\]|metal\[(\d+)\])$";
12
13#[derive(Deserialize)]
14pub struct DeserLayerTopology {
15    isq: Option<String>,
16    device: Option<String>,
17}
18
19#[derive(Deserialize)]
20pub struct DeserTopology(HashMap<String, DeserLayerTopology>);
21
22#[derive(Clone, Debug)]
23pub struct LayerTopology {
24    pub isq: Option<IsqType>,
25    pub device: Option<Device>,
26}
27
28#[derive(PartialEq, Eq, Debug)]
29struct CustomRange {
30    start: usize,
31    end: usize,
32}
33
34impl From<CustomRange> for Range<usize> {
35    fn from(value: CustomRange) -> Self {
36        Self {
37            start: value.start,
38            end: value.end,
39        }
40    }
41}
42
43impl Ord for CustomRange {
44    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
45        // Order based on end position
46        self.end.cmp(&other.end)
47    }
48}
49
50impl PartialOrd for CustomRange {
51    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
52        Some(self.cmp(other))
53    }
54}
55
56#[derive(Clone, Debug)]
57pub struct Topology(pub Vec<Option<LayerTopology>>);
58
59impl Topology {
60    /// Create an empty topology.
61    pub fn empty() -> Self {
62        Topology(Vec::new())
63    }
64
65    pub fn with_capacity(cap: usize) -> Self {
66        Topology(vec![None; cap])
67    }
68
69    pub fn is_dummy_device_map(&self) -> bool {
70        self.0
71            .iter()
72            .all(|l| l.is_none() || l.as_ref().is_some_and(|l| l.device.is_none()))
73    }
74
75    pub fn with_range(mut self, range: Range<usize>, layer: LayerTopology) -> Self {
76        if self.0.len() < range.end {
77            self.0.extend(vec![None; range.end - self.0.len()]);
78        }
79        for i in range.start..range.end {
80            self.0[i] = Some(layer.clone());
81        }
82        self
83    }
84
85    #[allow(clippy::should_implement_trait)]
86    pub fn from_str(topology: &str) -> anyhow::Result<Self> {
87        let deser: DeserTopology = serde_yaml::from_str(topology)?;
88        let device_regex = Regex::new(DEVICE_PATTERN)?;
89
90        let mut layers = Vec::new();
91        for (range, DeserLayerTopology { isq, device }) in deser.0 {
92            // Parse isq
93            let (start, end) = if range.contains('-') {
94                // Range (inclusive, exclusive)
95                let Some((start, end)) = range.splitn(2, '-').collect_tuple() else {
96                    anyhow::bail!("Topology range segment must follow the format START-END")
97                };
98                (start.parse::<usize>()?, end.parse::<usize>()?)
99            } else {
100                // Single layer here
101                let layer = range.parse::<usize>()?;
102                (layer, layer + 1)
103            };
104
105            if end <= start {
106                anyhow::bail!("Topology range end must be > start, got {end} <= {start}");
107            }
108            let range = CustomRange { start, end };
109            let isq = if let Some(isq) = isq {
110                Some(parse_isq_value(&isq).map_err(anyhow::Error::msg)?)
111            } else {
112                None
113            };
114
115            // Parse device
116            let device = if let Some(device) = device {
117                let Some(captures) = device_regex.captures(&device) else {
118                    anyhow::bail!("Device specifier must match regex {DEVICE_PATTERN}. Examples: `cpu`, `cuda[ORD]`, `metal[ORD]`");
119                };
120                let device = if let Some(val) = captures.get(2).or(captures.get(3)) {
121                    let ord = val.as_str().parse::<usize>()?;
122                    let device = device.split('[').collect::<Vec<_>>()[0];
123                    match device {
124                        "cuda" => Device::new_cuda(ord)?,
125                        "metal" => Device::new_metal(ord)?,
126                        _ => unreachable!(),
127                    }
128                } else {
129                    Device::Cpu
130                };
131
132                Some(device)
133            } else {
134                None
135            };
136
137            let layer_topo = LayerTopology { isq, device };
138            layers.push((range, layer_topo));
139        }
140        // Sort so that we increase in end points
141        layers.sort_by(|(r1, _), (r2, _)| r1.cmp(r2));
142
143        let mut this = Self::with_capacity(layers.last().unwrap().0.end);
144        for (range, layer) in layers {
145            for i in range.start..range.end {
146                this.0[i] = Some(layer.clone());
147            }
148        }
149        Ok(this)
150    }
151
152    pub fn from_reader<R: Read>(mut reader: R) -> anyhow::Result<Self> {
153        let mut buf = String::new();
154        reader.read_to_string(&mut buf)?;
155        Self::from_str(&buf)
156    }
157
158    pub fn from_path<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
159        let buf = fs::read_to_string(path)?;
160        Self::from_str(&buf)
161    }
162
163    pub fn from_option_path<P: AsRef<Path>>(path: Option<P>) -> anyhow::Result<Option<Self>> {
164        if let Some(path) = path {
165            let buf = fs::read_to_string(path)?;
166            Ok(Some(Self::from_str(&buf)?))
167        } else {
168            Ok(None)
169        }
170    }
171}