mistralrs_core/topology/
mod.rs1use std::{collections::HashMap, fs, io::Read, ops::Range, path::Path};
2
3use candle_core::Device;
4use itertools::Itertools;
5use mistralrs_quant::IsqType;
6use regex::Regex;
7use serde::Deserialize;
8
9use crate::parse_isq_value;
10
11const DEVICE_PATTERN: &str = r"^(cpu|cuda\[(\d+)\]|metal\[(\d+)\])$";
12
13#[derive(Deserialize)]
14pub struct DeserLayerTopology {
15 isq: Option<String>,
16 device: Option<String>,
17}
18
19#[derive(Deserialize)]
20pub struct DeserTopology(HashMap<String, DeserLayerTopology>);
21
22#[derive(Clone, Debug)]
23pub struct LayerTopology {
24 pub isq: Option<IsqType>,
25 pub device: Option<Device>,
26}
27
28#[derive(PartialEq, Eq, Debug)]
29struct CustomRange {
30 start: usize,
31 end: usize,
32}
33
34impl From<CustomRange> for Range<usize> {
35 fn from(value: CustomRange) -> Self {
36 Self {
37 start: value.start,
38 end: value.end,
39 }
40 }
41}
42
43impl Ord for CustomRange {
44 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
45 self.end.cmp(&other.end)
47 }
48}
49
50impl PartialOrd for CustomRange {
51 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
52 Some(self.cmp(other))
53 }
54}
55
56#[derive(Clone, Debug)]
57pub struct Topology(pub Vec<Option<LayerTopology>>);
58
59impl Topology {
60 pub fn empty() -> Self {
62 Topology(Vec::new())
63 }
64
65 pub fn with_capacity(cap: usize) -> Self {
66 Topology(vec![None; cap])
67 }
68
69 pub fn is_dummy_device_map(&self) -> bool {
70 self.0
71 .iter()
72 .all(|l| l.is_none() || l.as_ref().is_some_and(|l| l.device.is_none()))
73 }
74
75 pub fn with_range(mut self, range: Range<usize>, layer: LayerTopology) -> Self {
76 if self.0.len() < range.end {
77 self.0.extend(vec![None; range.end - self.0.len()]);
78 }
79 for i in range.start..range.end {
80 self.0[i] = Some(layer.clone());
81 }
82 self
83 }
84
85 #[allow(clippy::should_implement_trait)]
86 pub fn from_str(topology: &str) -> anyhow::Result<Self> {
87 let deser: DeserTopology = serde_yaml::from_str(topology)?;
88 let device_regex = Regex::new(DEVICE_PATTERN)?;
89
90 let mut layers = Vec::new();
91 for (range, DeserLayerTopology { isq, device }) in deser.0 {
92 let (start, end) = if range.contains('-') {
94 let Some((start, end)) = range.splitn(2, '-').collect_tuple() else {
96 anyhow::bail!("Topology range segment must follow the format START-END")
97 };
98 (start.parse::<usize>()?, end.parse::<usize>()?)
99 } else {
100 let layer = range.parse::<usize>()?;
102 (layer, layer + 1)
103 };
104
105 if end <= start {
106 anyhow::bail!("Topology range end must be > start, got {end} <= {start}");
107 }
108 let range = CustomRange { start, end };
109 let isq = if let Some(isq) = isq {
110 Some(parse_isq_value(&isq).map_err(anyhow::Error::msg)?)
111 } else {
112 None
113 };
114
115 let device = if let Some(device) = device {
117 let Some(captures) = device_regex.captures(&device) else {
118 anyhow::bail!("Device specifier must match regex {DEVICE_PATTERN}. Examples: `cpu`, `cuda[ORD]`, `metal[ORD]`");
119 };
120 let device = if let Some(val) = captures.get(2).or(captures.get(3)) {
121 let ord = val.as_str().parse::<usize>()?;
122 let device = device.split('[').collect::<Vec<_>>()[0];
123 match device {
124 "cuda" => Device::new_cuda(ord)?,
125 "metal" => Device::new_metal(ord)?,
126 _ => unreachable!(),
127 }
128 } else {
129 Device::Cpu
130 };
131
132 Some(device)
133 } else {
134 None
135 };
136
137 let layer_topo = LayerTopology { isq, device };
138 layers.push((range, layer_topo));
139 }
140 layers.sort_by(|(r1, _), (r2, _)| r1.cmp(r2));
142
143 let mut this = Self::with_capacity(layers.last().unwrap().0.end);
144 for (range, layer) in layers {
145 for i in range.start..range.end {
146 this.0[i] = Some(layer.clone());
147 }
148 }
149 Ok(this)
150 }
151
152 pub fn from_reader<R: Read>(mut reader: R) -> anyhow::Result<Self> {
153 let mut buf = String::new();
154 reader.read_to_string(&mut buf)?;
155 Self::from_str(&buf)
156 }
157
158 pub fn from_path<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
159 let buf = fs::read_to_string(path)?;
160 Self::from_str(&buf)
161 }
162
163 pub fn from_option_path<P: AsRef<Path>>(path: Option<P>) -> anyhow::Result<Option<Self>> {
164 if let Some(path) = path {
165 let buf = fs::read_to_string(path)?;
166 Ok(Some(Self::from_str(&buf)?))
167 } else {
168 Ok(None)
169 }
170 }
171}