mistralrs_core/pipeline/
amoe.rs

1use std::{
2    any::Any,
3    fs::{self, File},
4    io::Read,
5    path::Path,
6    sync::Arc,
7};
8
9use base64::{engine::general_purpose, Engine};
10use candle_core::{DType, Device, Tensor};
11use candle_nn::{AdamW, Optimizer, ParamsAdamW};
12use either::Either;
13use image::DynamicImage;
14use indexmap::IndexMap;
15use indicatif::MultiProgress;
16use mistralrs_quant::IsqType;
17use rand::{rng, seq::SliceRandom};
18use rand_isaac::Isaac64Rng;
19use tracing::{info, warn};
20
21use crate::{
22    amoe::{AnyMoeConfig, AnyMoeTrainingInputRow, AnyMoeTrainingInputs, AnyMoeTrainingResult},
23    device_map::DeviceMapper,
24    get_mut_arcmutex,
25    prefix_cacher::PrefixCacheManagerV2,
26    sampler::Sampler,
27    sequence::{SeqStepType, Sequence, SequenceGroup, SequenceRecognizer},
28    utils::progress::NiceProgressBar,
29    DeviceMapSetting, Loader, ModelCategory, ModelKind, ModelPaths, PagedAttentionConfig, Pipeline,
30    Response, TokenSource, TryIntoDType,
31};
32
33use super::{
34    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
35    MetadataMixin, PreProcessingMixin,
36};
37
38pub struct AnyMoeLoader {
39    pub target: Box<dyn Loader>,
40    pub config: AnyMoeConfig,
41    pub path: String,
42    pub prefix: String,
43    pub mlp: String,
44    pub model_ids: Vec<String>,
45    pub layers: Vec<usize>,
46}
47
48pub struct AnyMoePipeline {
49    target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
50    config: AnyMoeConfig,
51}
52
53impl Loader for AnyMoeLoader {
54    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
55    fn load_model_from_hf(
56        &self,
57        revision: Option<String>,
58        token_source: TokenSource,
59        dtype: &dyn TryIntoDType,
60        device: &Device,
61        silent: bool,
62        mapper: DeviceMapSetting,
63        in_situ_quant: Option<IsqType>,
64        paged_attn_config: Option<PagedAttentionConfig>,
65    ) -> anyhow::Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
66        let paged_attn_config = if paged_attn_config.is_none() {
67            warn!("AnyMoE does not currently support PagedAttention, running without");
68            None
69        } else {
70            paged_attn_config
71        };
72
73        let target = self.target.load_model_from_hf(
74            revision.clone(),
75            token_source.clone(),
76            dtype,
77            device,
78            silent,
79            mapper.clone(),
80            in_situ_quant,
81            paged_attn_config,
82        )?;
83        Ok(Arc::new(tokio::sync::Mutex::new(AnyMoePipeline::new(
84            target,
85            self.config.clone(),
86            AnyMoeTrainingInputs::from_json(&self.path)?,
87            self.prefix.clone(),
88            self.mlp.clone(),
89            self.model_ids.clone(),
90            token_source,
91            revision,
92            self.layers.clone(),
93            silent,
94        )?)))
95    }
96
97    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
98    fn load_model_from_path(
99        &self,
100        paths: &Box<dyn ModelPaths>,
101        dtype: &dyn TryIntoDType,
102        device: &Device,
103        silent: bool,
104        mapper: DeviceMapSetting,
105        in_situ_quant: Option<IsqType>,
106        paged_attn_config: Option<PagedAttentionConfig>,
107    ) -> anyhow::Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
108        let paged_attn_config = if paged_attn_config.is_none() {
109            warn!("AnyMoE does not currently support PagedAttention, running without");
110            None
111        } else {
112            paged_attn_config
113        };
114
115        let target = self.target.load_model_from_path(
116            paths,
117            dtype,
118            device,
119            silent,
120            mapper.clone(),
121            in_situ_quant,
122            paged_attn_config,
123        )?;
124        Ok(Arc::new(tokio::sync::Mutex::new(AnyMoePipeline::new(
125            target,
126            self.config.clone(),
127            AnyMoeTrainingInputs::from_json(&self.path)?,
128            self.prefix.clone(),
129            self.mlp.clone(),
130            self.model_ids.clone(),
131            TokenSource::None,
132            None,
133            self.layers.clone(),
134            silent,
135        )?)))
136    }
137    fn get_id(&self) -> String {
138        format!("AnyMoE: tgt = `{}`", self.target.get_id(),)
139    }
140    fn get_kind(&self) -> ModelKind {
141        ModelKind::AnyMoe {
142            target: Box::new(self.target.get_kind()),
143        }
144    }
145}
146
147impl AnyMoePipeline {
148    #[allow(clippy::too_many_arguments)]
149    pub fn new(
150        target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
151        config: AnyMoeConfig,
152        inputs: AnyMoeTrainingInputs,
153        prefix: String,
154        mlp: String,
155        model_ids: Vec<String>,
156        token: TokenSource,
157        revision: Option<String>,
158        layers: Vec<usize>,
159        silent: bool,
160    ) -> anyhow::Result<Self> {
161        let this = Self { target, config };
162        info!("Loaded pretraining dataset of {} samples.", inputs.len());
163        match this.amoe_pre_train(
164            inputs,
165            (prefix, mlp),
166            model_ids,
167            token,
168            revision,
169            layers,
170            silent,
171        )? {
172            Some(AnyMoeTrainingResult { steps, final_loss }) => {
173                info!("Finished training in {steps} steps. Final losses per layer: {final_loss:?}")
174            }
175            None => {
176                info!("Not training gating layer, using trained gating layer specified in config")
177            }
178        }
179        Ok(this)
180    }
181}
182
183impl CacheManagerMixin for AnyMoePipeline {
184    fn cache(&self) -> &EitherCache {
185        unreachable!()
186    }
187    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
188        get_mut_arcmutex!(self.target).clone_in_cache(seqs)
189    }
190    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
191        get_mut_arcmutex!(self.target).clone_out_cache(seqs)
192    }
193    fn set_none_cache(
194        &self,
195        seqs: &mut [&mut Sequence],
196        reset_non_granular: bool,
197        modify_draft_cache: bool,
198        load_preallocated_cache: bool,
199    ) {
200        get_mut_arcmutex!(self.target).set_none_cache(
201            seqs,
202            reset_non_granular,
203            modify_draft_cache,
204            load_preallocated_cache,
205        )
206    }
207}
208
209impl IsqPipelineMixin for AnyMoePipeline {
210    fn re_isq_model(&mut self, dtype: IsqType) -> anyhow::Result<()> {
211        get_mut_arcmutex!(self.target).re_isq_model(dtype)
212    }
213}
214
215impl PreProcessingMixin for AnyMoePipeline {
216    fn get_chat_template(&self) -> Option<Arc<crate::ChatTemplate>> {
217        get_mut_arcmutex!(self.target).get_chat_template()
218    }
219    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
220        get_mut_arcmutex!(self.target).get_input_processor_config()
221    }
222    fn get_processor(&self) -> Arc<dyn super::Processor> {
223        get_mut_arcmutex!(self.target).get_processor()
224    }
225}
226
227impl MetadataMixin for AnyMoePipeline {
228    fn device(&self) -> Device {
229        get_mut_arcmutex!(self.target).device()
230    }
231    fn get_metadata(&self) -> Arc<super::GeneralMetadata> {
232        get_mut_arcmutex!(self.target).get_metadata()
233    }
234    fn name(&self) -> String {
235        get_mut_arcmutex!(self.target).name()
236    }
237    fn reset_non_granular_state(&self) {
238        get_mut_arcmutex!(self.target).reset_non_granular_state()
239    }
240    fn tokenizer(&self) -> Option<Arc<tokenizers::Tokenizer>> {
241        get_mut_arcmutex!(self.target).tokenizer()
242    }
243    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
244        None
245    }
246}
247
248#[async_trait::async_trait]
249impl Pipeline for AnyMoePipeline {
250    fn forward_inputs(
251        &mut self,
252        inputs: Box<dyn Any>,
253        _return_raw_logits: bool,
254    ) -> Result<ForwardInputsResult, candle_core::Error> {
255        get_mut_arcmutex!(self.target).forward_inputs(inputs, false)
256    }
257
258    async fn sample_causal_gen(
259        &self,
260        seqs: &mut [&mut Sequence],
261        logits: Vec<Tensor>,
262        prefix_cacher: &mut PrefixCacheManagerV2,
263        disable_eos_stop: bool,
264        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
265    ) -> Result<(), candle_core::Error> {
266        get_mut_arcmutex!(self.target)
267            .sample_causal_gen(seqs, logits, prefix_cacher, disable_eos_stop, rng)
268            .await
269    }
270
271    fn category(&self) -> ModelCategory {
272        get_mut_arcmutex!(self.target).category()
273    }
274}
275
276impl AnyMoePipelineMixin for AnyMoePipeline {
277    // Training result is None if inference
278    fn amoe_pre_train(
279        &self,
280        inputs: AnyMoeTrainingInputs,
281        (prefix, mlp): (String, String),
282        model_ids: Vec<String>,
283        token: TokenSource,
284        revision: Option<String>,
285        layers: Vec<usize>,
286        silent: bool,
287    ) -> anyhow::Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
288        let mut target = get_mut_arcmutex!(self.target);
289        if !target.amoe_supported() {
290            candle_core::bail!("AnyMoE is not supported for this model.");
291        }
292
293        let device = target.device();
294        let processor = target.get_processor();
295        let inputs_processor = target.get_processor().inputs_processor();
296        let tokenizer = target.tokenizer();
297        let metadata = target.get_metadata().clone();
298        let input_processor_cfg = target.get_input_processor_config().clone();
299
300        let AnyMoeConfig {
301            hidden_size: _,
302            lr,
303            epochs,
304            batch_size,
305            expert_type,
306            gate_model_id,
307            training,
308            loss_csv_path,
309        } = self.config.clone();
310        let mut steps = 0;
311
312        info!("Expert type: {expert_type:?}");
313        info!("Expert model ids: {model_ids:?}");
314
315        // Inject the AnyMoE layers
316        target.amoe_create_layers(
317            model_ids,
318            &token,
319            revision,
320            &mlp.clone(),
321            self.config.clone(),
322            metadata.activation_dtype,
323            &device,
324            (prefix, mlp),
325            layers,
326            expert_type,
327            silent,
328            if !training {
329                gate_model_id.clone()
330            } else {
331                None
332            },
333        )?;
334        let layer_vars = target.amoe_layer_vars();
335
336        // If there are no trainable params, assume we got a gate model id so no training
337        if target.amoe_base_model_trainable_params() == 0 {
338            return Ok(None);
339        }
340
341        info!(
342            "{} gating layers, {} trainable parameters, lr = {lr}, {epochs} epochs, batch size = {batch_size}",
343            layer_vars.len(),
344            target.amoe_base_model_trainable_params()
345        );
346
347        let mut optimizers = layer_vars
348            .into_iter()
349            .map(|vars| {
350                AdamW::new(
351                    vars,
352                    ParamsAdamW {
353                        lr,
354                        beta1: 0.9,
355                        beta2: 0.999,
356                        eps: 1e-8,
357                        weight_decay: 0.0,
358                    },
359                )
360            })
361            .collect::<candle_core::Result<Vec<_>>>()?;
362
363        let mut rng = rng();
364        let mut samples = inputs.into_inner();
365
366        // Create several dummy objects for the sequences. No custom logits processors.
367        let (dummy_sender, _) = tokio::sync::mpsc::channel(10000);
368        let dummy_sampler = Sampler::new(
369            None,
370            0,
371            tokenizer.clone(),
372            None,
373            None,
374            None,
375            -1,
376            0.0,
377            0.0,
378            vec![],
379        )
380        .map_err(candle_core::Error::msg)?;
381
382        let dummy_group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
383            1, false, false, None,
384        )));
385
386        let mut latest_loss = vec![0.0; optimizers.len()];
387        let mut all_losses = Vec::new();
388
389        for _ in
390            NiceProgressBar::<_, 'g'>(0..epochs, "Training gating layers", &MultiProgress::new())
391        {
392            samples.as_mut_slice().shuffle(&mut rng);
393            for batch in samples.chunks(batch_size) {
394                steps += 1;
395
396                // === PREPARE INPUTS ==
397                let mut seqs = Vec::new();
398                for AnyMoeTrainingInputRow {
399                    prompt,
400                    expert: _,
401                    image_urls,
402                } in batch
403                {
404                    let tokens = processor
405                        .process(
406                            &*target,
407                            vec![IndexMap::from([
408                                ("role".to_string(), Either::Left("user".to_string())),
409                                ("content".to_string(), Either::Left(prompt.clone())),
410                            ])],
411                            true,
412                            true,
413                            Vec::new(),
414                        )
415                        .map_err(candle_core::Error::msg)?;
416                    let images = image_urls.as_ref().map(|urls| {
417                        urls.iter()
418                            .map(|url| -> anyhow::Result<DynamicImage> {
419                                let bytes = if url.contains("http") {
420                                    // Read from http
421                                    match reqwest::blocking::get(url.clone()) {
422                                        Ok(http_resp) => http_resp.bytes()?.to_vec(),
423                                        Err(e) => anyhow::bail!(e),
424                                    }
425                                } else if let Ok(mut f) = File::open(url) {
426                                    // Read from local file
427                                    let metadata = fs::metadata(url)?;
428                                    #[allow(clippy::cast_possible_truncation)]
429                                    let mut buffer = vec![0; metadata.len() as usize];
430                                    f.read_exact(&mut buffer)?;
431                                    buffer
432                                } else {
433                                    // Decode with base64
434                                    general_purpose::STANDARD.decode(url)?
435                                };
436                                Ok(image::load_from_memory(&bytes)?)
437                            })
438                            .collect::<anyhow::Result<Vec<_>>>()
439                    });
440                    let images = match images {
441                        Some(Ok(x)) => Some(x),
442                        Some(Err(e)) => {
443                            return anyhow::Result::Err(candle_core::Error::Msg(e.to_string()))
444                        }
445                        None => None,
446                    };
447                    seqs.push(new_dummy_seq(
448                        tokens,
449                        dummy_sender.clone(),
450                        dummy_sampler.clone(),
451                        dummy_group.clone(),
452                        images,
453                        target.get_metadata().eos_tok.clone(),
454                    ));
455                }
456                let mut input_seqs = seqs.iter_mut().collect::<Vec<_>>();
457
458                // Clear KV cache in prep for training
459                target.set_none_cache(&mut input_seqs, true, true, false);
460
461                let inputs = inputs_processor
462                    .process_inputs(
463                        tokenizer.clone(),
464                        &mut input_seqs,
465                        true, // Always a prompt
466                        metadata.is_xlora,
467                        &device,
468                        metadata.no_kv_cache,
469                        None,
470                        false,
471                        input_processor_cfg.clone(),
472                        None, // TODO: get block tables/handle it for PagedAttention
473                        None, // TODO: prompt chunking doesn't work.
474                        None,
475                    )
476                    .nth(0)
477                    .unwrap();
478
479                // === PREPARE AND RUN MODEL ==
480
481                // Run the model, ignoring the logits
482                let _ = target.forward_inputs(inputs.unwrap().inputs, false)?;
483
484                // Clear the KV cache
485                target.set_none_cache(&mut input_seqs, true, true, false);
486
487                // === BACKWARD STEP ==
488                #[allow(clippy::cast_possible_truncation)]
489                let labels = Tensor::from_vec(
490                    batch
491                        .iter()
492                        .map(
493                            |AnyMoeTrainingInputRow {
494                                 prompt: _,
495                                 expert,
496                                 image_urls: _,
497                             }| *expert as u32,
498                        )
499                        .collect::<Vec<_>>(),
500                    (batch.len(),),
501                    &device,
502                )?;
503
504                let cached = target.amoe_take_cached_gating_outputs();
505                for (layer, (optimizer, output)) in optimizers.iter_mut().zip(cached).enumerate() {
506                    let loss = candle_nn::loss::cross_entropy(
507                        &output,
508                        &labels.to_device(output.device())?,
509                    )?;
510                    let gradstore = loss.backward()?;
511                    optimizer.step(&gradstore)?;
512                    latest_loss[layer] = loss.to_dtype(DType::F32)?.to_scalar::<f32>()?;
513                }
514                all_losses.push(latest_loss.clone());
515            }
516        }
517
518        target.amoe_finish_training(gate_model_id)?;
519        assert_eq!(target.amoe_base_model_trainable_params(), 0);
520
521        if let Some(loss_csv_path) = loss_csv_path {
522            let path = Path::new(&loss_csv_path);
523            if path
524                .extension()
525                .is_none_or(|e| e.to_string_lossy() != *"csv")
526            {
527                candle_core::bail!("`loss_csv_path` must have an extension `csv`.");
528            }
529
530            let mut writer = csv::Writer::from_path(path).map_err(candle_core::Error::msg)?;
531
532            let mut header = vec![format!("Step")];
533            header.extend((0..all_losses[0].len()).map(|i| format!("Gating layer {i}")));
534            writer
535                .write_record(&header)
536                .map_err(candle_core::Error::msg)?;
537
538            for (i, row) in all_losses.into_iter().enumerate() {
539                let mut new_row = vec![format!("Step {i}")];
540                new_row.extend(row.iter().map(|x| format!("{x:.4}")));
541                writer
542                    .write_record(&new_row)
543                    .map_err(candle_core::Error::msg)?;
544            }
545
546            writer.flush().map_err(candle_core::Error::msg)?;
547        }
548
549        Ok(Some(AnyMoeTrainingResult {
550            steps,
551            final_loss: latest_loss,
552        }))
553    }
554}
555
556/// Create a dummy sequence containing just the prompt. This is OK because we just want a sequence that
557/// has no information other than the input tokens (and maybe images).
558fn new_dummy_seq(
559    (tokens, prompt): (Vec<u32>, String),
560    dummy_sender: tokio::sync::mpsc::Sender<Response>,
561    dummy_sampler: Sampler,
562    dummy_group: Arc<tokio::sync::Mutex<SequenceGroup>>,
563    images: Option<Vec<DynamicImage>>,
564    eos_toks: Vec<u32>,
565) -> Sequence {
566    Sequence::new_waiting(
567        tokens,
568        prompt,
569        0,
570        0,
571        1,
572        dummy_sender,
573        dummy_sampler,
574        vec![],
575        vec![],
576        None,
577        false,
578        false,
579        dummy_group,
580        0,
581        0,
582        SequenceRecognizer::None,
583        None,
584        None,
585        images,
586        None, // TODO incorrect for PagedAttention
587        None,
588        None,
589        SeqStepType::PromptAndDecode,
590        None,
591        None,
592        false,
593        eos_toks,
594    )
595}