mistralrs_core/pipeline/
amoe.rs

1use std::{
2    any::Any,
3    fs::{self, File},
4    io::Read,
5    path::Path,
6    sync::Arc,
7};
8
9use base64::{engine::general_purpose, Engine};
10use candle_core::{DType, Device, Tensor};
11use candle_nn::{AdamW, Optimizer, ParamsAdamW};
12use either::Either;
13use image::DynamicImage;
14use indexmap::IndexMap;
15use indicatif::MultiProgress;
16use mistralrs_quant::IsqType;
17use rand::{rng, seq::SliceRandom};
18use rand_isaac::Isaac64Rng;
19use tracing::{info, warn};
20
21use crate::{
22    amoe::{AnyMoeConfig, AnyMoeTrainingInputRow, AnyMoeTrainingInputs, AnyMoeTrainingResult},
23    device_map::DeviceMapper,
24    get_mut_arcmutex,
25    prefix_cacher::PrefixCacheManagerV2,
26    sampler::Sampler,
27    sequence::{SeqStepType, Sequence, SequenceGroup, SequenceRecognizer},
28    utils::progress::NiceProgressBar,
29    DeviceMapSetting, Loader, ModelCategory, ModelKind, ModelPaths, PagedAttentionConfig, Pipeline,
30    Response, TokenSource, TryIntoDType,
31};
32
33use super::{
34    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
35    MetadataMixin, PreProcessingMixin,
36};
37
38pub struct AnyMoeLoader {
39    pub target: Box<dyn Loader>,
40    pub config: AnyMoeConfig,
41    pub path: String,
42    pub prefix: String,
43    pub mlp: String,
44    pub model_ids: Vec<String>,
45    pub layers: Vec<usize>,
46}
47
48pub struct AnyMoePipeline {
49    target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
50    config: AnyMoeConfig,
51}
52
53impl Loader for AnyMoeLoader {
54    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
55    fn load_model_from_hf(
56        &self,
57        revision: Option<String>,
58        token_source: TokenSource,
59        dtype: &dyn TryIntoDType,
60        device: &Device,
61        silent: bool,
62        mapper: DeviceMapSetting,
63        in_situ_quant: Option<IsqType>,
64        paged_attn_config: Option<PagedAttentionConfig>,
65    ) -> anyhow::Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
66        let paged_attn_config = if paged_attn_config.is_none() {
67            warn!("AnyMoE does not currently support PagedAttention, running without");
68            None
69        } else {
70            paged_attn_config
71        };
72
73        let target = self.target.load_model_from_hf(
74            revision.clone(),
75            token_source.clone(),
76            dtype,
77            device,
78            silent,
79            mapper.clone(),
80            in_situ_quant,
81            paged_attn_config,
82        )?;
83        Ok(Arc::new(tokio::sync::Mutex::new(AnyMoePipeline::new(
84            target,
85            self.config.clone(),
86            AnyMoeTrainingInputs::from_json(&self.path)?,
87            self.prefix.clone(),
88            self.mlp.clone(),
89            self.model_ids.clone(),
90            token_source,
91            revision,
92            self.layers.clone(),
93            silent,
94        )?)))
95    }
96
97    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
98    fn load_model_from_path(
99        &self,
100        paths: &Box<dyn ModelPaths>,
101        dtype: &dyn TryIntoDType,
102        device: &Device,
103        silent: bool,
104        mapper: DeviceMapSetting,
105        in_situ_quant: Option<IsqType>,
106        paged_attn_config: Option<PagedAttentionConfig>,
107    ) -> anyhow::Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
108        let paged_attn_config = if paged_attn_config.is_none() {
109            warn!("AnyMoE does not currently support PagedAttention, running without");
110            None
111        } else {
112            paged_attn_config
113        };
114
115        let target = self.target.load_model_from_path(
116            paths,
117            dtype,
118            device,
119            silent,
120            mapper.clone(),
121            in_situ_quant,
122            paged_attn_config,
123        )?;
124        Ok(Arc::new(tokio::sync::Mutex::new(AnyMoePipeline::new(
125            target,
126            self.config.clone(),
127            AnyMoeTrainingInputs::from_json(&self.path)?,
128            self.prefix.clone(),
129            self.mlp.clone(),
130            self.model_ids.clone(),
131            TokenSource::None,
132            None,
133            self.layers.clone(),
134            silent,
135        )?)))
136    }
137    fn get_id(&self) -> String {
138        format!("AnyMoE: tgt = `{}`", self.target.get_id(),)
139    }
140    fn get_kind(&self) -> ModelKind {
141        ModelKind::AnyMoe {
142            target: Box::new(self.target.get_kind()),
143        }
144    }
145}
146
147impl AnyMoePipeline {
148    #[allow(clippy::too_many_arguments)]
149    pub fn new(
150        target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
151        config: AnyMoeConfig,
152        inputs: AnyMoeTrainingInputs,
153        prefix: String,
154        mlp: String,
155        model_ids: Vec<String>,
156        token: TokenSource,
157        revision: Option<String>,
158        layers: Vec<usize>,
159        silent: bool,
160    ) -> anyhow::Result<Self> {
161        let this = Self { target, config };
162        info!("Loaded pretraining dataset of {} samples.", inputs.len());
163        match this.amoe_pre_train(
164            inputs,
165            (prefix, mlp),
166            model_ids,
167            token,
168            revision,
169            layers,
170            silent,
171        )? {
172            Some(AnyMoeTrainingResult { steps, final_loss }) => {
173                info!("Finished training in {steps} steps. Final losses per layer: {final_loss:?}")
174            }
175            None => {
176                info!("Not training gating layer, using trained gating layer specified in config")
177            }
178        }
179        Ok(this)
180    }
181}
182
183impl CacheManagerMixin for AnyMoePipeline {
184    fn cache(&self) -> &EitherCache {
185        unreachable!()
186    }
187    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
188        get_mut_arcmutex!(self.target).clone_in_cache(seqs)
189    }
190    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
191        get_mut_arcmutex!(self.target).clone_out_cache(seqs)
192    }
193    fn set_none_cache(
194        &self,
195        seqs: &mut [&mut Sequence],
196        reset_non_granular: bool,
197        modify_draft_cache: bool,
198        load_preallocated_cache: bool,
199    ) {
200        get_mut_arcmutex!(self.target).set_none_cache(
201            seqs,
202            reset_non_granular,
203            modify_draft_cache,
204            load_preallocated_cache,
205        )
206    }
207}
208
209impl IsqPipelineMixin for AnyMoePipeline {
210    fn re_isq_model(&mut self, dtype: IsqType) -> anyhow::Result<()> {
211        get_mut_arcmutex!(self.target).re_isq_model(dtype)
212    }
213}
214
215impl PreProcessingMixin for AnyMoePipeline {
216    fn get_chat_template(&self) -> Option<Arc<crate::ChatTemplate>> {
217        get_mut_arcmutex!(self.target).get_chat_template()
218    }
219    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
220        get_mut_arcmutex!(self.target).get_input_processor_config()
221    }
222    fn get_processor(&self) -> Arc<dyn super::Processor> {
223        get_mut_arcmutex!(self.target).get_processor()
224    }
225}
226
227impl MetadataMixin for AnyMoePipeline {
228    fn device(&self) -> Device {
229        get_mut_arcmutex!(self.target).device()
230    }
231    fn get_metadata(&self) -> Arc<super::GeneralMetadata> {
232        get_mut_arcmutex!(self.target).get_metadata()
233    }
234    fn name(&self) -> String {
235        get_mut_arcmutex!(self.target).name()
236    }
237    fn reset_non_granular_state(&self) {
238        get_mut_arcmutex!(self.target).reset_non_granular_state()
239    }
240    fn tokenizer(&self) -> Option<Arc<tokenizers::Tokenizer>> {
241        get_mut_arcmutex!(self.target).tokenizer()
242    }
243    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
244        None
245    }
246}
247
248#[async_trait::async_trait]
249impl Pipeline for AnyMoePipeline {
250    fn forward_inputs(
251        &mut self,
252        inputs: Box<dyn Any>,
253        _return_raw_logits: bool,
254    ) -> Result<ForwardInputsResult, candle_core::Error> {
255        get_mut_arcmutex!(self.target).forward_inputs(inputs, false)
256    }
257
258    async fn sample_causal_gen(
259        &self,
260        seqs: &mut [&mut Sequence],
261        logits: Vec<Tensor>,
262        prefix_cacher: &mut PrefixCacheManagerV2,
263        disable_eos_stop: bool,
264        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
265    ) -> Result<(), candle_core::Error> {
266        get_mut_arcmutex!(self.target)
267            .sample_causal_gen(seqs, logits, prefix_cacher, disable_eos_stop, rng)
268            .await
269    }
270
271    fn category(&self) -> ModelCategory {
272        get_mut_arcmutex!(self.target).category()
273    }
274}
275
276impl AnyMoePipelineMixin for AnyMoePipeline {
277    // Training result is None if inference
278    fn amoe_pre_train(
279        &self,
280        inputs: AnyMoeTrainingInputs,
281        (prefix, mlp): (String, String),
282        model_ids: Vec<String>,
283        token: TokenSource,
284        revision: Option<String>,
285        layers: Vec<usize>,
286        silent: bool,
287    ) -> anyhow::Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
288        let mut target = get_mut_arcmutex!(self.target);
289        if !target.amoe_supported() {
290            candle_core::bail!("AnyMoE is not supported for this model.");
291        }
292
293        let device = target.device();
294        let processor = target.get_processor();
295        let inputs_processor = target.get_processor().inputs_processor();
296        let tokenizer = target.tokenizer();
297        let metadata = target.get_metadata().clone();
298        let input_processor_cfg = target.get_input_processor_config().clone();
299
300        let AnyMoeConfig {
301            hidden_size: _,
302            lr,
303            epochs,
304            batch_size,
305            expert_type,
306            gate_model_id,
307            training,
308            loss_csv_path,
309        } = self.config.clone();
310        let mut steps = 0;
311
312        info!("Expert type: {expert_type:?}");
313        info!("Expert model ids: {model_ids:?}");
314
315        // Inject the AnyMoE layers
316        target.amoe_create_layers(
317            model_ids,
318            &token,
319            revision,
320            &mlp.clone(),
321            self.config.clone(),
322            metadata.activation_dtype,
323            &device,
324            (prefix, mlp),
325            layers,
326            expert_type,
327            silent,
328            if !training {
329                gate_model_id.clone()
330            } else {
331                None
332            },
333        )?;
334        let layer_vars = target.amoe_layer_vars();
335
336        // If there are no trainable params, assume we got a gate model id so no training
337        if target.amoe_base_model_trainable_params() == 0 {
338            return Ok(None);
339        }
340
341        info!(
342            "{} gating layers, {} trainable parameters, lr = {lr}, {epochs} epochs, batch size = {batch_size}",
343            layer_vars.len(),
344            target.amoe_base_model_trainable_params()
345        );
346
347        let mut optimizers = layer_vars
348            .into_iter()
349            .map(|vars| {
350                AdamW::new(
351                    vars,
352                    ParamsAdamW {
353                        lr,
354                        beta1: 0.9,
355                        beta2: 0.999,
356                        eps: 1e-8,
357                        weight_decay: 0.0,
358                    },
359                )
360            })
361            .collect::<candle_core::Result<Vec<_>>>()?;
362
363        let mut rng = rng();
364        let mut samples = inputs.into_inner();
365
366        // Create several dummy objects for the sequences. No custom logits processors.
367        let (dummy_sender, _) = tokio::sync::mpsc::channel(10000);
368        let dummy_sampler = Sampler::new(
369            None,
370            0,
371            tokenizer.clone(),
372            None,
373            None,
374            None,
375            -1,
376            0.0,
377            0.0,
378            vec![],
379        )
380        .map_err(candle_core::Error::msg)?;
381
382        let dummy_group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
383            1, false, false, None,
384        )));
385
386        let mut latest_loss = vec![0.0; optimizers.len()];
387        let mut all_losses = Vec::new();
388
389        for _ in
390            NiceProgressBar::<_, 'g'>(0..epochs, "Training gating layers", &MultiProgress::new())
391        {
392            samples.as_mut_slice().shuffle(&mut rng);
393            for batch in samples.chunks(batch_size) {
394                steps += 1;
395
396                // === PREPARE INPUTS ==
397                let mut seqs = Vec::new();
398                for AnyMoeTrainingInputRow {
399                    prompt,
400                    expert: _,
401                    image_urls,
402                } in batch
403                {
404                    let tokens = processor
405                        .process(
406                            &*target,
407                            vec![IndexMap::from([
408                                ("role".to_string(), Either::Left("user".to_string())),
409                                ("content".to_string(), Either::Left(prompt.clone())),
410                            ])],
411                            true,
412                            true,
413                            None,
414                            Vec::new(),
415                        )
416                        .map_err(candle_core::Error::msg)?;
417                    let images = image_urls.as_ref().map(|urls| {
418                        urls.iter()
419                            .map(|url| -> anyhow::Result<DynamicImage> {
420                                let bytes = if url.contains("http") {
421                                    // Read from http
422                                    match reqwest::blocking::get(url.clone()) {
423                                        Ok(http_resp) => http_resp.bytes()?.to_vec(),
424                                        Err(e) => anyhow::bail!(e),
425                                    }
426                                } else if let Ok(mut f) = File::open(url) {
427                                    // Read from local file
428                                    let metadata = fs::metadata(url)?;
429                                    #[allow(clippy::cast_possible_truncation)]
430                                    let mut buffer = vec![0; metadata.len() as usize];
431                                    f.read_exact(&mut buffer)?;
432                                    buffer
433                                } else {
434                                    // Decode with base64
435                                    general_purpose::STANDARD.decode(url)?
436                                };
437                                Ok(image::load_from_memory(&bytes)?)
438                            })
439                            .collect::<anyhow::Result<Vec<_>>>()
440                    });
441                    let images = match images {
442                        Some(Ok(x)) => Some(x),
443                        Some(Err(e)) => {
444                            return anyhow::Result::Err(candle_core::Error::Msg(e.to_string()))
445                        }
446                        None => None,
447                    };
448                    seqs.push(new_dummy_seq(
449                        tokens,
450                        dummy_sender.clone(),
451                        dummy_sampler.clone(),
452                        dummy_group.clone(),
453                        images,
454                        target.get_metadata().eos_tok.clone(),
455                    ));
456                }
457                let mut input_seqs = seqs.iter_mut().collect::<Vec<_>>();
458
459                // Clear KV cache in prep for training
460                target.set_none_cache(&mut input_seqs, true, true, false);
461
462                let inputs = inputs_processor
463                    .process_inputs(
464                        tokenizer.clone(),
465                        &mut input_seqs,
466                        true, // Always a prompt
467                        metadata.is_xlora,
468                        &device,
469                        metadata.no_kv_cache,
470                        None,
471                        false,
472                        input_processor_cfg.clone(),
473                        None, // TODO: get block tables/handle it for PagedAttention
474                        None, // TODO: prompt chunking doesn't work.
475                        None,
476                    )
477                    .nth(0)
478                    .unwrap();
479
480                // === PREPARE AND RUN MODEL ==
481
482                // Run the model, ignoring the logits
483                let _ = target.forward_inputs(inputs.unwrap().inputs, false)?;
484
485                // Clear the KV cache
486                target.set_none_cache(&mut input_seqs, true, true, false);
487
488                // === BACKWARD STEP ==
489                #[allow(clippy::cast_possible_truncation)]
490                let labels = Tensor::from_vec(
491                    batch
492                        .iter()
493                        .map(
494                            |AnyMoeTrainingInputRow {
495                                 prompt: _,
496                                 expert,
497                                 image_urls: _,
498                             }| *expert as u32,
499                        )
500                        .collect::<Vec<_>>(),
501                    (batch.len(),),
502                    &device,
503                )?;
504
505                let cached = target.amoe_take_cached_gating_outputs();
506                for (layer, (optimizer, output)) in optimizers.iter_mut().zip(cached).enumerate() {
507                    let loss = candle_nn::loss::cross_entropy(
508                        &output,
509                        &labels.to_device(output.device())?,
510                    )?;
511                    let gradstore = loss.backward()?;
512                    optimizer.step(&gradstore)?;
513                    latest_loss[layer] = loss.to_dtype(DType::F32)?.to_scalar::<f32>()?;
514                }
515                all_losses.push(latest_loss.clone());
516            }
517        }
518
519        target.amoe_finish_training(gate_model_id)?;
520        assert_eq!(target.amoe_base_model_trainable_params(), 0);
521
522        if let Some(loss_csv_path) = loss_csv_path {
523            let path = Path::new(&loss_csv_path);
524            if path
525                .extension()
526                .is_none_or(|e| e.to_string_lossy() != *"csv")
527            {
528                candle_core::bail!("`loss_csv_path` must have an extension `csv`.");
529            }
530
531            let mut writer = csv::Writer::from_path(path).map_err(candle_core::Error::msg)?;
532
533            let mut header = vec![format!("Step")];
534            header.extend((0..all_losses[0].len()).map(|i| format!("Gating layer {i}")));
535            writer
536                .write_record(&header)
537                .map_err(candle_core::Error::msg)?;
538
539            for (i, row) in all_losses.into_iter().enumerate() {
540                let mut new_row = vec![format!("Step {i}")];
541                new_row.extend(row.iter().map(|x| format!("{x:.4}")));
542                writer
543                    .write_record(&new_row)
544                    .map_err(candle_core::Error::msg)?;
545            }
546
547            writer.flush().map_err(candle_core::Error::msg)?;
548        }
549
550        Ok(Some(AnyMoeTrainingResult {
551            steps,
552            final_loss: latest_loss,
553        }))
554    }
555}
556
557/// Create a dummy sequence containing just the prompt. This is OK because we just want a sequence that
558/// has no information other than the input tokens (and maybe images).
559fn new_dummy_seq(
560    (tokens, prompt): (Vec<u32>, String),
561    dummy_sender: tokio::sync::mpsc::Sender<Response>,
562    dummy_sampler: Sampler,
563    dummy_group: Arc<tokio::sync::Mutex<SequenceGroup>>,
564    images: Option<Vec<DynamicImage>>,
565    eos_toks: Vec<u32>,
566) -> Sequence {
567    Sequence::new_waiting(
568        tokens,
569        prompt,
570        0,
571        0,
572        1,
573        dummy_sender,
574        dummy_sampler,
575        vec![],
576        vec![],
577        None,
578        false,
579        false,
580        dummy_group,
581        0,
582        0,
583        SequenceRecognizer::None,
584        None,
585        None,
586        images,
587        None, // TODO incorrect for PagedAttention
588        None,
589        None,
590        SeqStepType::PromptAndDecode,
591        None,
592        None,
593        false,
594        eos_toks,
595    )
596}