1use std::{
2 any::Any,
3 fs::{self, File},
4 io::Read,
5 path::Path,
6 sync::Arc,
7};
8
9use base64::{engine::general_purpose, Engine};
10use candle_core::{DType, Device, Tensor};
11use candle_nn::{AdamW, Optimizer, ParamsAdamW};
12use either::Either;
13use image::DynamicImage;
14use indexmap::IndexMap;
15use indicatif::MultiProgress;
16use mistralrs_quant::IsqType;
17use rand::{rng, seq::SliceRandom};
18use rand_isaac::Isaac64Rng;
19use tracing::{info, warn};
20
21use crate::{
22 amoe::{AnyMoeConfig, AnyMoeTrainingInputRow, AnyMoeTrainingInputs, AnyMoeTrainingResult},
23 device_map::DeviceMapper,
24 get_mut_arcmutex,
25 prefix_cacher::PrefixCacheManagerV2,
26 sampler::Sampler,
27 sequence::{SeqStepType, Sequence, SequenceGroup, SequenceRecognizer},
28 utils::progress::NiceProgressBar,
29 DeviceMapSetting, Loader, ModelCategory, ModelKind, ModelPaths, PagedAttentionConfig, Pipeline,
30 Response, TokenSource, TryIntoDType,
31};
32
33use super::{
34 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
35 MetadataMixin, PreProcessingMixin,
36};
37
38pub struct AnyMoeLoader {
39 pub target: Box<dyn Loader>,
40 pub config: AnyMoeConfig,
41 pub path: String,
42 pub prefix: String,
43 pub mlp: String,
44 pub model_ids: Vec<String>,
45 pub layers: Vec<usize>,
46}
47
48pub struct AnyMoePipeline {
49 target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
50 config: AnyMoeConfig,
51}
52
53impl Loader for AnyMoeLoader {
54 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
55 fn load_model_from_hf(
56 &self,
57 revision: Option<String>,
58 token_source: TokenSource,
59 dtype: &dyn TryIntoDType,
60 device: &Device,
61 silent: bool,
62 mapper: DeviceMapSetting,
63 in_situ_quant: Option<IsqType>,
64 paged_attn_config: Option<PagedAttentionConfig>,
65 ) -> anyhow::Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
66 let paged_attn_config = if paged_attn_config.is_none() {
67 warn!("AnyMoE does not currently support PagedAttention, running without");
68 None
69 } else {
70 paged_attn_config
71 };
72
73 let target = self.target.load_model_from_hf(
74 revision.clone(),
75 token_source.clone(),
76 dtype,
77 device,
78 silent,
79 mapper.clone(),
80 in_situ_quant,
81 paged_attn_config,
82 )?;
83 Ok(Arc::new(tokio::sync::Mutex::new(AnyMoePipeline::new(
84 target,
85 self.config.clone(),
86 AnyMoeTrainingInputs::from_json(&self.path)?,
87 self.prefix.clone(),
88 self.mlp.clone(),
89 self.model_ids.clone(),
90 token_source,
91 revision,
92 self.layers.clone(),
93 silent,
94 )?)))
95 }
96
97 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
98 fn load_model_from_path(
99 &self,
100 paths: &Box<dyn ModelPaths>,
101 dtype: &dyn TryIntoDType,
102 device: &Device,
103 silent: bool,
104 mapper: DeviceMapSetting,
105 in_situ_quant: Option<IsqType>,
106 paged_attn_config: Option<PagedAttentionConfig>,
107 ) -> anyhow::Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
108 let paged_attn_config = if paged_attn_config.is_none() {
109 warn!("AnyMoE does not currently support PagedAttention, running without");
110 None
111 } else {
112 paged_attn_config
113 };
114
115 let target = self.target.load_model_from_path(
116 paths,
117 dtype,
118 device,
119 silent,
120 mapper.clone(),
121 in_situ_quant,
122 paged_attn_config,
123 )?;
124 Ok(Arc::new(tokio::sync::Mutex::new(AnyMoePipeline::new(
125 target,
126 self.config.clone(),
127 AnyMoeTrainingInputs::from_json(&self.path)?,
128 self.prefix.clone(),
129 self.mlp.clone(),
130 self.model_ids.clone(),
131 TokenSource::None,
132 None,
133 self.layers.clone(),
134 silent,
135 )?)))
136 }
137 fn get_id(&self) -> String {
138 format!("AnyMoE: tgt = `{}`", self.target.get_id(),)
139 }
140 fn get_kind(&self) -> ModelKind {
141 ModelKind::AnyMoe {
142 target: Box::new(self.target.get_kind()),
143 }
144 }
145}
146
147impl AnyMoePipeline {
148 #[allow(clippy::too_many_arguments)]
149 pub fn new(
150 target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
151 config: AnyMoeConfig,
152 inputs: AnyMoeTrainingInputs,
153 prefix: String,
154 mlp: String,
155 model_ids: Vec<String>,
156 token: TokenSource,
157 revision: Option<String>,
158 layers: Vec<usize>,
159 silent: bool,
160 ) -> anyhow::Result<Self> {
161 let this = Self { target, config };
162 info!("Loaded pretraining dataset of {} samples.", inputs.len());
163 match this.amoe_pre_train(
164 inputs,
165 (prefix, mlp),
166 model_ids,
167 token,
168 revision,
169 layers,
170 silent,
171 )? {
172 Some(AnyMoeTrainingResult { steps, final_loss }) => {
173 info!("Finished training in {steps} steps. Final losses per layer: {final_loss:?}")
174 }
175 None => {
176 info!("Not training gating layer, using trained gating layer specified in config")
177 }
178 }
179 Ok(this)
180 }
181}
182
183impl CacheManagerMixin for AnyMoePipeline {
184 fn cache(&self) -> &EitherCache {
185 unreachable!()
186 }
187 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
188 get_mut_arcmutex!(self.target).clone_in_cache(seqs)
189 }
190 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
191 get_mut_arcmutex!(self.target).clone_out_cache(seqs)
192 }
193 fn set_none_cache(
194 &self,
195 seqs: &mut [&mut Sequence],
196 reset_non_granular: bool,
197 modify_draft_cache: bool,
198 load_preallocated_cache: bool,
199 ) {
200 get_mut_arcmutex!(self.target).set_none_cache(
201 seqs,
202 reset_non_granular,
203 modify_draft_cache,
204 load_preallocated_cache,
205 )
206 }
207}
208
209impl IsqPipelineMixin for AnyMoePipeline {
210 fn re_isq_model(&mut self, dtype: IsqType) -> anyhow::Result<()> {
211 get_mut_arcmutex!(self.target).re_isq_model(dtype)
212 }
213}
214
215impl PreProcessingMixin for AnyMoePipeline {
216 fn get_chat_template(&self) -> Option<Arc<crate::ChatTemplate>> {
217 get_mut_arcmutex!(self.target).get_chat_template()
218 }
219 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
220 get_mut_arcmutex!(self.target).get_input_processor_config()
221 }
222 fn get_processor(&self) -> Arc<dyn super::Processor> {
223 get_mut_arcmutex!(self.target).get_processor()
224 }
225}
226
227impl MetadataMixin for AnyMoePipeline {
228 fn device(&self) -> Device {
229 get_mut_arcmutex!(self.target).device()
230 }
231 fn get_metadata(&self) -> Arc<super::GeneralMetadata> {
232 get_mut_arcmutex!(self.target).get_metadata()
233 }
234 fn name(&self) -> String {
235 get_mut_arcmutex!(self.target).name()
236 }
237 fn reset_non_granular_state(&self) {
238 get_mut_arcmutex!(self.target).reset_non_granular_state()
239 }
240 fn tokenizer(&self) -> Option<Arc<tokenizers::Tokenizer>> {
241 get_mut_arcmutex!(self.target).tokenizer()
242 }
243 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
244 None
245 }
246}
247
248#[async_trait::async_trait]
249impl Pipeline for AnyMoePipeline {
250 fn forward_inputs(
251 &mut self,
252 inputs: Box<dyn Any>,
253 _return_raw_logits: bool,
254 ) -> Result<ForwardInputsResult, candle_core::Error> {
255 get_mut_arcmutex!(self.target).forward_inputs(inputs, false)
256 }
257
258 async fn sample_causal_gen(
259 &self,
260 seqs: &mut [&mut Sequence],
261 logits: Vec<Tensor>,
262 prefix_cacher: &mut PrefixCacheManagerV2,
263 disable_eos_stop: bool,
264 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
265 ) -> Result<(), candle_core::Error> {
266 get_mut_arcmutex!(self.target)
267 .sample_causal_gen(seqs, logits, prefix_cacher, disable_eos_stop, rng)
268 .await
269 }
270
271 fn category(&self) -> ModelCategory {
272 get_mut_arcmutex!(self.target).category()
273 }
274}
275
276impl AnyMoePipelineMixin for AnyMoePipeline {
277 fn amoe_pre_train(
279 &self,
280 inputs: AnyMoeTrainingInputs,
281 (prefix, mlp): (String, String),
282 model_ids: Vec<String>,
283 token: TokenSource,
284 revision: Option<String>,
285 layers: Vec<usize>,
286 silent: bool,
287 ) -> anyhow::Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
288 let mut target = get_mut_arcmutex!(self.target);
289 if !target.amoe_supported() {
290 candle_core::bail!("AnyMoE is not supported for this model.");
291 }
292
293 let device = target.device();
294 let processor = target.get_processor();
295 let inputs_processor = target.get_processor().inputs_processor();
296 let tokenizer = target.tokenizer();
297 let metadata = target.get_metadata().clone();
298 let input_processor_cfg = target.get_input_processor_config().clone();
299
300 let AnyMoeConfig {
301 hidden_size: _,
302 lr,
303 epochs,
304 batch_size,
305 expert_type,
306 gate_model_id,
307 training,
308 loss_csv_path,
309 } = self.config.clone();
310 let mut steps = 0;
311
312 info!("Expert type: {expert_type:?}");
313 info!("Expert model ids: {model_ids:?}");
314
315 target.amoe_create_layers(
317 model_ids,
318 &token,
319 revision,
320 &mlp.clone(),
321 self.config.clone(),
322 metadata.activation_dtype,
323 &device,
324 (prefix, mlp),
325 layers,
326 expert_type,
327 silent,
328 if !training {
329 gate_model_id.clone()
330 } else {
331 None
332 },
333 )?;
334 let layer_vars = target.amoe_layer_vars();
335
336 if target.amoe_base_model_trainable_params() == 0 {
338 return Ok(None);
339 }
340
341 info!(
342 "{} gating layers, {} trainable parameters, lr = {lr}, {epochs} epochs, batch size = {batch_size}",
343 layer_vars.len(),
344 target.amoe_base_model_trainable_params()
345 );
346
347 let mut optimizers = layer_vars
348 .into_iter()
349 .map(|vars| {
350 AdamW::new(
351 vars,
352 ParamsAdamW {
353 lr,
354 beta1: 0.9,
355 beta2: 0.999,
356 eps: 1e-8,
357 weight_decay: 0.0,
358 },
359 )
360 })
361 .collect::<candle_core::Result<Vec<_>>>()?;
362
363 let mut rng = rng();
364 let mut samples = inputs.into_inner();
365
366 let (dummy_sender, _) = tokio::sync::mpsc::channel(10000);
368 let dummy_sampler = Sampler::new(
369 None,
370 0,
371 tokenizer.clone(),
372 None,
373 None,
374 None,
375 -1,
376 0.0,
377 0.0,
378 vec![],
379 )
380 .map_err(candle_core::Error::msg)?;
381
382 let dummy_group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
383 1, false, false, None,
384 )));
385
386 let mut latest_loss = vec![0.0; optimizers.len()];
387 let mut all_losses = Vec::new();
388
389 for _ in
390 NiceProgressBar::<_, 'g'>(0..epochs, "Training gating layers", &MultiProgress::new())
391 {
392 samples.as_mut_slice().shuffle(&mut rng);
393 for batch in samples.chunks(batch_size) {
394 steps += 1;
395
396 let mut seqs = Vec::new();
398 for AnyMoeTrainingInputRow {
399 prompt,
400 expert: _,
401 image_urls,
402 } in batch
403 {
404 let tokens = processor
405 .process(
406 &*target,
407 vec![IndexMap::from([
408 ("role".to_string(), Either::Left("user".to_string())),
409 ("content".to_string(), Either::Left(prompt.clone())),
410 ])],
411 true,
412 true,
413 Vec::new(),
414 )
415 .map_err(candle_core::Error::msg)?;
416 let images = image_urls.as_ref().map(|urls| {
417 urls.iter()
418 .map(|url| -> anyhow::Result<DynamicImage> {
419 let bytes = if url.contains("http") {
420 match reqwest::blocking::get(url.clone()) {
422 Ok(http_resp) => http_resp.bytes()?.to_vec(),
423 Err(e) => anyhow::bail!(e),
424 }
425 } else if let Ok(mut f) = File::open(url) {
426 let metadata = fs::metadata(url)?;
428 #[allow(clippy::cast_possible_truncation)]
429 let mut buffer = vec![0; metadata.len() as usize];
430 f.read_exact(&mut buffer)?;
431 buffer
432 } else {
433 general_purpose::STANDARD.decode(url)?
435 };
436 Ok(image::load_from_memory(&bytes)?)
437 })
438 .collect::<anyhow::Result<Vec<_>>>()
439 });
440 let images = match images {
441 Some(Ok(x)) => Some(x),
442 Some(Err(e)) => {
443 return anyhow::Result::Err(candle_core::Error::Msg(e.to_string()))
444 }
445 None => None,
446 };
447 seqs.push(new_dummy_seq(
448 tokens,
449 dummy_sender.clone(),
450 dummy_sampler.clone(),
451 dummy_group.clone(),
452 images,
453 target.get_metadata().eos_tok.clone(),
454 ));
455 }
456 let mut input_seqs = seqs.iter_mut().collect::<Vec<_>>();
457
458 target.set_none_cache(&mut input_seqs, true, true, false);
460
461 let inputs = inputs_processor
462 .process_inputs(
463 tokenizer.clone(),
464 &mut input_seqs,
465 true, metadata.is_xlora,
467 &device,
468 metadata.no_kv_cache,
469 None,
470 false,
471 input_processor_cfg.clone(),
472 None, None, None,
475 )
476 .nth(0)
477 .unwrap();
478
479 let _ = target.forward_inputs(inputs.unwrap().inputs, false)?;
483
484 target.set_none_cache(&mut input_seqs, true, true, false);
486
487 #[allow(clippy::cast_possible_truncation)]
489 let labels = Tensor::from_vec(
490 batch
491 .iter()
492 .map(
493 |AnyMoeTrainingInputRow {
494 prompt: _,
495 expert,
496 image_urls: _,
497 }| *expert as u32,
498 )
499 .collect::<Vec<_>>(),
500 (batch.len(),),
501 &device,
502 )?;
503
504 let cached = target.amoe_take_cached_gating_outputs();
505 for (layer, (optimizer, output)) in optimizers.iter_mut().zip(cached).enumerate() {
506 let loss = candle_nn::loss::cross_entropy(
507 &output,
508 &labels.to_device(output.device())?,
509 )?;
510 let gradstore = loss.backward()?;
511 optimizer.step(&gradstore)?;
512 latest_loss[layer] = loss.to_dtype(DType::F32)?.to_scalar::<f32>()?;
513 }
514 all_losses.push(latest_loss.clone());
515 }
516 }
517
518 target.amoe_finish_training(gate_model_id)?;
519 assert_eq!(target.amoe_base_model_trainable_params(), 0);
520
521 if let Some(loss_csv_path) = loss_csv_path {
522 let path = Path::new(&loss_csv_path);
523 if path
524 .extension()
525 .is_none_or(|e| e.to_string_lossy() != *"csv")
526 {
527 candle_core::bail!("`loss_csv_path` must have an extension `csv`.");
528 }
529
530 let mut writer = csv::Writer::from_path(path).map_err(candle_core::Error::msg)?;
531
532 let mut header = vec![format!("Step")];
533 header.extend((0..all_losses[0].len()).map(|i| format!("Gating layer {i}")));
534 writer
535 .write_record(&header)
536 .map_err(candle_core::Error::msg)?;
537
538 for (i, row) in all_losses.into_iter().enumerate() {
539 let mut new_row = vec![format!("Step {i}")];
540 new_row.extend(row.iter().map(|x| format!("{x:.4}")));
541 writer
542 .write_record(&new_row)
543 .map_err(candle_core::Error::msg)?;
544 }
545
546 writer.flush().map_err(candle_core::Error::msg)?;
547 }
548
549 Ok(Some(AnyMoeTrainingResult {
550 steps,
551 final_loss: latest_loss,
552 }))
553 }
554}
555
556fn new_dummy_seq(
559 (tokens, prompt): (Vec<u32>, String),
560 dummy_sender: tokio::sync::mpsc::Sender<Response>,
561 dummy_sampler: Sampler,
562 dummy_group: Arc<tokio::sync::Mutex<SequenceGroup>>,
563 images: Option<Vec<DynamicImage>>,
564 eos_toks: Vec<u32>,
565) -> Sequence {
566 Sequence::new_waiting(
567 tokens,
568 prompt,
569 0,
570 0,
571 1,
572 dummy_sender,
573 dummy_sampler,
574 vec![],
575 vec![],
576 None,
577 false,
578 false,
579 dummy_group,
580 0,
581 0,
582 SequenceRecognizer::None,
583 None,
584 None,
585 images,
586 None, None,
588 None,
589 SeqStepType::PromptAndDecode,
590 None,
591 None,
592 false,
593 eos_toks,
594 )
595}