mistralrs_core/
matformer.rs1use anyhow::{Context, Result};
2use serde::Deserialize;
3use std::collections::HashMap;
4use std::fs::File;
5use std::io::BufReader;
6use std::path::Path;
7use std::sync::Arc;
8
9#[derive(Debug, Clone)]
10pub struct Slice {
11 pub effective_params: f64,
12 pub ffn_hidden_dimensions: Vec<usize>,
13 pub layers_skipped: Option<Vec<usize>>,
14}
15
16#[derive(Debug)]
17pub struct MatformerConfig {
18 pub slices: HashMap<String, Slice>,
19}
20
21#[derive(Debug, Clone)]
22pub struct MatformerSliceConfig {
23 pub slice_name: String,
24 pub config: Arc<MatformerConfig>,
25}
26
27impl MatformerSliceConfig {
28 pub fn new(slice_name: String, config: Arc<MatformerConfig>) -> Self {
29 Self { slice_name, config }
30 }
31
32 pub fn get_slicing(&self) -> Option<&Slice> {
33 self.config.get_slicing(&self.slice_name)
34 }
35}
36
37#[derive(Debug, Deserialize)]
38struct CsvRecord {
39 name: String,
40 #[serde(rename = "# Layers")]
41 #[allow(dead_code)]
42 num_layers: u32,
43 #[serde(rename = "# Effective Params (B)")]
44 effective_params: f64,
45 #[serde(rename = "MMLU PT accuracy")]
46 #[allow(dead_code)]
47 mmlu_accuracy: String,
48 #[serde(rename = "FFN Hidden Dims")]
49 ffn_hidden_dims: String,
50 #[serde(rename = "Layers Skipped")]
51 layers_skipped: Option<String>,
52}
53
54impl MatformerConfig {
55 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
56 let file = File::open(&path).with_context(|| {
57 format!("Failed to open matformer config file: {:?}", path.as_ref())
58 })?;
59 let reader = BufReader::new(file);
60
61 let mut rdr = csv::Reader::from_reader(reader);
62 let mut slices = HashMap::new();
63
64 for result in rdr.deserialize() {
65 let record: CsvRecord = result.context("Failed to parse CSV record")?;
66
67 let ffn_hidden_dimensions = parse_ffn_hidden_dims(&record.ffn_hidden_dims)
68 .with_context(|| format!("Failed to parse FFN hidden dims for {}", record.name))?;
69
70 let layers_skipped = record
71 .layers_skipped
72 .as_ref()
73 .filter(|s| !s.is_empty())
74 .map(|s| parse_layers_skipped(s))
75 .transpose()
76 .with_context(|| format!("Failed to parse layers skipped for {}", record.name))?;
77
78 let slicing = Slice {
79 effective_params: record.effective_params,
80 ffn_hidden_dimensions,
81 layers_skipped,
82 };
83
84 slices.insert(record.name, slicing);
85 }
86
87 Ok(MatformerConfig { slices })
88 }
89
90 pub fn get_slicing(&self, name: &str) -> Option<&Slice> {
91 self.slices.get(name)
92 }
93}
94
95fn parse_ffn_hidden_dims(s: &str) -> Result<Vec<usize>> {
96 let s = s.trim();
97 if !s.starts_with('[') || !s.ends_with(']') {
98 anyhow::bail!("FFN hidden dims must be enclosed in brackets");
99 }
100
101 let inner = &s[1..s.len() - 1];
102 let parts: Vec<&str> = inner.split(',').collect();
103
104 let mut dimensions = Vec::with_capacity(parts.len());
105 for part in parts {
106 let dim = evaluate_expression(part.trim())
107 .with_context(|| format!("Failed to evaluate expression: {part}"))?;
108 dimensions.push(dim);
109 }
110
111 Ok(dimensions)
112}
113
114fn parse_layers_skipped(s: &str) -> Result<Vec<usize>> {
115 let s = s.trim();
116 if !s.starts_with('[') || !s.ends_with(']') {
117 anyhow::bail!("Layers skipped must be enclosed in brackets");
118 }
119
120 let inner = &s[1..s.len() - 1];
121 let parts: Vec<&str> = inner.split(',').collect();
122
123 let mut layers = Vec::with_capacity(parts.len());
124 for part in parts {
125 let layer = part
126 .trim()
127 .parse::<usize>()
128 .with_context(|| format!("Failed to parse layer number: {part}"))?;
129 layers.push(layer);
130 }
131
132 Ok(layers)
133}
134
135fn evaluate_expression(expr: &str) -> Result<usize> {
136 let expr = expr.trim();
137
138 if let Ok(num) = expr.replace('_', "").parse::<usize>() {
140 return Ok(num);
141 }
142
143 if expr.contains('*') {
145 let parts: Vec<&str> = expr.split('*').collect();
146 if parts.len() != 2 {
147 anyhow::bail!("Invalid multiplication expression: {}", expr);
148 }
149
150 let left = parts[0]
151 .trim()
152 .replace('_', "")
153 .parse::<usize>()
154 .with_context(|| format!("Failed to parse left operand: {}", parts[0]))?;
155 let right = parts[1]
156 .trim()
157 .parse::<usize>()
158 .with_context(|| format!("Failed to parse right operand: {}", parts[1]))?;
159
160 return Ok(left * right);
161 }
162
163 anyhow::bail!("Unsupported expression format: {}", expr)
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_evaluate_expression() {
172 assert_eq!(evaluate_expression("2048").unwrap(), 2048);
173 assert_eq!(evaluate_expression("2_048").unwrap(), 2048);
174 assert_eq!(evaluate_expression("2_048 * 4").unwrap(), 8192);
175 assert_eq!(evaluate_expression("2048 * 8").unwrap(), 16384);
176 }
177
178 #[test]
179 fn test_parse_ffn_hidden_dims() {
180 let dims = parse_ffn_hidden_dims("[2_048 * 4, 2_048 * 8, 2048 * 6]").unwrap();
181 assert_eq!(dims, vec![8192, 16384, 12288]);
182 }
183
184 #[test]
185 fn test_parse_layers_skipped() {
186 let layers = parse_layers_skipped("[20, 21, 22]").unwrap();
187 assert_eq!(layers, vec![20, 21, 22]);
188 }
189}