mistralrs_core/amoe/
mod.rs

1use std::{
2    collections::HashMap,
3    fs,
4    path::Path,
5    sync::{Arc, RwLock},
6};
7
8use candle_core::{safetensors, DType, Device, Result, Tensor, Var, D};
9use candle_nn::{Linear, ModuleT, VarMap};
10use mistralrs_quant::{QuantMethod, ShardedSafeTensors, ShardedVarBuilder};
11use serde::{Deserialize, Serialize};
12
13mod inputs;
14mod macros;
15pub use inputs::{AnyMoeTrainingInputRow, AnyMoeTrainingInputs, AnyMoeTrainingResult};
16use tracing::info;
17
18use crate::{
19    layers::{linear, Activation},
20    ops::{TopKLastDimOp, TopKOutput},
21    serde_default_fn,
22};
23
24/// Implemented by the base model of an AnyMoe.
25pub trait AnyMoeBaseModelMixin {
26    fn get_vars(&self) -> Vec<Vec<Var>> {
27        self.get_mlps()
28            .iter()
29            .filter(|mlp| mlp.is_moe_layer())
30            .map(|mlp| mlp.get_vars())
31            .collect::<Vec<_>>()
32    }
33    fn finish_training(&mut self, gate_model_id: Option<String>) -> Result<()> {
34        let mut out = HashMap::new();
35        for mlp in self
36            .get_mlps_mut()
37            .iter_mut()
38            .filter(|mlp| mlp.is_moe_layer())
39        {
40            let out_accum = if gate_model_id.is_some() {
41                Some(&mut out)
42            } else {
43                None
44            };
45            mlp.finish_training(out_accum);
46        }
47        if let Some(gate_model_id) = gate_model_id {
48            if !Path::new(&gate_model_id).exists() {
49                fs::create_dir_all(&gate_model_id)?;
50            }
51            let save_path = Path::new(&gate_model_id).join("gate.safetensors");
52            safetensors::save(&out, &save_path)?;
53            info!("Saved gating layers to `{}`", save_path.display());
54        }
55        Ok(())
56    }
57    fn trainable_params(&self) -> usize {
58        self.get_mlps()
59            .iter()
60            .filter(|mlp| mlp.is_moe_layer())
61            .map(|mlp| mlp.trainable_params())
62            .sum()
63    }
64    fn take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
65        self.get_mlps_mut()
66            .iter_mut()
67            .filter(|mlp| mlp.is_moe_layer())
68            .map(|mlp| mlp.take_cached_gating_output())
69            .collect::<Vec<_>>()
70    }
71
72    #[allow(clippy::too_many_arguments)]
73    fn create_anymoe_layers(
74        &mut self,
75        _additional_vbs: Vec<ShardedVarBuilder>,
76        _config: AnyMoeConfig,
77        (_prefix, _mlp): (String, String),
78        _layers: Vec<usize>,
79        _expert_type: AnyMoeExpertType,
80        _gate_vb: Option<ShardedVarBuilder>,
81    ) -> Result<()> {
82        candle_core::bail!("Model does not support AnyMoE layers");
83    }
84    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
85        panic!("Model does not support AnyMoE layers");
86    }
87    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
88        panic!("Model does not support AnyMoE layers");
89    }
90    fn amoe_supported(&self) -> bool {
91        false
92    }
93}
94
95pub trait MlpLayer: Send + Sync + AnyMoeTrainableLayer {
96    fn forward(&self, xs: &Tensor) -> Result<Tensor>;
97    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>>;
98    fn clone(&self) -> Box<dyn MlpLayer>;
99    /// WARNING: The deltas are not a struct but are instead assumed to
100    /// be correctly ordered! for that model and it's implementation details
101    fn get_params(&self) -> &[usize];
102    fn hidden_act(&self) -> Activation;
103    fn is_moe_layer(&self) -> bool {
104        false
105    }
106    /// This is for LoRA experts and completes the merging process.
107    /// WARNING: The deltas are not a struct but are instead assumed to
108    /// be correctly ordered! for that model and it's implementation details
109    fn new_added_delta(&self, _deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>>;
110    fn dtype_device(&self) -> (DType, Device);
111}
112
113pub trait AnyMoeTrainableLayer {
114    fn get_vars(&self) -> Vec<Var> {
115        vec![]
116    }
117    fn finish_training(&mut self, _out: Option<&mut HashMap<String, Tensor>>) {}
118    fn trainable_params(&self) -> usize {
119        0
120    }
121    fn take_cached_gating_output(&mut self) -> Tensor {
122        panic!("Gating output is not applicable to this layer.")
123    }
124}
125
126serde_default_fn!(f64, default_lr, 1e-3);
127serde_default_fn!(usize, default_epochs, 100);
128serde_default_fn!(usize, default_bs, 4);
129serde_default_fn!(bool, default_true, true);
130
131#[derive(Serialize, Deserialize, Clone, Debug)]
132pub enum AnyMoeExpertType {
133    #[serde(rename = "fine_tuned")]
134    FineTuned,
135    #[serde(rename = "lora_adapter")]
136    LoraAdapter {
137        rank: usize,
138        alpha: f64,
139        target_modules: Vec<String>,
140    },
141}
142
143#[derive(Serialize, Deserialize, Clone)]
144pub struct AnyMoeConfig {
145    pub hidden_size: usize,
146    #[serde(default = "default_lr")]
147    pub lr: f64,
148    #[serde(default = "default_epochs")]
149    pub epochs: usize,
150    #[serde(default = "default_bs")]
151    pub batch_size: usize,
152    pub expert_type: AnyMoeExpertType,
153    pub gate_model_id: Option<String>,
154    #[serde(default = "default_true")]
155    pub training: bool,
156    /// If `training == true`, `loss_csv_path` will not save anything.
157    /// Otherwise, this will save a .csv loss file here.
158    pub loss_csv_path: Option<String>,
159}
160
161#[derive(Clone)]
162pub struct MoeGate {
163    lin: Linear,
164}
165
166impl ModuleT for MoeGate {
167    fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
168        let hidden_states = xs.apply(&self.lin)?;
169        if train {
170            candle_nn::ops::softmax(&hidden_states, D::Minus1)
171        } else {
172            candle_nn::ops::softmax_last_dim(&hidden_states)
173        }
174    }
175}
176
177pub struct MoeMlp {
178    experts: Vec<Box<dyn MlpLayer>>,
179    gate: MoeGate,
180    training: bool,
181    vars: Vec<Var>,
182    gating_output: Arc<RwLock<Option<Tensor>>>,
183    layer_idx: usize,
184}
185
186impl MoeMlp {
187    /// Create a new MoeMlp layer. By default this is in training mode.
188    pub fn new(
189        experts: Vec<Box<dyn MlpLayer>>,
190        config: AnyMoeConfig,
191        dtype: DType,
192        dev: &Device,
193        layer: usize,
194        gate_vb: Option<&ShardedVarBuilder>,
195    ) -> Result<Self> {
196        let n_experts = experts.len();
197        let var_map = VarMap::new();
198
199        let inference = gate_vb.is_some();
200        let empty_map = ShardedSafeTensors::wrap(Box::new(var_map.clone()), dtype, dev.clone());
201        let vb = gate_vb.unwrap_or(&empty_map);
202        let vb = vb
203            .pp("moe_gate")
204            .pp(layer)
205            .set_device(dev.clone())
206            .set_dtype(dtype);
207
208        let lin = linear(config.hidden_size, n_experts, vb)?;
209
210        let vars = var_map.all_vars();
211        if vars.is_empty() && !inference {
212            candle_core::bail!("No vars to train in MoeMlp, perhaps there are no layers?");
213        }
214        Ok(Self {
215            experts,
216            gate: MoeGate { lin },
217            training: true,
218            vars,
219            gating_output: Arc::new(RwLock::new(None)),
220            layer_idx: layer,
221        })
222    }
223}
224
225impl AnyMoeTrainableLayer for MoeMlp {
226    fn finish_training(&mut self, out: Option<&mut HashMap<String, Tensor>>) {
227        self.training = false;
228        let w = self.gate.lin.weight().detach();
229        let b = self.gate.lin.bias().map(|b| b.detach());
230        self.gate = MoeGate {
231            lin: Linear::new(w.clone(), b.clone()),
232        };
233        if let Some(out) = out {
234            out.insert(format!("moe_gate.{}.weight", self.layer_idx), w);
235            if let Some(b) = b {
236                out.insert(format!("moe_gate.{}.bias", self.layer_idx), b);
237            }
238        }
239    }
240    fn trainable_params(&self) -> usize {
241        let mut sum = 0;
242        if self.gate.lin.weight().is_variable() {
243            sum += self.gate.lin.weight().elem_count();
244        }
245        if self.gate.lin.bias().as_ref().unwrap().is_variable() {
246            sum += self.gate.lin.bias().unwrap().elem_count();
247        }
248        sum
249    }
250    fn get_vars(&self) -> Vec<Var> {
251        self.vars.clone()
252    }
253    fn take_cached_gating_output(&mut self) -> Tensor {
254        self.gating_output.read().unwrap().clone().unwrap()
255    }
256}
257
258impl MlpLayer for MoeMlp {
259    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
260        // ^ [b, s, h]
261        let gate = self.gate.forward_t(xs, self.training)?;
262        // ^ [b, s, n_e]
263        // Mean across the sequence dimension
264        let gate = gate.mean(1)?;
265        // ^ [b, n_e]
266
267        // Gate with topk 1 to get the highest ranked expert
268        let TopKOutput { values: _, indices } = gate.topk(1)?;
269
270        if self.training {
271            *self.gating_output.write().unwrap() = Some(gate.clone());
272        }
273
274        let mut expert_outputs = Vec::new();
275        for expert in &self.experts {
276            expert_outputs.push(expert.forward(xs)?);
277        }
278        let stacked_outputs = Tensor::stack(&expert_outputs, 1)?;
279        // ^ [b, n_e s, h]
280        let (b, _e, s, h) = stacked_outputs.dims4()?;
281        let indices = indices.reshape((b, 1, 1, 1))?.expand((b, 1, s, h))?;
282        let gathered_outputs = stacked_outputs
283            .contiguous()?
284            .gather(&indices.contiguous()?, 1)?;
285        gathered_outputs.squeeze(1)
286    }
287
288    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
289        if self.training {
290            unreachable!("Should not be applying ISQ before training is complete.");
291        }
292
293        let mut accum = Vec::new();
294        for expert in &mut self.experts {
295            accum.extend(expert.get_isq_layers());
296        }
297        accum
298    }
299
300    fn clone(&self) -> Box<dyn MlpLayer> {
301        let mut experts = Vec::new();
302        for e in &self.experts {
303            experts.push((*e).clone());
304        }
305        Box::new(Self {
306            experts,
307            gate: self.gate.clone(),
308            training: self.training,
309            vars: self.vars.clone(),
310            gating_output: self.gating_output.clone(),
311            layer_idx: self.layer_idx,
312        })
313    }
314
315    fn get_params(&self) -> &[usize] {
316        self.experts[0].get_params()
317    }
318
319    fn hidden_act(&self) -> Activation {
320        self.experts[0].hidden_act()
321    }
322
323    fn is_moe_layer(&self) -> bool {
324        true
325    }
326
327    fn new_added_delta(&self, _deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
328        unreachable!()
329    }
330
331    fn dtype_device(&self) -> (DType, Device) {
332        (
333            self.gate.lin.weight().dtype(),
334            self.gate.lin.weight().device().clone(),
335        )
336    }
337}