1use mistralrs_core::{
2 initialize_logging, AnyMoeConfig, AnyMoeLoader, AutoDeviceMapParams, DefaultSchedulerMethod,
3 DeviceMapSetting, Loader, MistralRsBuilder, NormalLoaderBuilder, NormalSpecificConfig,
4 SchedulerConfig,
5};
6
7use crate::{best_device, Model, TextModelBuilder};
8
9pub struct AnyMoeModelBuilder {
10 base: TextModelBuilder,
11 config: AnyMoeConfig,
12 path: String,
13 prefix: String,
14 mlp: String,
15 model_ids: Vec<String>,
16 layers: Vec<usize>,
17}
18
19impl AnyMoeModelBuilder {
20 pub fn from_text_builder(
21 base: TextModelBuilder,
22 config: AnyMoeConfig,
23 path: impl ToString,
24 prefix: impl ToString,
25 mlp: impl ToString,
26 model_ids: Vec<impl ToString>,
27 layers: Vec<usize>,
28 ) -> Self {
29 Self {
30 base,
31 config,
32 path: path.to_string(),
33 prefix: prefix.to_string(),
34 mlp: mlp.to_string(),
35 model_ids: model_ids
36 .into_iter()
37 .map(|f| f.to_string())
38 .collect::<Vec<_>>(),
39 layers,
40 }
41 }
42
43 pub async fn build(self) -> anyhow::Result<Model> {
44 let config = NormalSpecificConfig {
45 use_flash_attn: self.base.use_flash_attn,
46 prompt_chunksize: self.base.prompt_chunksize,
47 topology: self.base.topology,
48 organization: self.base.organization,
49 write_uqff: self.base.write_uqff,
50 from_uqff: self.base.from_uqff,
51 imatrix: None,
52 calibration_file: None,
53 hf_cache_path: self.base.hf_cache_path,
54 };
55
56 if self.base.with_logging {
57 initialize_logging();
58 }
59
60 let loader = NormalLoaderBuilder::new(
61 config,
62 self.base.chat_template,
63 self.base.tokenizer_json,
64 Some(self.base.model_id),
65 self.base.no_kv_cache,
66 self.base.jinja_explicit,
67 )
68 .build(self.base.loader_type)?;
69
70 let loader: Box<dyn Loader> = Box::new(AnyMoeLoader {
71 target: loader,
72 config: self.config,
73 prefix: self.prefix,
74 mlp: self.mlp,
75 path: self.path,
76 model_ids: self.model_ids,
77 layers: self.layers,
78 });
79
80 let pipeline = loader.load_model_from_hf(
82 self.base.hf_revision,
83 self.base.token_source,
84 &self.base.dtype,
85 &best_device(self.base.force_cpu)?,
86 !self.base.with_logging,
87 self.base
88 .device_mapping
89 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
90 self.base.isq,
91 self.base.paged_attn_cfg,
92 )?;
93
94 let scheduler_method = match self.base.paged_attn_cfg {
95 Some(_) => {
96 let config = pipeline
97 .lock()
98 .await
99 .get_metadata()
100 .cache_config
101 .as_ref()
102 .unwrap()
103 .clone();
104
105 SchedulerConfig::PagedAttentionMeta {
106 max_num_seqs: self.base.max_num_seqs,
107 config,
108 }
109 }
110 None => SchedulerConfig::DefaultScheduler {
111 method: DefaultSchedulerMethod::Fixed(self.base.max_num_seqs.try_into()?),
112 },
113 };
114
115 let mut runner = MistralRsBuilder::new(
116 pipeline,
117 scheduler_method,
118 self.base.throughput_logging,
119 self.base.search_bert_model,
120 )
121 .with_no_kv_cache(self.base.no_kv_cache)
122 .with_no_prefix_cache(self.base.prefix_cache_n.is_none());
123
124 if let Some(n) = self.base.prefix_cache_n {
125 runner = runner.with_prefix_cache_n(n)
126 }
127
128 Ok(Model::new(runner.build()))
129 }
130}