mistralrs_core/pipeline/
amoe.rs

1use std::{
2    any::Any,
3    fs::{self, File},
4    io::Read,
5    path::Path,
6    sync::Arc,
7};
8
9use base64::{engine::general_purpose, Engine};
10use candle_core::{DType, Device, Tensor};
11use candle_nn::{AdamW, Optimizer, ParamsAdamW};
12use either::Either;
13use image::DynamicImage;
14use indexmap::IndexMap;
15use mistralrs_quant::IsqType;
16use rand::{rng, seq::SliceRandom};
17use rand_isaac::Isaac64Rng;
18use tracing::{info, warn};
19
20use crate::{
21    amoe::{AnyMoeConfig, AnyMoeTrainingInputRow, AnyMoeTrainingInputs, AnyMoeTrainingResult},
22    device_map::DeviceMapper,
23    get_mut_arcmutex,
24    prefix_cacher::PrefixCacheManagerV2,
25    sampler::Sampler,
26    sequence::{SeqStepType, Sequence, SequenceGroup, SequenceRecognizer},
27    utils::progress::{new_multi_progress, NiceProgressBar, ProgressScopeGuard},
28    DeviceMapSetting, Loader, ModelCategory, ModelKind, ModelPaths, PagedAttentionConfig, Pipeline,
29    Response, TokenSource, TryIntoDType,
30};
31
32use super::{
33    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin,
34    MetadataMixin, PreProcessingMixin,
35};
36
37pub struct AnyMoeLoader {
38    pub target: Box<dyn Loader>,
39    pub config: AnyMoeConfig,
40    pub path: String,
41    pub prefix: String,
42    pub mlp: String,
43    pub model_ids: Vec<String>,
44    pub layers: Vec<usize>,
45}
46
47pub struct AnyMoePipeline {
48    target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
49    config: AnyMoeConfig,
50}
51
52impl Loader for AnyMoeLoader {
53    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
54    fn load_model_from_hf(
55        &self,
56        revision: Option<String>,
57        token_source: TokenSource,
58        dtype: &dyn TryIntoDType,
59        device: &Device,
60        silent: bool,
61        mapper: DeviceMapSetting,
62        in_situ_quant: Option<IsqType>,
63        paged_attn_config: Option<PagedAttentionConfig>,
64    ) -> anyhow::Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
65        let _progress_guard = ProgressScopeGuard::new(silent);
66        let paged_attn_config = if paged_attn_config.is_none() {
67            warn!("AnyMoE does not currently support PagedAttention, running without");
68            None
69        } else {
70            paged_attn_config
71        };
72
73        let target = self.target.load_model_from_hf(
74            revision.clone(),
75            token_source.clone(),
76            dtype,
77            device,
78            silent,
79            mapper.clone(),
80            in_situ_quant,
81            paged_attn_config,
82        )?;
83        Ok(Arc::new(tokio::sync::Mutex::new(AnyMoePipeline::new(
84            target,
85            self.config.clone(),
86            AnyMoeTrainingInputs::from_json(&self.path)?,
87            self.prefix.clone(),
88            self.mlp.clone(),
89            self.model_ids.clone(),
90            token_source,
91            revision,
92            self.layers.clone(),
93            silent,
94        )?)))
95    }
96
97    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
98    fn load_model_from_path(
99        &self,
100        paths: &Box<dyn ModelPaths>,
101        dtype: &dyn TryIntoDType,
102        device: &Device,
103        silent: bool,
104        mapper: DeviceMapSetting,
105        in_situ_quant: Option<IsqType>,
106        paged_attn_config: Option<PagedAttentionConfig>,
107    ) -> anyhow::Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
108        let _progress_guard = ProgressScopeGuard::new(silent);
109        let paged_attn_config = if paged_attn_config.is_none() {
110            warn!("AnyMoE does not currently support PagedAttention, running without");
111            None
112        } else {
113            paged_attn_config
114        };
115
116        let target = self.target.load_model_from_path(
117            paths,
118            dtype,
119            device,
120            silent,
121            mapper.clone(),
122            in_situ_quant,
123            paged_attn_config,
124        )?;
125        Ok(Arc::new(tokio::sync::Mutex::new(AnyMoePipeline::new(
126            target,
127            self.config.clone(),
128            AnyMoeTrainingInputs::from_json(&self.path)?,
129            self.prefix.clone(),
130            self.mlp.clone(),
131            self.model_ids.clone(),
132            TokenSource::None,
133            None,
134            self.layers.clone(),
135            silent,
136        )?)))
137    }
138    fn get_id(&self) -> String {
139        format!("AnyMoE: tgt = `{}`", self.target.get_id(),)
140    }
141    fn get_kind(&self) -> ModelKind {
142        ModelKind::AnyMoe {
143            target: Box::new(self.target.get_kind()),
144        }
145    }
146}
147
148impl AnyMoePipeline {
149    #[allow(clippy::too_many_arguments)]
150    pub fn new(
151        target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
152        config: AnyMoeConfig,
153        inputs: AnyMoeTrainingInputs,
154        prefix: String,
155        mlp: String,
156        model_ids: Vec<String>,
157        token: TokenSource,
158        revision: Option<String>,
159        layers: Vec<usize>,
160        silent: bool,
161    ) -> anyhow::Result<Self> {
162        let this = Self { target, config };
163        info!("Loaded pretraining dataset of {} samples.", inputs.len());
164        match this.amoe_pre_train(
165            inputs,
166            (prefix, mlp),
167            model_ids,
168            token,
169            revision,
170            layers,
171            silent,
172        )? {
173            Some(AnyMoeTrainingResult { steps, final_loss }) => {
174                info!("Finished training in {steps} steps. Final losses per layer: {final_loss:?}")
175            }
176            None => {
177                info!("Not training gating layer, using trained gating layer specified in config")
178            }
179        }
180        Ok(this)
181    }
182}
183
184impl CacheManagerMixin for AnyMoePipeline {
185    fn cache(&self) -> &EitherCache {
186        unreachable!()
187    }
188    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
189        get_mut_arcmutex!(self.target).clone_in_cache(seqs)
190    }
191    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
192        get_mut_arcmutex!(self.target).clone_out_cache(seqs)
193    }
194    fn set_none_cache(
195        &self,
196        seqs: &mut [&mut Sequence],
197        reset_non_granular: bool,
198        modify_draft_cache: bool,
199        load_preallocated_cache: bool,
200    ) {
201        get_mut_arcmutex!(self.target).set_none_cache(
202            seqs,
203            reset_non_granular,
204            modify_draft_cache,
205            load_preallocated_cache,
206        )
207    }
208}
209
210impl IsqPipelineMixin for AnyMoePipeline {
211    fn re_isq_model(&mut self, dtype: IsqType) -> anyhow::Result<()> {
212        get_mut_arcmutex!(self.target).re_isq_model(dtype)
213    }
214}
215
216impl PreProcessingMixin for AnyMoePipeline {
217    fn get_chat_template(&self) -> Option<Arc<crate::ChatTemplate>> {
218        get_mut_arcmutex!(self.target).get_chat_template()
219    }
220    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
221        get_mut_arcmutex!(self.target).get_input_processor_config()
222    }
223    fn get_processor(&self) -> Arc<dyn super::Processor> {
224        get_mut_arcmutex!(self.target).get_processor()
225    }
226}
227
228impl MetadataMixin for AnyMoePipeline {
229    fn device(&self) -> Device {
230        get_mut_arcmutex!(self.target).device()
231    }
232    fn get_metadata(&self) -> Arc<super::GeneralMetadata> {
233        get_mut_arcmutex!(self.target).get_metadata()
234    }
235    fn name(&self) -> String {
236        get_mut_arcmutex!(self.target).name()
237    }
238    fn reset_non_granular_state(&self) {
239        get_mut_arcmutex!(self.target).reset_non_granular_state()
240    }
241    fn tokenizer(&self) -> Option<Arc<tokenizers::Tokenizer>> {
242        get_mut_arcmutex!(self.target).tokenizer()
243    }
244    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
245        None
246    }
247}
248
249#[async_trait::async_trait]
250impl Pipeline for AnyMoePipeline {
251    fn forward_inputs(
252        &mut self,
253        inputs: Box<dyn Any>,
254        _return_raw_logits: bool,
255    ) -> Result<ForwardInputsResult, candle_core::Error> {
256        get_mut_arcmutex!(self.target).forward_inputs(inputs, false)
257    }
258
259    async fn sample_causal_gen(
260        &self,
261        seqs: &mut [&mut Sequence],
262        logits: Vec<Tensor>,
263        prefix_cacher: &mut PrefixCacheManagerV2,
264        disable_eos_stop: bool,
265        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
266    ) -> Result<(), candle_core::Error> {
267        get_mut_arcmutex!(self.target)
268            .sample_causal_gen(seqs, logits, prefix_cacher, disable_eos_stop, rng)
269            .await
270    }
271
272    fn category(&self) -> ModelCategory {
273        get_mut_arcmutex!(self.target).category()
274    }
275}
276
277impl AnyMoePipelineMixin for AnyMoePipeline {
278    // Training result is None if inference
279    fn amoe_pre_train(
280        &self,
281        inputs: AnyMoeTrainingInputs,
282        (prefix, mlp): (String, String),
283        model_ids: Vec<String>,
284        token: TokenSource,
285        revision: Option<String>,
286        layers: Vec<usize>,
287        silent: bool,
288    ) -> anyhow::Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
289        let mut target = get_mut_arcmutex!(self.target);
290        if !target.amoe_supported() {
291            candle_core::bail!("AnyMoE is not supported for this model.");
292        }
293
294        let device = target.device();
295        let processor = target.get_processor();
296        let inputs_processor = target.get_processor().inputs_processor();
297        let tokenizer = target.tokenizer();
298        let metadata = target.get_metadata().clone();
299        let input_processor_cfg = target.get_input_processor_config().clone();
300
301        let AnyMoeConfig {
302            hidden_size: _,
303            lr,
304            epochs,
305            batch_size,
306            expert_type,
307            gate_model_id,
308            training,
309            loss_csv_path,
310        } = self.config.clone();
311        let mut steps = 0;
312
313        info!("Expert type: {expert_type:?}");
314        info!("Expert model ids: {model_ids:?}");
315
316        // Inject the AnyMoE layers
317        target.amoe_create_layers(
318            model_ids,
319            &token,
320            revision,
321            &mlp.clone(),
322            self.config.clone(),
323            metadata.activation_dtype,
324            &device,
325            (prefix, mlp),
326            layers,
327            expert_type,
328            silent,
329            if !training {
330                gate_model_id.clone()
331            } else {
332                None
333            },
334        )?;
335        let layer_vars = target.amoe_layer_vars();
336
337        // If there are no trainable params, assume we got a gate model id so no training
338        if target.amoe_base_model_trainable_params() == 0 {
339            return Ok(None);
340        }
341
342        info!(
343            "{} gating layers, {} trainable parameters, lr = {lr}, {epochs} epochs, batch size = {batch_size}",
344            layer_vars.len(),
345            target.amoe_base_model_trainable_params()
346        );
347
348        let mut optimizers = layer_vars
349            .into_iter()
350            .map(|vars| {
351                AdamW::new(
352                    vars,
353                    ParamsAdamW {
354                        lr,
355                        beta1: 0.9,
356                        beta2: 0.999,
357                        eps: 1e-8,
358                        weight_decay: 0.0,
359                    },
360                )
361            })
362            .collect::<candle_core::Result<Vec<_>>>()?;
363
364        let mut rng = rng();
365        let mut samples = inputs.into_inner();
366
367        // Create several dummy objects for the sequences. No custom logits processors.
368        let (dummy_sender, _) = tokio::sync::mpsc::channel(10000);
369        let dummy_sampler = Sampler::new(
370            None,
371            0,
372            tokenizer.clone(),
373            None,
374            None,
375            None,
376            None,
377            -1,
378            0.0,
379            0.0,
380            vec![],
381        )
382        .map_err(candle_core::Error::msg)?;
383
384        let dummy_group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
385            1, false, false, None,
386        )));
387
388        let mut latest_loss = vec![0.0; optimizers.len()];
389        let mut all_losses = Vec::new();
390
391        for _ in
392            NiceProgressBar::<_, 'g'>(0..epochs, "Training gating layers", &new_multi_progress())
393        {
394            samples.as_mut_slice().shuffle(&mut rng);
395            for batch in samples.chunks(batch_size) {
396                steps += 1;
397
398                // === PREPARE INPUTS ==
399                let mut seqs = Vec::new();
400                for AnyMoeTrainingInputRow {
401                    prompt,
402                    expert: _,
403                    image_urls,
404                } in batch
405                {
406                    let tokens = processor
407                        .process(
408                            &*target,
409                            vec![IndexMap::from([
410                                ("role".to_string(), Either::Left("user".to_string())),
411                                ("content".to_string(), Either::Left(prompt.clone())),
412                            ])],
413                            true,
414                            true,
415                            None,
416                            Vec::new(),
417                        )
418                        .map_err(candle_core::Error::msg)?;
419                    let images = image_urls.as_ref().map(|urls| {
420                        urls.iter()
421                            .map(|url| -> anyhow::Result<DynamicImage> {
422                                let bytes = if url.contains("http") {
423                                    // Read from http
424                                    match reqwest::blocking::get(url.clone()) {
425                                        Ok(http_resp) => http_resp.bytes()?.to_vec(),
426                                        Err(e) => anyhow::bail!(e),
427                                    }
428                                } else if let Ok(mut f) = File::open(url) {
429                                    // Read from local file
430                                    let metadata = fs::metadata(url)?;
431                                    #[allow(clippy::cast_possible_truncation)]
432                                    let mut buffer = vec![0; metadata.len() as usize];
433                                    f.read_exact(&mut buffer)?;
434                                    buffer
435                                } else {
436                                    // Decode with base64
437                                    general_purpose::STANDARD.decode(url)?
438                                };
439                                Ok(image::load_from_memory(&bytes)?)
440                            })
441                            .collect::<anyhow::Result<Vec<_>>>()
442                    });
443                    let images = match images {
444                        Some(Ok(x)) => Some(x),
445                        Some(Err(e)) => {
446                            return anyhow::Result::Err(candle_core::Error::Msg(e.to_string()))
447                        }
448                        None => None,
449                    };
450                    seqs.push(new_dummy_seq(
451                        tokens,
452                        dummy_sender.clone(),
453                        dummy_sampler.clone(),
454                        dummy_group.clone(),
455                        images,
456                        target.get_metadata().eos_tok.clone(),
457                    ));
458                }
459                let mut input_seqs = seqs.iter_mut().collect::<Vec<_>>();
460
461                // Clear KV cache in prep for training
462                target.set_none_cache(&mut input_seqs, true, true, false);
463
464                let inputs = inputs_processor.process_inputs(
465                    tokenizer.clone(),
466                    &mut input_seqs,
467                    true, // Always a prompt
468                    metadata.is_xlora,
469                    &device,
470                    metadata.no_kv_cache,
471                    None,
472                    false,
473                    input_processor_cfg.clone(),
474                    None, // TODO: get block tables/handle it for PagedAttention
475                    None,
476                );
477
478                // === PREPARE AND RUN MODEL ==
479
480                // Run the model, ignoring the logits
481                let _ = target.forward_inputs(inputs.unwrap().inputs, false)?;
482
483                // Clear the KV cache
484                target.set_none_cache(&mut input_seqs, true, true, false);
485
486                // === BACKWARD STEP ==
487                #[allow(clippy::cast_possible_truncation)]
488                let labels = Tensor::from_vec(
489                    batch
490                        .iter()
491                        .map(
492                            |AnyMoeTrainingInputRow {
493                                 prompt: _,
494                                 expert,
495                                 image_urls: _,
496                             }| *expert as u32,
497                        )
498                        .collect::<Vec<_>>(),
499                    (batch.len(),),
500                    &device,
501                )?;
502
503                let cached = target.amoe_take_cached_gating_outputs();
504                for (layer, (optimizer, output)) in optimizers.iter_mut().zip(cached).enumerate() {
505                    let loss = candle_nn::loss::cross_entropy(
506                        &output,
507                        &labels.to_device(output.device())?,
508                    )?;
509                    let gradstore = loss.backward()?;
510                    optimizer.step(&gradstore)?;
511                    latest_loss[layer] = loss.to_dtype(DType::F32)?.to_scalar::<f32>()?;
512                }
513                all_losses.push(latest_loss.clone());
514            }
515        }
516
517        target.amoe_finish_training(gate_model_id)?;
518        assert_eq!(target.amoe_base_model_trainable_params(), 0);
519
520        if let Some(loss_csv_path) = loss_csv_path {
521            let path = Path::new(&loss_csv_path);
522            if path
523                .extension()
524                .is_none_or(|e| e.to_string_lossy() != *"csv")
525            {
526                candle_core::bail!("`loss_csv_path` must have an extension `csv`.");
527            }
528
529            let mut writer = csv::Writer::from_path(path).map_err(candle_core::Error::msg)?;
530
531            let mut header = vec![format!("Step")];
532            header.extend((0..all_losses[0].len()).map(|i| format!("Gating layer {i}")));
533            writer
534                .write_record(&header)
535                .map_err(candle_core::Error::msg)?;
536
537            for (i, row) in all_losses.into_iter().enumerate() {
538                let mut new_row = vec![format!("Step {i}")];
539                new_row.extend(row.iter().map(|x| format!("{x:.4}")));
540                writer
541                    .write_record(&new_row)
542                    .map_err(candle_core::Error::msg)?;
543            }
544
545            writer.flush().map_err(candle_core::Error::msg)?;
546        }
547
548        Ok(Some(AnyMoeTrainingResult {
549            steps,
550            final_loss: latest_loss,
551        }))
552    }
553}
554
555/// Create a dummy sequence containing just the prompt. This is OK because we just want a sequence that
556/// has no information other than the input tokens (and maybe images).
557fn new_dummy_seq(
558    (tokens, prompt): (Vec<u32>, String),
559    dummy_sender: tokio::sync::mpsc::Sender<Response>,
560    dummy_sampler: Sampler,
561    dummy_group: Arc<tokio::sync::Mutex<SequenceGroup>>,
562    images: Option<Vec<DynamicImage>>,
563    eos_toks: Vec<u32>,
564) -> Sequence {
565    Sequence::new_waiting(
566        tokens,
567        prompt,
568        0,
569        0,
570        1,
571        dummy_sender,
572        dummy_sampler,
573        vec![],
574        vec![],
575        None,
576        false,
577        false,
578        dummy_group,
579        0,
580        0,
581        SequenceRecognizer::None,
582        None,
583        None,
584        images,
585        None,
586        None, // TODO incorrect for PagedAttention
587        None,
588        None,
589        SeqStepType::PromptAndDecode,
590        None,
591        None,
592        false,
593        eos_toks,
594    )
595}