mistralrs/
anymoe.rs

1use mistralrs_core::{
2    initialize_logging, AnyMoeConfig, AnyMoeLoader, AutoDeviceMapParams, DefaultSchedulerMethod,
3    DeviceMapSetting, Loader, MistralRsBuilder, NormalLoaderBuilder, NormalSpecificConfig,
4    SchedulerConfig,
5};
6
7use crate::{best_device, Model, TextModelBuilder};
8
9pub struct AnyMoeModelBuilder {
10    base: TextModelBuilder,
11    config: AnyMoeConfig,
12    path: String,
13    prefix: String,
14    mlp: String,
15    model_ids: Vec<String>,
16    layers: Vec<usize>,
17}
18
19impl AnyMoeModelBuilder {
20    pub fn from_text_builder(
21        base: TextModelBuilder,
22        config: AnyMoeConfig,
23        path: impl ToString,
24        prefix: impl ToString,
25        mlp: impl ToString,
26        model_ids: Vec<impl ToString>,
27        layers: Vec<usize>,
28    ) -> Self {
29        Self {
30            base,
31            config,
32            path: path.to_string(),
33            prefix: prefix.to_string(),
34            mlp: mlp.to_string(),
35            model_ids: model_ids
36                .into_iter()
37                .map(|f| f.to_string())
38                .collect::<Vec<_>>(),
39            layers,
40        }
41    }
42
43    pub async fn build(self) -> anyhow::Result<Model> {
44        let config = NormalSpecificConfig {
45            prompt_chunksize: self.base.prompt_chunksize,
46            topology: self.base.topology,
47            organization: self.base.organization,
48            write_uqff: self.base.write_uqff,
49            from_uqff: self.base.from_uqff,
50            imatrix: None,
51            calibration_file: None,
52            hf_cache_path: self.base.hf_cache_path,
53        };
54
55        if self.base.with_logging {
56            initialize_logging();
57        }
58
59        let loader = NormalLoaderBuilder::new(
60            config,
61            self.base.chat_template,
62            self.base.tokenizer_json,
63            Some(self.base.model_id),
64            self.base.no_kv_cache,
65            self.base.jinja_explicit,
66        )
67        .build(self.base.loader_type)?;
68
69        let loader: Box<dyn Loader> = Box::new(AnyMoeLoader {
70            target: loader,
71            config: self.config,
72            prefix: self.prefix,
73            mlp: self.mlp,
74            path: self.path,
75            model_ids: self.model_ids,
76            layers: self.layers,
77        });
78
79        // Load, into a Pipeline
80        let pipeline = loader.load_model_from_hf(
81            self.base.hf_revision,
82            self.base.token_source,
83            &self.base.dtype,
84            &best_device(self.base.force_cpu)?,
85            !self.base.with_logging,
86            self.base
87                .device_mapping
88                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
89            self.base.isq,
90            self.base.paged_attn_cfg,
91        )?;
92
93        let scheduler_method = match self.base.paged_attn_cfg {
94            Some(_) => {
95                let config = pipeline
96                    .lock()
97                    .await
98                    .get_metadata()
99                    .cache_config
100                    .as_ref()
101                    .unwrap()
102                    .clone();
103
104                SchedulerConfig::PagedAttentionMeta {
105                    max_num_seqs: self.base.max_num_seqs,
106                    config,
107                }
108            }
109            None => SchedulerConfig::DefaultScheduler {
110                method: DefaultSchedulerMethod::Fixed(self.base.max_num_seqs.try_into()?),
111            },
112        };
113
114        let mut runner = MistralRsBuilder::new(
115            pipeline,
116            scheduler_method,
117            self.base.throughput_logging,
118            self.base.search_bert_model,
119        )
120        .with_no_kv_cache(self.base.no_kv_cache)
121        .with_no_prefix_cache(self.base.prefix_cache_n.is_none());
122
123        if let Some(n) = self.base.prefix_cache_n {
124            runner = runner.with_prefix_cache_n(n)
125        }
126
127        Ok(Model::new(runner.build()))
128    }
129}