mistralrs/
anymoe.rs

1use mistralrs_core::{
2    initialize_logging, AnyMoeConfig, AnyMoeLoader, AutoDeviceMapParams, DefaultSchedulerMethod,
3    DeviceMapSetting, Loader, MistralRsBuilder, NormalLoaderBuilder, NormalSpecificConfig,
4    SchedulerConfig,
5};
6
7use crate::{best_device, Model, TextModelBuilder};
8
9pub struct AnyMoeModelBuilder {
10    base: TextModelBuilder,
11    config: AnyMoeConfig,
12    path: String,
13    prefix: String,
14    mlp: String,
15    model_ids: Vec<String>,
16    layers: Vec<usize>,
17}
18
19impl AnyMoeModelBuilder {
20    pub fn from_text_builder(
21        base: TextModelBuilder,
22        config: AnyMoeConfig,
23        path: impl ToString,
24        prefix: impl ToString,
25        mlp: impl ToString,
26        model_ids: Vec<impl ToString>,
27        layers: Vec<usize>,
28    ) -> Self {
29        Self {
30            base,
31            config,
32            path: path.to_string(),
33            prefix: prefix.to_string(),
34            mlp: mlp.to_string(),
35            model_ids: model_ids
36                .into_iter()
37                .map(|f| f.to_string())
38                .collect::<Vec<_>>(),
39            layers,
40        }
41    }
42
43    pub async fn build(self) -> anyhow::Result<Model> {
44        let config = NormalSpecificConfig {
45            prompt_chunksize: self.base.prompt_chunksize,
46            topology: self.base.topology,
47            organization: self.base.organization,
48            write_uqff: self.base.write_uqff,
49            from_uqff: self.base.from_uqff,
50            imatrix: None,
51            calibration_file: None,
52            hf_cache_path: self.base.hf_cache_path,
53            matformer_config_path: None,
54            matformer_slice_name: None,
55        };
56
57        if self.base.with_logging {
58            initialize_logging();
59        }
60
61        let loader = NormalLoaderBuilder::new(
62            config,
63            self.base.chat_template,
64            self.base.tokenizer_json,
65            Some(self.base.model_id),
66            self.base.no_kv_cache,
67            self.base.jinja_explicit,
68        )
69        .build(self.base.loader_type)?;
70
71        let loader: Box<dyn Loader> = Box::new(AnyMoeLoader {
72            target: loader,
73            config: self.config,
74            prefix: self.prefix,
75            mlp: self.mlp,
76            path: self.path,
77            model_ids: self.model_ids,
78            layers: self.layers,
79        });
80
81        // Load, into a Pipeline
82        let pipeline = loader.load_model_from_hf(
83            self.base.hf_revision,
84            self.base.token_source,
85            &self.base.dtype,
86            &best_device(self.base.force_cpu)?,
87            !self.base.with_logging,
88            self.base
89                .device_mapping
90                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
91            self.base.isq,
92            self.base.paged_attn_cfg,
93        )?;
94
95        let scheduler_method = match self.base.paged_attn_cfg {
96            Some(_) => {
97                let config = pipeline
98                    .lock()
99                    .await
100                    .get_metadata()
101                    .cache_config
102                    .as_ref()
103                    .unwrap()
104                    .clone();
105
106                SchedulerConfig::PagedAttentionMeta {
107                    max_num_seqs: self.base.max_num_seqs,
108                    config,
109                }
110            }
111            None => SchedulerConfig::DefaultScheduler {
112                method: DefaultSchedulerMethod::Fixed(self.base.max_num_seqs.try_into()?),
113            },
114        };
115
116        let mut runner = MistralRsBuilder::new(
117            pipeline,
118            scheduler_method,
119            self.base.throughput_logging,
120            self.base.search_bert_model,
121        );
122        if let Some(cb) = self.base.search_callback.clone() {
123            runner = runner.with_search_callback(cb);
124        }
125        for (name, cb) in &self.base.tool_callbacks {
126            runner = runner.with_tool_callback(name.clone(), cb.clone());
127        }
128        runner = runner
129            .with_no_kv_cache(self.base.no_kv_cache)
130            .with_no_prefix_cache(self.base.prefix_cache_n.is_none());
131
132        if let Some(n) = self.base.prefix_cache_n {
133            runner = runner.with_prefix_cache_n(n)
134        }
135
136        Ok(Model::new(runner.build().await))
137    }
138}