1use mistralrs_core::{
2 initialize_logging, AnyMoeConfig, AnyMoeLoader, AutoDeviceMapParams, DefaultSchedulerMethod,
3 DeviceMapSetting, Loader, MistralRsBuilder, NormalLoaderBuilder, NormalSpecificConfig,
4 SchedulerConfig,
5};
6
7use crate::{best_device, Model, TextModelBuilder};
8
9pub struct AnyMoeModelBuilder {
10 base: TextModelBuilder,
11 config: AnyMoeConfig,
12 path: String,
13 prefix: String,
14 mlp: String,
15 model_ids: Vec<String>,
16 layers: Vec<usize>,
17}
18
19impl AnyMoeModelBuilder {
20 pub fn from_text_builder(
21 base: TextModelBuilder,
22 config: AnyMoeConfig,
23 path: impl ToString,
24 prefix: impl ToString,
25 mlp: impl ToString,
26 model_ids: Vec<impl ToString>,
27 layers: Vec<usize>,
28 ) -> Self {
29 Self {
30 base,
31 config,
32 path: path.to_string(),
33 prefix: prefix.to_string(),
34 mlp: mlp.to_string(),
35 model_ids: model_ids
36 .into_iter()
37 .map(|f| f.to_string())
38 .collect::<Vec<_>>(),
39 layers,
40 }
41 }
42
43 pub async fn build(self) -> anyhow::Result<Model> {
44 let config = NormalSpecificConfig {
45 prompt_chunksize: self.base.prompt_chunksize,
46 topology: self.base.topology,
47 organization: self.base.organization,
48 write_uqff: self.base.write_uqff,
49 from_uqff: self.base.from_uqff,
50 imatrix: None,
51 calibration_file: None,
52 hf_cache_path: self.base.hf_cache_path,
53 };
54
55 if self.base.with_logging {
56 initialize_logging();
57 }
58
59 let loader = NormalLoaderBuilder::new(
60 config,
61 self.base.chat_template,
62 self.base.tokenizer_json,
63 Some(self.base.model_id),
64 self.base.no_kv_cache,
65 self.base.jinja_explicit,
66 )
67 .build(self.base.loader_type)?;
68
69 let loader: Box<dyn Loader> = Box::new(AnyMoeLoader {
70 target: loader,
71 config: self.config,
72 prefix: self.prefix,
73 mlp: self.mlp,
74 path: self.path,
75 model_ids: self.model_ids,
76 layers: self.layers,
77 });
78
79 let pipeline = loader.load_model_from_hf(
81 self.base.hf_revision,
82 self.base.token_source,
83 &self.base.dtype,
84 &best_device(self.base.force_cpu)?,
85 !self.base.with_logging,
86 self.base
87 .device_mapping
88 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
89 self.base.isq,
90 self.base.paged_attn_cfg,
91 )?;
92
93 let scheduler_method = match self.base.paged_attn_cfg {
94 Some(_) => {
95 let config = pipeline
96 .lock()
97 .await
98 .get_metadata()
99 .cache_config
100 .as_ref()
101 .unwrap()
102 .clone();
103
104 SchedulerConfig::PagedAttentionMeta {
105 max_num_seqs: self.base.max_num_seqs,
106 config,
107 }
108 }
109 None => SchedulerConfig::DefaultScheduler {
110 method: DefaultSchedulerMethod::Fixed(self.base.max_num_seqs.try_into()?),
111 },
112 };
113
114 let mut runner = MistralRsBuilder::new(
115 pipeline,
116 scheduler_method,
117 self.base.throughput_logging,
118 self.base.search_bert_model,
119 )
120 .with_no_kv_cache(self.base.no_kv_cache)
121 .with_no_prefix_cache(self.base.prefix_cache_n.is_none());
122
123 if let Some(n) = self.base.prefix_cache_n {
124 runner = runner.with_prefix_cache_n(n)
125 }
126
127 Ok(Model::new(runner.build()))
128 }
129}