mistralrs/
anymoe.rs

1use mistralrs_core::{
2    initialize_logging, AnyMoeConfig, AnyMoeLoader, AutoDeviceMapParams, DefaultSchedulerMethod,
3    DeviceMapSetting, Loader, MistralRsBuilder, NormalLoaderBuilder, NormalSpecificConfig,
4    SchedulerConfig,
5};
6
7use crate::{best_device, Model, TextModelBuilder};
8
9pub struct AnyMoeModelBuilder {
10    base: TextModelBuilder,
11    config: AnyMoeConfig,
12    path: String,
13    prefix: String,
14    mlp: String,
15    model_ids: Vec<String>,
16    layers: Vec<usize>,
17}
18
19impl AnyMoeModelBuilder {
20    pub fn from_text_builder(
21        base: TextModelBuilder,
22        config: AnyMoeConfig,
23        path: impl ToString,
24        prefix: impl ToString,
25        mlp: impl ToString,
26        model_ids: Vec<impl ToString>,
27        layers: Vec<usize>,
28    ) -> Self {
29        Self {
30            base,
31            config,
32            path: path.to_string(),
33            prefix: prefix.to_string(),
34            mlp: mlp.to_string(),
35            model_ids: model_ids
36                .into_iter()
37                .map(|f| f.to_string())
38                .collect::<Vec<_>>(),
39            layers,
40        }
41    }
42
43    pub async fn build(self) -> anyhow::Result<Model> {
44        let config = NormalSpecificConfig {
45            use_flash_attn: self.base.use_flash_attn,
46            prompt_chunksize: self.base.prompt_chunksize,
47            topology: self.base.topology,
48            organization: self.base.organization,
49            write_uqff: self.base.write_uqff,
50            from_uqff: self.base.from_uqff,
51            imatrix: None,
52            calibration_file: None,
53            hf_cache_path: self.base.hf_cache_path,
54        };
55
56        if self.base.with_logging {
57            initialize_logging();
58        }
59
60        let loader = NormalLoaderBuilder::new(
61            config,
62            self.base.chat_template,
63            self.base.tokenizer_json,
64            Some(self.base.model_id),
65            self.base.no_kv_cache,
66            self.base.jinja_explicit,
67        )
68        .build(self.base.loader_type)?;
69
70        let loader: Box<dyn Loader> = Box::new(AnyMoeLoader {
71            target: loader,
72            config: self.config,
73            prefix: self.prefix,
74            mlp: self.mlp,
75            path: self.path,
76            model_ids: self.model_ids,
77            layers: self.layers,
78        });
79
80        // Load, into a Pipeline
81        let pipeline = loader.load_model_from_hf(
82            self.base.hf_revision,
83            self.base.token_source,
84            &self.base.dtype,
85            &best_device(self.base.force_cpu)?,
86            !self.base.with_logging,
87            self.base
88                .device_mapping
89                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
90            self.base.isq,
91            self.base.paged_attn_cfg,
92        )?;
93
94        let scheduler_method = match self.base.paged_attn_cfg {
95            Some(_) => {
96                let config = pipeline
97                    .lock()
98                    .await
99                    .get_metadata()
100                    .cache_config
101                    .as_ref()
102                    .unwrap()
103                    .clone();
104
105                SchedulerConfig::PagedAttentionMeta {
106                    max_num_seqs: self.base.max_num_seqs,
107                    config,
108                }
109            }
110            None => SchedulerConfig::DefaultScheduler {
111                method: DefaultSchedulerMethod::Fixed(self.base.max_num_seqs.try_into()?),
112            },
113        };
114
115        let mut runner = MistralRsBuilder::new(
116            pipeline,
117            scheduler_method,
118            self.base.throughput_logging,
119            self.base.search_bert_model,
120        )
121        .with_no_kv_cache(self.base.no_kv_cache)
122        .with_no_prefix_cache(self.base.prefix_cache_n.is_none());
123
124        if let Some(n) = self.base.prefix_cache_n {
125            runner = runner.with_prefix_cache_n(n)
126        }
127
128        Ok(Model::new(runner.build()))
129    }
130}