1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5
6use crate::{best_device, Model};
7use std::sync::Arc;
8
9#[derive(Clone)]
11pub struct ToolCallbackWithTool {
12 pub callback: Arc<ToolCallback>,
13 pub tool: Tool,
14}
15
16pub struct GgufModelBuilder {
18 pub(crate) model_id: String,
20 pub(crate) files: Vec<String>,
21 pub(crate) tok_model_id: Option<String>,
22 pub(crate) token_source: TokenSource,
23 pub(crate) hf_revision: Option<String>,
24 pub(crate) chat_template: Option<String>,
25 pub(crate) jinja_explicit: Option<String>,
26 pub(crate) tokenizer_json: Option<String>,
27 pub(crate) device_mapping: Option<DeviceMapSetting>,
28 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
29 pub(crate) search_callback: Option<Arc<SearchCallback>>,
30 pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
31 pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
32 pub(crate) device: Option<Device>,
33
34 pub(crate) force_cpu: bool,
36 pub(crate) topology: Option<Topology>,
37 pub(crate) throughput_logging: bool,
38
39 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
41 pub(crate) max_num_seqs: usize,
42 pub(crate) no_kv_cache: bool,
43 pub(crate) with_logging: bool,
44 pub(crate) prefix_cache_n: Option<usize>,
45}
46
47impl GgufModelBuilder {
48 pub fn new(model_id: impl ToString, files: Vec<impl ToString>) -> Self {
55 Self {
56 model_id: model_id.to_string(),
57 files: files.into_iter().map(|f| f.to_string()).collect::<Vec<_>>(),
58 chat_template: None,
59 tokenizer_json: None,
60 force_cpu: false,
61 token_source: TokenSource::CacheToken,
62 hf_revision: None,
63 paged_attn_cfg: None,
64 max_num_seqs: 32,
65 no_kv_cache: false,
66 prefix_cache_n: Some(16),
67 with_logging: false,
68 topology: None,
69 tok_model_id: None,
70 device_mapping: None,
71 jinja_explicit: None,
72 throughput_logging: false,
73 search_bert_model: None,
74 search_callback: None,
75 tool_callbacks: HashMap::new(),
76 tool_callbacks_with_tools: HashMap::new(),
77 device: None,
78 }
79 }
80
81 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
83 self.search_bert_model = Some(search_bert_model);
84 self
85 }
86
87 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
89 self.search_callback = Some(callback);
90 self
91 }
92
93 pub fn with_tool_callback(
94 mut self,
95 name: impl Into<String>,
96 callback: Arc<ToolCallback>,
97 ) -> Self {
98 self.tool_callbacks.insert(name.into(), callback);
99 self
100 }
101
102 pub fn with_tool_callback_and_tool(
105 mut self,
106 name: impl Into<String>,
107 callback: Arc<ToolCallback>,
108 tool: Tool,
109 ) -> Self {
110 let name = name.into();
111 self.tool_callbacks_with_tools
112 .insert(name, ToolCallbackWithTool { callback, tool });
113 self
114 }
115
116 pub fn with_throughput_logging(mut self) -> Self {
118 self.throughput_logging = true;
119 self
120 }
121
122 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
124 self.jinja_explicit = Some(jinja_explicit);
125 self
126 }
127
128 pub fn with_tok_model_id(mut self, tok_model_id: impl ToString) -> Self {
130 self.tok_model_id = Some(tok_model_id.to_string());
131 self
132 }
133
134 pub fn with_topology(mut self, topology: Topology) -> Self {
136 self.topology = Some(topology);
137 self
138 }
139
140 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
142 self.chat_template = Some(chat_template.to_string());
143 self
144 }
145
146 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
148 self.tokenizer_json = Some(tokenizer_json.to_string());
149 self
150 }
151
152 pub fn with_force_cpu(mut self) -> Self {
154 self.force_cpu = true;
155 self
156 }
157
158 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
160 self.token_source = token_source;
161 self
162 }
163
164 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
166 self.hf_revision = Some(revision.to_string());
167 self
168 }
169
170 pub fn with_paged_attn(
177 mut self,
178 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
179 ) -> anyhow::Result<Self> {
180 if paged_attn_supported() {
181 self.paged_attn_cfg = Some(paged_attn_cfg()?);
182 } else {
183 self.paged_attn_cfg = None;
184 }
185 Ok(self)
186 }
187
188 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
190 self.max_num_seqs = max_num_seqs;
191 self
192 }
193
194 pub fn with_no_kv_cache(mut self) -> Self {
196 self.no_kv_cache = true;
197 self
198 }
199
200 pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
202 self.prefix_cache_n = n_seqs;
203 self
204 }
205
206 pub fn with_logging(mut self) -> Self {
208 self.with_logging = true;
209 self
210 }
211
212 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
214 self.device_mapping = Some(device_mapping);
215 self
216 }
217
218 pub fn with_device(mut self, device: Device) -> Self {
220 self.device = Some(device);
221 self
222 }
223
224 pub async fn build(self) -> anyhow::Result<Model> {
225 let config = GGUFSpecificConfig {
226 topology: self.topology,
227 };
228
229 if self.with_logging {
230 initialize_logging();
231 }
232
233 let loader = GGUFLoaderBuilder::new(
234 self.chat_template,
235 self.tok_model_id,
236 self.model_id,
237 self.files,
238 config,
239 self.no_kv_cache,
240 self.jinja_explicit,
241 )
242 .build();
243
244 let pipeline = loader.load_model_from_hf(
246 self.hf_revision,
247 self.token_source,
248 &ModelDType::Auto,
249 &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
250 !self.with_logging,
251 self.device_mapping
252 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
253 None,
254 self.paged_attn_cfg,
255 )?;
256
257 let scheduler_method = match self.paged_attn_cfg {
258 Some(_) => {
259 let config = pipeline
260 .lock()
261 .await
262 .get_metadata()
263 .cache_config
264 .as_ref()
265 .unwrap()
266 .clone();
267
268 SchedulerConfig::PagedAttentionMeta {
269 max_num_seqs: self.max_num_seqs,
270 config,
271 }
272 }
273 None => SchedulerConfig::DefaultScheduler {
274 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
275 },
276 };
277
278 let mut runner = MistralRsBuilder::new(
279 pipeline,
280 scheduler_method,
281 self.throughput_logging,
282 self.search_bert_model,
283 );
284 if let Some(cb) = self.search_callback.clone() {
285 runner = runner.with_search_callback(cb);
286 }
287 for (name, cb) in &self.tool_callbacks {
288 runner = runner.with_tool_callback(name.clone(), cb.clone());
289 }
290 for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
291 runner = runner.with_tool_callback_and_tool(
292 name.clone(),
293 callback_with_tool.callback.clone(),
294 callback_with_tool.tool.clone(),
295 );
296 }
297 runner = runner
298 .with_no_kv_cache(self.no_kv_cache)
299 .with_no_prefix_cache(self.prefix_cache_n.is_none());
300
301 if let Some(n) = self.prefix_cache_n {
302 runner = runner.with_prefix_cache_n(n)
303 }
304
305 Ok(Model::new(runner.build().await))
306 }
307}