mistralrs_core/
matformer.rs

1use anyhow::{Context, Result};
2use serde::Deserialize;
3use std::collections::HashMap;
4use std::fs::File;
5use std::io::BufReader;
6use std::path::Path;
7use std::sync::Arc;
8
9#[derive(Debug, Clone)]
10pub struct Slice {
11    pub effective_params: f64,
12    pub ffn_hidden_dimensions: Vec<usize>,
13    pub layers_skipped: Option<Vec<usize>>,
14}
15
16#[derive(Debug)]
17pub struct MatformerConfig {
18    pub slices: HashMap<String, Slice>,
19}
20
21#[derive(Debug, Clone)]
22pub struct MatformerSliceConfig {
23    pub slice_name: String,
24    pub config: Arc<MatformerConfig>,
25}
26
27impl MatformerSliceConfig {
28    pub fn new(slice_name: String, config: Arc<MatformerConfig>) -> Self {
29        Self { slice_name, config }
30    }
31
32    pub fn get_slicing(&self) -> Option<&Slice> {
33        self.config.get_slicing(&self.slice_name)
34    }
35}
36
37#[derive(Debug, Deserialize)]
38struct CsvRecord {
39    name: String,
40    #[serde(rename = "# Layers")]
41    #[allow(dead_code)]
42    num_layers: u32,
43    #[serde(rename = "# Effective Params (B)")]
44    effective_params: f64,
45    #[serde(rename = "MMLU PT accuracy")]
46    #[allow(dead_code)]
47    mmlu_accuracy: String,
48    #[serde(rename = "FFN Hidden Dims")]
49    ffn_hidden_dims: String,
50    #[serde(rename = "Layers Skipped")]
51    layers_skipped: Option<String>,
52}
53
54impl MatformerConfig {
55    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
56        let file = File::open(&path).with_context(|| {
57            format!("Failed to open matformer config file: {:?}", path.as_ref())
58        })?;
59        let reader = BufReader::new(file);
60
61        let mut rdr = csv::Reader::from_reader(reader);
62        let mut slices = HashMap::new();
63
64        for result in rdr.deserialize() {
65            let record: CsvRecord = result.context("Failed to parse CSV record")?;
66
67            let ffn_hidden_dimensions = parse_ffn_hidden_dims(&record.ffn_hidden_dims)
68                .with_context(|| format!("Failed to parse FFN hidden dims for {}", record.name))?;
69
70            let layers_skipped = record
71                .layers_skipped
72                .as_ref()
73                .filter(|s| !s.is_empty())
74                .map(|s| parse_layers_skipped(s))
75                .transpose()
76                .with_context(|| format!("Failed to parse layers skipped for {}", record.name))?;
77
78            let slicing = Slice {
79                effective_params: record.effective_params,
80                ffn_hidden_dimensions,
81                layers_skipped,
82            };
83
84            slices.insert(record.name, slicing);
85        }
86
87        Ok(MatformerConfig { slices })
88    }
89
90    pub fn get_slicing(&self, name: &str) -> Option<&Slice> {
91        self.slices.get(name)
92    }
93}
94
95fn parse_ffn_hidden_dims(s: &str) -> Result<Vec<usize>> {
96    let s = s.trim();
97    if !s.starts_with('[') || !s.ends_with(']') {
98        anyhow::bail!("FFN hidden dims must be enclosed in brackets");
99    }
100
101    let inner = &s[1..s.len() - 1];
102    let parts: Vec<&str> = inner.split(',').collect();
103
104    let mut dimensions = Vec::with_capacity(parts.len());
105    for part in parts {
106        let dim = evaluate_expression(part.trim())
107            .with_context(|| format!("Failed to evaluate expression: {part}"))?;
108        dimensions.push(dim);
109    }
110
111    Ok(dimensions)
112}
113
114fn parse_layers_skipped(s: &str) -> Result<Vec<usize>> {
115    let s = s.trim();
116    if !s.starts_with('[') || !s.ends_with(']') {
117        anyhow::bail!("Layers skipped must be enclosed in brackets");
118    }
119
120    let inner = &s[1..s.len() - 1];
121    let parts: Vec<&str> = inner.split(',').collect();
122
123    let mut layers = Vec::with_capacity(parts.len());
124    for part in parts {
125        let layer = part
126            .trim()
127            .parse::<usize>()
128            .with_context(|| format!("Failed to parse layer number: {part}"))?;
129        layers.push(layer);
130    }
131
132    Ok(layers)
133}
134
135fn evaluate_expression(expr: &str) -> Result<usize> {
136    let expr = expr.trim();
137
138    // Handle simple number (with potential underscores)
139    if let Ok(num) = expr.replace('_', "").parse::<usize>() {
140        return Ok(num);
141    }
142
143    // Handle multiplication expressions like "2_048 * 4"
144    if expr.contains('*') {
145        let parts: Vec<&str> = expr.split('*').collect();
146        if parts.len() != 2 {
147            anyhow::bail!("Invalid multiplication expression: {}", expr);
148        }
149
150        let left = parts[0]
151            .trim()
152            .replace('_', "")
153            .parse::<usize>()
154            .with_context(|| format!("Failed to parse left operand: {}", parts[0]))?;
155        let right = parts[1]
156            .trim()
157            .parse::<usize>()
158            .with_context(|| format!("Failed to parse right operand: {}", parts[1]))?;
159
160        return Ok(left * right);
161    }
162
163    anyhow::bail!("Unsupported expression format: {}", expr)
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_evaluate_expression() {
172        assert_eq!(evaluate_expression("2048").unwrap(), 2048);
173        assert_eq!(evaluate_expression("2_048").unwrap(), 2048);
174        assert_eq!(evaluate_expression("2_048 * 4").unwrap(), 8192);
175        assert_eq!(evaluate_expression("2048 * 8").unwrap(), 16384);
176    }
177
178    #[test]
179    fn test_parse_ffn_hidden_dims() {
180        let dims = parse_ffn_hidden_dims("[2_048 * 4, 2_048 * 8, 2048 * 6]").unwrap();
181        assert_eq!(dims, vec![8192, 16384, 12288]);
182    }
183
184    #[test]
185    fn test_parse_layers_skipped() {
186        let layers = parse_layers_skipped("[20, 21, 22]").unwrap();
187        assert_eq!(layers, vec![20, 21, 22]);
188    }
189}