1use std::{
2 any::Any,
3 fs::{self, File},
4 io::Read,
5 path::Path,
6 sync::Arc,
7};
8
9use base64::{engine::general_purpose, Engine};
10use candle_core::{DType, Device, Tensor};
11use candle_nn::{AdamW, Optimizer, ParamsAdamW};
12use either::Either;
13use image::DynamicImage;
14use indexmap::IndexMap;
15use mistralrs_quant::IsqType;
16use rand::{rng, seq::SliceRandom};
17use rand_isaac::Isaac64Rng;
18use tracing::{info, warn};
19
20use crate::{
21 amoe::{AnyMoeConfig, AnyMoeTrainingInputRow, AnyMoeTrainingInputs, AnyMoeTrainingResult},
22 device_map::DeviceMapper,
23 get_mut_arcmutex,
24 prefix_cacher::PrefixCacheManagerV2,
25 sampler::Sampler,
26 sequence::{SeqStepType, Sequence, SequenceGroup, SequenceRecognizer},
27 utils::progress::{new_multi_progress, NiceProgressBar, ProgressScopeGuard},
28 DeviceMapSetting, Loader, ModelCategory, ModelKind, ModelPaths, PagedAttentionConfig, Pipeline,
29 Response, TokenSource, TryIntoDType,
30};
31
32use super::{
33 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
34 MetadataMixin, PreProcessingMixin,
35};
36
37pub struct AnyMoeLoader {
38 pub target: Box<dyn Loader>,
39 pub config: AnyMoeConfig,
40 pub path: String,
41 pub prefix: String,
42 pub mlp: String,
43 pub model_ids: Vec<String>,
44 pub layers: Vec<usize>,
45}
46
47pub struct AnyMoePipeline {
48 target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
49 config: AnyMoeConfig,
50}
51
52impl Loader for AnyMoeLoader {
53 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
54 fn load_model_from_hf(
55 &self,
56 revision: Option<String>,
57 token_source: TokenSource,
58 dtype: &dyn TryIntoDType,
59 device: &Device,
60 silent: bool,
61 mapper: DeviceMapSetting,
62 in_situ_quant: Option<IsqType>,
63 paged_attn_config: Option<PagedAttentionConfig>,
64 ) -> anyhow::Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
65 let _progress_guard = ProgressScopeGuard::new(silent);
66 let paged_attn_config = if paged_attn_config.is_none() {
67 warn!("AnyMoE does not currently support PagedAttention, running without");
68 None
69 } else {
70 paged_attn_config
71 };
72
73 let target = self.target.load_model_from_hf(
74 revision.clone(),
75 token_source.clone(),
76 dtype,
77 device,
78 silent,
79 mapper.clone(),
80 in_situ_quant,
81 paged_attn_config,
82 )?;
83 Ok(Arc::new(tokio::sync::Mutex::new(AnyMoePipeline::new(
84 target,
85 self.config.clone(),
86 AnyMoeTrainingInputs::from_json(&self.path)?,
87 self.prefix.clone(),
88 self.mlp.clone(),
89 self.model_ids.clone(),
90 token_source,
91 revision,
92 self.layers.clone(),
93 silent,
94 )?)))
95 }
96
97 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
98 fn load_model_from_path(
99 &self,
100 paths: &Box<dyn ModelPaths>,
101 dtype: &dyn TryIntoDType,
102 device: &Device,
103 silent: bool,
104 mapper: DeviceMapSetting,
105 in_situ_quant: Option<IsqType>,
106 paged_attn_config: Option<PagedAttentionConfig>,
107 ) -> anyhow::Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
108 let _progress_guard = ProgressScopeGuard::new(silent);
109 let paged_attn_config = if paged_attn_config.is_none() {
110 warn!("AnyMoE does not currently support PagedAttention, running without");
111 None
112 } else {
113 paged_attn_config
114 };
115
116 let target = self.target.load_model_from_path(
117 paths,
118 dtype,
119 device,
120 silent,
121 mapper.clone(),
122 in_situ_quant,
123 paged_attn_config,
124 )?;
125 Ok(Arc::new(tokio::sync::Mutex::new(AnyMoePipeline::new(
126 target,
127 self.config.clone(),
128 AnyMoeTrainingInputs::from_json(&self.path)?,
129 self.prefix.clone(),
130 self.mlp.clone(),
131 self.model_ids.clone(),
132 TokenSource::None,
133 None,
134 self.layers.clone(),
135 silent,
136 )?)))
137 }
138 fn get_id(&self) -> String {
139 format!("AnyMoE: tgt = `{}`", self.target.get_id(),)
140 }
141 fn get_kind(&self) -> ModelKind {
142 ModelKind::AnyMoe {
143 target: Box::new(self.target.get_kind()),
144 }
145 }
146}
147
148impl AnyMoePipeline {
149 #[allow(clippy::too_many_arguments)]
150 pub fn new(
151 target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
152 config: AnyMoeConfig,
153 inputs: AnyMoeTrainingInputs,
154 prefix: String,
155 mlp: String,
156 model_ids: Vec<String>,
157 token: TokenSource,
158 revision: Option<String>,
159 layers: Vec<usize>,
160 silent: bool,
161 ) -> anyhow::Result<Self> {
162 let this = Self { target, config };
163 info!("Loaded pretraining dataset of {} samples.", inputs.len());
164 match this.amoe_pre_train(
165 inputs,
166 (prefix, mlp),
167 model_ids,
168 token,
169 revision,
170 layers,
171 silent,
172 )? {
173 Some(AnyMoeTrainingResult { steps, final_loss }) => {
174 info!("Finished training in {steps} steps. Final losses per layer: {final_loss:?}")
175 }
176 None => {
177 info!("Not training gating layer, using trained gating layer specified in config")
178 }
179 }
180 Ok(this)
181 }
182}
183
184impl CacheManagerMixin for AnyMoePipeline {
185 fn cache(&self) -> &EitherCache {
186 unreachable!()
187 }
188 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
189 get_mut_arcmutex!(self.target).clone_in_cache(seqs)
190 }
191 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
192 get_mut_arcmutex!(self.target).clone_out_cache(seqs)
193 }
194 fn set_none_cache(
195 &self,
196 seqs: &mut [&mut Sequence],
197 reset_non_granular: bool,
198 modify_draft_cache: bool,
199 load_preallocated_cache: bool,
200 ) {
201 get_mut_arcmutex!(self.target).set_none_cache(
202 seqs,
203 reset_non_granular,
204 modify_draft_cache,
205 load_preallocated_cache,
206 )
207 }
208}
209
210impl IsqPipelineMixin for AnyMoePipeline {
211 fn re_isq_model(&mut self, dtype: IsqType) -> anyhow::Result<()> {
212 get_mut_arcmutex!(self.target).re_isq_model(dtype)
213 }
214}
215
216impl PreProcessingMixin for AnyMoePipeline {
217 fn get_chat_template(&self) -> Option<Arc<crate::ChatTemplate>> {
218 get_mut_arcmutex!(self.target).get_chat_template()
219 }
220 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
221 get_mut_arcmutex!(self.target).get_input_processor_config()
222 }
223 fn get_processor(&self) -> Arc<dyn super::Processor> {
224 get_mut_arcmutex!(self.target).get_processor()
225 }
226}
227
228impl MetadataMixin for AnyMoePipeline {
229 fn device(&self) -> Device {
230 get_mut_arcmutex!(self.target).device()
231 }
232 fn get_metadata(&self) -> Arc<super::GeneralMetadata> {
233 get_mut_arcmutex!(self.target).get_metadata()
234 }
235 fn name(&self) -> String {
236 get_mut_arcmutex!(self.target).name()
237 }
238 fn reset_non_granular_state(&self) {
239 get_mut_arcmutex!(self.target).reset_non_granular_state()
240 }
241 fn tokenizer(&self) -> Option<Arc<tokenizers::Tokenizer>> {
242 get_mut_arcmutex!(self.target).tokenizer()
243 }
244 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
245 None
246 }
247}
248
249#[async_trait::async_trait]
250impl Pipeline for AnyMoePipeline {
251 fn forward_inputs(
252 &mut self,
253 inputs: Box<dyn Any>,
254 _return_raw_logits: bool,
255 ) -> Result<ForwardInputsResult, candle_core::Error> {
256 get_mut_arcmutex!(self.target).forward_inputs(inputs, false)
257 }
258
259 async fn sample_causal_gen(
260 &self,
261 seqs: &mut [&mut Sequence],
262 logits: Vec<Tensor>,
263 prefix_cacher: &mut PrefixCacheManagerV2,
264 disable_eos_stop: bool,
265 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
266 ) -> Result<(), candle_core::Error> {
267 get_mut_arcmutex!(self.target)
268 .sample_causal_gen(seqs, logits, prefix_cacher, disable_eos_stop, rng)
269 .await
270 }
271
272 fn category(&self) -> ModelCategory {
273 get_mut_arcmutex!(self.target).category()
274 }
275}
276
277impl AnyMoePipelineMixin for AnyMoePipeline {
278 fn amoe_pre_train(
280 &self,
281 inputs: AnyMoeTrainingInputs,
282 (prefix, mlp): (String, String),
283 model_ids: Vec<String>,
284 token: TokenSource,
285 revision: Option<String>,
286 layers: Vec<usize>,
287 silent: bool,
288 ) -> anyhow::Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
289 let mut target = get_mut_arcmutex!(self.target);
290 if !target.amoe_supported() {
291 candle_core::bail!("AnyMoE is not supported for this model.");
292 }
293
294 let device = target.device();
295 let processor = target.get_processor();
296 let inputs_processor = target.get_processor().inputs_processor();
297 let tokenizer = target.tokenizer();
298 let metadata = target.get_metadata().clone();
299 let input_processor_cfg = target.get_input_processor_config().clone();
300
301 let AnyMoeConfig {
302 hidden_size: _,
303 lr,
304 epochs,
305 batch_size,
306 expert_type,
307 gate_model_id,
308 training,
309 loss_csv_path,
310 } = self.config.clone();
311 let mut steps = 0;
312
313 info!("Expert type: {expert_type:?}");
314 info!("Expert model ids: {model_ids:?}");
315
316 target.amoe_create_layers(
318 model_ids,
319 &token,
320 revision,
321 &mlp.clone(),
322 self.config.clone(),
323 metadata.activation_dtype,
324 &device,
325 (prefix, mlp),
326 layers,
327 expert_type,
328 silent,
329 if !training {
330 gate_model_id.clone()
331 } else {
332 None
333 },
334 )?;
335 let layer_vars = target.amoe_layer_vars();
336
337 if target.amoe_base_model_trainable_params() == 0 {
339 return Ok(None);
340 }
341
342 info!(
343 "{} gating layers, {} trainable parameters, lr = {lr}, {epochs} epochs, batch size = {batch_size}",
344 layer_vars.len(),
345 target.amoe_base_model_trainable_params()
346 );
347
348 let mut optimizers = layer_vars
349 .into_iter()
350 .map(|vars| {
351 AdamW::new(
352 vars,
353 ParamsAdamW {
354 lr,
355 beta1: 0.9,
356 beta2: 0.999,
357 eps: 1e-8,
358 weight_decay: 0.0,
359 },
360 )
361 })
362 .collect::<candle_core::Result<Vec<_>>>()?;
363
364 let mut rng = rng();
365 let mut samples = inputs.into_inner();
366
367 let (dummy_sender, _) = tokio::sync::mpsc::channel(10000);
369 let dummy_sampler = Sampler::new(
370 None,
371 0,
372 tokenizer.clone(),
373 None,
374 None,
375 None,
376 None,
377 -1,
378 0.0,
379 0.0,
380 vec![],
381 )
382 .map_err(candle_core::Error::msg)?;
383
384 let dummy_group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
385 1, false, false, None,
386 )));
387
388 let mut latest_loss = vec![0.0; optimizers.len()];
389 let mut all_losses = Vec::new();
390
391 for _ in
392 NiceProgressBar::<_, 'g'>(0..epochs, "Training gating layers", &new_multi_progress())
393 {
394 samples.as_mut_slice().shuffle(&mut rng);
395 for batch in samples.chunks(batch_size) {
396 steps += 1;
397
398 let mut seqs = Vec::new();
400 for AnyMoeTrainingInputRow {
401 prompt,
402 expert: _,
403 image_urls,
404 } in batch
405 {
406 let tokens = processor
407 .process(
408 &*target,
409 vec![IndexMap::from([
410 ("role".to_string(), Either::Left("user".to_string())),
411 ("content".to_string(), Either::Left(prompt.clone())),
412 ])],
413 true,
414 true,
415 None,
416 Vec::new(),
417 )
418 .map_err(candle_core::Error::msg)?;
419 let images = image_urls.as_ref().map(|urls| {
420 urls.iter()
421 .map(|url| -> anyhow::Result<DynamicImage> {
422 let bytes = if url.contains("http") {
423 match reqwest::blocking::get(url.clone()) {
425 Ok(http_resp) => http_resp.bytes()?.to_vec(),
426 Err(e) => anyhow::bail!(e),
427 }
428 } else if let Ok(mut f) = File::open(url) {
429 let metadata = fs::metadata(url)?;
431 #[allow(clippy::cast_possible_truncation)]
432 let mut buffer = vec![0; metadata.len() as usize];
433 f.read_exact(&mut buffer)?;
434 buffer
435 } else {
436 general_purpose::STANDARD.decode(url)?
438 };
439 Ok(image::load_from_memory(&bytes)?)
440 })
441 .collect::<anyhow::Result<Vec<_>>>()
442 });
443 let images = match images {
444 Some(Ok(x)) => Some(x),
445 Some(Err(e)) => {
446 return anyhow::Result::Err(candle_core::Error::Msg(e.to_string()))
447 }
448 None => None,
449 };
450 seqs.push(new_dummy_seq(
451 tokens,
452 dummy_sender.clone(),
453 dummy_sampler.clone(),
454 dummy_group.clone(),
455 images,
456 target.get_metadata().eos_tok.clone(),
457 ));
458 }
459 let mut input_seqs = seqs.iter_mut().collect::<Vec<_>>();
460
461 target.set_none_cache(&mut input_seqs, true, true, false);
463
464 let inputs = inputs_processor.process_inputs(
465 tokenizer.clone(),
466 &mut input_seqs,
467 true, metadata.is_xlora,
469 &device,
470 metadata.no_kv_cache,
471 None,
472 false,
473 input_processor_cfg.clone(),
474 None, None,
476 );
477
478 let _ = target.forward_inputs(inputs.unwrap().inputs, false)?;
482
483 target.set_none_cache(&mut input_seqs, true, true, false);
485
486 #[allow(clippy::cast_possible_truncation)]
488 let labels = Tensor::from_vec(
489 batch
490 .iter()
491 .map(
492 |AnyMoeTrainingInputRow {
493 prompt: _,
494 expert,
495 image_urls: _,
496 }| *expert as u32,
497 )
498 .collect::<Vec<_>>(),
499 (batch.len(),),
500 &device,
501 )?;
502
503 let cached = target.amoe_take_cached_gating_outputs();
504 for (layer, (optimizer, output)) in optimizers.iter_mut().zip(cached).enumerate() {
505 let loss = candle_nn::loss::cross_entropy(
506 &output,
507 &labels.to_device(output.device())?,
508 )?;
509 let gradstore = loss.backward()?;
510 optimizer.step(&gradstore)?;
511 latest_loss[layer] = loss.to_dtype(DType::F32)?.to_scalar::<f32>()?;
512 }
513 all_losses.push(latest_loss.clone());
514 }
515 }
516
517 target.amoe_finish_training(gate_model_id)?;
518 assert_eq!(target.amoe_base_model_trainable_params(), 0);
519
520 if let Some(loss_csv_path) = loss_csv_path {
521 let path = Path::new(&loss_csv_path);
522 if path
523 .extension()
524 .is_none_or(|e| e.to_string_lossy() != *"csv")
525 {
526 candle_core::bail!("`loss_csv_path` must have an extension `csv`.");
527 }
528
529 let mut writer = csv::Writer::from_path(path).map_err(candle_core::Error::msg)?;
530
531 let mut header = vec![format!("Step")];
532 header.extend((0..all_losses[0].len()).map(|i| format!("Gating layer {i}")));
533 writer
534 .write_record(&header)
535 .map_err(candle_core::Error::msg)?;
536
537 for (i, row) in all_losses.into_iter().enumerate() {
538 let mut new_row = vec![format!("Step {i}")];
539 new_row.extend(row.iter().map(|x| format!("{x:.4}")));
540 writer
541 .write_record(&new_row)
542 .map_err(candle_core::Error::msg)?;
543 }
544
545 writer.flush().map_err(candle_core::Error::msg)?;
546 }
547
548 Ok(Some(AnyMoeTrainingResult {
549 steps,
550 final_loss: latest_loss,
551 }))
552 }
553}
554
555fn new_dummy_seq(
558 (tokens, prompt): (Vec<u32>, String),
559 dummy_sender: tokio::sync::mpsc::Sender<Response>,
560 dummy_sampler: Sampler,
561 dummy_group: Arc<tokio::sync::Mutex<SequenceGroup>>,
562 images: Option<Vec<DynamicImage>>,
563 eos_toks: Vec<u32>,
564) -> Sequence {
565 Sequence::new_waiting(
566 tokens,
567 prompt,
568 0,
569 0,
570 1,
571 dummy_sender,
572 dummy_sampler,
573 vec![],
574 vec![],
575 None,
576 false,
577 false,
578 dummy_group,
579 0,
580 0,
581 SequenceRecognizer::None,
582 None,
583 None,
584 images,
585 None,
586 None, None,
588 None,
589 SeqStepType::PromptAndDecode,
590 None,
591 None,
592 false,
593 eos_toks,
594 )
595}