1use mistralrs_core::{
2 initialize_logging, AnyMoeConfig, AnyMoeLoader, AutoDeviceMapParams, DefaultSchedulerMethod,
3 DeviceMapSetting, Loader, MistralRsBuilder, NormalLoaderBuilder, NormalSpecificConfig,
4 SchedulerConfig,
5};
6
7use crate::{best_device, Model, TextModelBuilder};
8
9pub struct AnyMoeModelBuilder {
10 base: TextModelBuilder,
11 config: AnyMoeConfig,
12 path: String,
13 prefix: String,
14 mlp: String,
15 model_ids: Vec<String>,
16 layers: Vec<usize>,
17}
18
19impl AnyMoeModelBuilder {
20 pub fn from_text_builder(
21 base: TextModelBuilder,
22 config: AnyMoeConfig,
23 path: impl ToString,
24 prefix: impl ToString,
25 mlp: impl ToString,
26 model_ids: Vec<impl ToString>,
27 layers: Vec<usize>,
28 ) -> Self {
29 Self {
30 base,
31 config,
32 path: path.to_string(),
33 prefix: prefix.to_string(),
34 mlp: mlp.to_string(),
35 model_ids: model_ids
36 .into_iter()
37 .map(|f| f.to_string())
38 .collect::<Vec<_>>(),
39 layers,
40 }
41 }
42
43 pub async fn build(self) -> anyhow::Result<Model> {
44 let config = NormalSpecificConfig {
45 prompt_chunksize: self.base.prompt_chunksize,
46 topology: self.base.topology,
47 organization: self.base.organization,
48 write_uqff: self.base.write_uqff,
49 from_uqff: self.base.from_uqff,
50 imatrix: None,
51 calibration_file: None,
52 hf_cache_path: self.base.hf_cache_path,
53 matformer_config_path: None,
54 matformer_slice_name: None,
55 };
56
57 if self.base.with_logging {
58 initialize_logging();
59 }
60
61 let loader = NormalLoaderBuilder::new(
62 config,
63 self.base.chat_template,
64 self.base.tokenizer_json,
65 Some(self.base.model_id),
66 self.base.no_kv_cache,
67 self.base.jinja_explicit,
68 )
69 .build(self.base.loader_type)?;
70
71 let loader: Box<dyn Loader> = Box::new(AnyMoeLoader {
72 target: loader,
73 config: self.config,
74 prefix: self.prefix,
75 mlp: self.mlp,
76 path: self.path,
77 model_ids: self.model_ids,
78 layers: self.layers,
79 });
80
81 let pipeline = loader.load_model_from_hf(
83 self.base.hf_revision,
84 self.base.token_source,
85 &self.base.dtype,
86 &best_device(self.base.force_cpu)?,
87 !self.base.with_logging,
88 self.base
89 .device_mapping
90 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
91 self.base.isq,
92 self.base.paged_attn_cfg,
93 )?;
94
95 let scheduler_method = match self.base.paged_attn_cfg {
96 Some(_) => {
97 let config = pipeline
98 .lock()
99 .await
100 .get_metadata()
101 .cache_config
102 .as_ref()
103 .unwrap()
104 .clone();
105
106 SchedulerConfig::PagedAttentionMeta {
107 max_num_seqs: self.base.max_num_seqs,
108 config,
109 }
110 }
111 None => SchedulerConfig::DefaultScheduler {
112 method: DefaultSchedulerMethod::Fixed(self.base.max_num_seqs.try_into()?),
113 },
114 };
115
116 let mut runner = MistralRsBuilder::new(
117 pipeline,
118 scheduler_method,
119 self.base.throughput_logging,
120 self.base.search_bert_model,
121 );
122 if let Some(cb) = self.base.search_callback.clone() {
123 runner = runner.with_search_callback(cb);
124 }
125 for (name, cb) in &self.base.tool_callbacks {
126 runner = runner.with_tool_callback(name.clone(), cb.clone());
127 }
128 runner = runner
129 .with_no_kv_cache(self.base.no_kv_cache)
130 .with_no_prefix_cache(self.base.prefix_cache_n.is_none());
131
132 if let Some(n) = self.base.prefix_cache_n {
133 runner = runner.with_prefix_cache_n(n)
134 }
135
136 Ok(Model::new(runner.build().await))
137 }
138}