1use std::{
2 collections::HashMap,
3 fs,
4 path::Path,
5 sync::{Arc, RwLock},
6};
7
8use candle_core::{safetensors, DType, Device, Result, Tensor, Var, D};
9use candle_nn::{Linear, ModuleT, VarMap};
10use mistralrs_quant::{QuantMethod, ShardedSafeTensors, ShardedVarBuilder};
11use serde::{Deserialize, Serialize};
12
13mod inputs;
14mod macros;
15pub use inputs::{AnyMoeTrainingInputRow, AnyMoeTrainingInputs, AnyMoeTrainingResult};
16use tracing::info;
17
18use crate::{
19 layers::{linear, Activation},
20 ops::{TopKLastDimOp, TopKOutput},
21 serde_default_fn,
22};
23
24pub trait AnyMoeBaseModelMixin {
26 fn get_vars(&self) -> Vec<Vec<Var>> {
27 self.get_mlps()
28 .iter()
29 .filter(|mlp| mlp.is_moe_layer())
30 .map(|mlp| mlp.get_vars())
31 .collect::<Vec<_>>()
32 }
33 fn finish_training(&mut self, gate_model_id: Option<String>) -> Result<()> {
34 let mut out = HashMap::new();
35 for mlp in self
36 .get_mlps_mut()
37 .iter_mut()
38 .filter(|mlp| mlp.is_moe_layer())
39 {
40 let out_accum = if gate_model_id.is_some() {
41 Some(&mut out)
42 } else {
43 None
44 };
45 mlp.finish_training(out_accum);
46 }
47 if let Some(gate_model_id) = gate_model_id {
48 if !Path::new(&gate_model_id).exists() {
49 fs::create_dir_all(&gate_model_id)?;
50 }
51 let save_path = Path::new(&gate_model_id).join("gate.safetensors");
52 safetensors::save(&out, &save_path)?;
53 info!("Saved gating layers to `{}`", save_path.display());
54 }
55 Ok(())
56 }
57 fn trainable_params(&self) -> usize {
58 self.get_mlps()
59 .iter()
60 .filter(|mlp| mlp.is_moe_layer())
61 .map(|mlp| mlp.trainable_params())
62 .sum()
63 }
64 fn take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
65 self.get_mlps_mut()
66 .iter_mut()
67 .filter(|mlp| mlp.is_moe_layer())
68 .map(|mlp| mlp.take_cached_gating_output())
69 .collect::<Vec<_>>()
70 }
71
72 #[allow(clippy::too_many_arguments)]
73 fn create_anymoe_layers(
74 &mut self,
75 _additional_vbs: Vec<ShardedVarBuilder>,
76 _config: AnyMoeConfig,
77 (_prefix, _mlp): (String, String),
78 _layers: Vec<usize>,
79 _expert_type: AnyMoeExpertType,
80 _gate_vb: Option<ShardedVarBuilder>,
81 ) -> Result<()> {
82 candle_core::bail!("Model does not support AnyMoE layers");
83 }
84 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
85 panic!("Model does not support AnyMoE layers");
86 }
87 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
88 panic!("Model does not support AnyMoE layers");
89 }
90 fn amoe_supported(&self) -> bool {
91 false
92 }
93}
94
95pub trait MlpLayer: Send + Sync + AnyMoeTrainableLayer {
96 fn forward(&self, xs: &Tensor) -> Result<Tensor>;
97 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>>;
98 fn clone(&self) -> Box<dyn MlpLayer>;
99 fn get_params(&self) -> &[usize];
102 fn hidden_act(&self) -> Activation;
103 fn is_moe_layer(&self) -> bool {
104 false
105 }
106 fn new_added_delta(&self, _deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>>;
110 fn dtype_device(&self) -> (DType, Device);
111}
112
113pub trait AnyMoeTrainableLayer {
114 fn get_vars(&self) -> Vec<Var> {
115 vec![]
116 }
117 fn finish_training(&mut self, _out: Option<&mut HashMap<String, Tensor>>) {}
118 fn trainable_params(&self) -> usize {
119 0
120 }
121 fn take_cached_gating_output(&mut self) -> Tensor {
122 panic!("Gating output is not applicable to this layer.")
123 }
124}
125
126serde_default_fn!(f64, default_lr, 1e-3);
127serde_default_fn!(usize, default_epochs, 100);
128serde_default_fn!(usize, default_bs, 4);
129serde_default_fn!(bool, default_true, true);
130
131#[derive(Serialize, Deserialize, Clone, Debug)]
132pub enum AnyMoeExpertType {
133 #[serde(rename = "fine_tuned")]
134 FineTuned,
135 #[serde(rename = "lora_adapter")]
136 LoraAdapter {
137 rank: usize,
138 alpha: f64,
139 target_modules: Vec<String>,
140 },
141}
142
143#[derive(Serialize, Deserialize, Clone)]
144pub struct AnyMoeConfig {
145 pub hidden_size: usize,
146 #[serde(default = "default_lr")]
147 pub lr: f64,
148 #[serde(default = "default_epochs")]
149 pub epochs: usize,
150 #[serde(default = "default_bs")]
151 pub batch_size: usize,
152 pub expert_type: AnyMoeExpertType,
153 pub gate_model_id: Option<String>,
154 #[serde(default = "default_true")]
155 pub training: bool,
156 pub loss_csv_path: Option<String>,
159}
160
161#[derive(Clone)]
162pub struct MoeGate {
163 lin: Linear,
164}
165
166impl ModuleT for MoeGate {
167 fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
168 let hidden_states = xs.apply(&self.lin)?;
169 if train {
170 candle_nn::ops::softmax(&hidden_states, D::Minus1)
171 } else {
172 candle_nn::ops::softmax_last_dim(&hidden_states)
173 }
174 }
175}
176
177pub struct MoeMlp {
178 experts: Vec<Box<dyn MlpLayer>>,
179 gate: MoeGate,
180 training: bool,
181 vars: Vec<Var>,
182 gating_output: Arc<RwLock<Option<Tensor>>>,
183 layer_idx: usize,
184}
185
186impl MoeMlp {
187 pub fn new(
189 experts: Vec<Box<dyn MlpLayer>>,
190 config: AnyMoeConfig,
191 dtype: DType,
192 dev: &Device,
193 layer: usize,
194 gate_vb: Option<&ShardedVarBuilder>,
195 ) -> Result<Self> {
196 let n_experts = experts.len();
197 let var_map = VarMap::new();
198
199 let inference = gate_vb.is_some();
200 let empty_map = ShardedSafeTensors::wrap(Box::new(var_map.clone()), dtype, dev.clone());
201 let vb = gate_vb.unwrap_or(&empty_map);
202 let vb = vb
203 .pp("moe_gate")
204 .pp(layer)
205 .set_device(dev.clone())
206 .set_dtype(dtype);
207
208 let lin = linear(config.hidden_size, n_experts, vb)?;
209
210 let vars = var_map.all_vars();
211 if vars.is_empty() && !inference {
212 candle_core::bail!("No vars to train in MoeMlp, perhaps there are no layers?");
213 }
214 Ok(Self {
215 experts,
216 gate: MoeGate { lin },
217 training: true,
218 vars,
219 gating_output: Arc::new(RwLock::new(None)),
220 layer_idx: layer,
221 })
222 }
223}
224
225impl AnyMoeTrainableLayer for MoeMlp {
226 fn finish_training(&mut self, out: Option<&mut HashMap<String, Tensor>>) {
227 self.training = false;
228 let w = self.gate.lin.weight().detach();
229 let b = self.gate.lin.bias().map(|b| b.detach());
230 self.gate = MoeGate {
231 lin: Linear::new(w.clone(), b.clone()),
232 };
233 if let Some(out) = out {
234 out.insert(format!("moe_gate.{}.weight", self.layer_idx), w);
235 if let Some(b) = b {
236 out.insert(format!("moe_gate.{}.bias", self.layer_idx), b);
237 }
238 }
239 }
240 fn trainable_params(&self) -> usize {
241 let mut sum = 0;
242 if self.gate.lin.weight().is_variable() {
243 sum += self.gate.lin.weight().elem_count();
244 }
245 if self.gate.lin.bias().as_ref().unwrap().is_variable() {
246 sum += self.gate.lin.bias().unwrap().elem_count();
247 }
248 sum
249 }
250 fn get_vars(&self) -> Vec<Var> {
251 self.vars.clone()
252 }
253 fn take_cached_gating_output(&mut self) -> Tensor {
254 self.gating_output.read().unwrap().clone().unwrap()
255 }
256}
257
258impl MlpLayer for MoeMlp {
259 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
260 let gate = self.gate.forward_t(xs, self.training)?;
262 let gate = gate.mean(1)?;
265 let TopKOutput { values: _, indices } = gate.topk(1)?;
269
270 if self.training {
271 *self.gating_output.write().unwrap() = Some(gate.clone());
272 }
273
274 let mut expert_outputs = Vec::new();
275 for expert in &self.experts {
276 expert_outputs.push(expert.forward(xs)?);
277 }
278 let stacked_outputs = Tensor::stack(&expert_outputs, 1)?;
279 let (b, _e, s, h) = stacked_outputs.dims4()?;
281 let indices = indices.reshape((b, 1, 1, 1))?.expand((b, 1, s, h))?;
282 let gathered_outputs = stacked_outputs
283 .contiguous()?
284 .gather(&indices.contiguous()?, 1)?;
285 gathered_outputs.squeeze(1)
286 }
287
288 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
289 if self.training {
290 unreachable!("Should not be applying ISQ before training is complete.");
291 }
292
293 let mut accum = Vec::new();
294 for expert in &mut self.experts {
295 accum.extend(expert.get_isq_layers());
296 }
297 accum
298 }
299
300 fn clone(&self) -> Box<dyn MlpLayer> {
301 let mut experts = Vec::new();
302 for e in &self.experts {
303 experts.push((*e).clone());
304 }
305 Box::new(Self {
306 experts,
307 gate: self.gate.clone(),
308 training: self.training,
309 vars: self.vars.clone(),
310 gating_output: self.gating_output.clone(),
311 layer_idx: self.layer_idx,
312 })
313 }
314
315 fn get_params(&self) -> &[usize] {
316 self.experts[0].get_params()
317 }
318
319 fn hidden_act(&self) -> Activation {
320 self.experts[0].hidden_act()
321 }
322
323 fn is_moe_layer(&self) -> bool {
324 true
325 }
326
327 fn new_added_delta(&self, _deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
328 unreachable!()
329 }
330
331 fn dtype_device(&self) -> (DType, Device) {
332 (
333 self.gate.lin.weight().dtype(),
334 self.gate.lin.weight().device().clone(),
335 )
336 }
337}