1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::num::NonZeroUsize;
6
7use crate::{best_device, Model};
8use std::sync::Arc;
9
10#[derive(Clone)]
12pub struct ToolCallbackWithTool {
13 pub callback: Arc<ToolCallback>,
14 pub tool: Tool,
15}
16
17pub struct GgufModelBuilder {
19 pub(crate) model_id: String,
21 pub(crate) files: Vec<String>,
22 pub(crate) tok_model_id: Option<String>,
23 pub(crate) token_source: TokenSource,
24 pub(crate) hf_revision: Option<String>,
25 pub(crate) chat_template: Option<String>,
26 pub(crate) jinja_explicit: Option<String>,
27 pub(crate) tokenizer_json: Option<String>,
28 pub(crate) device_mapping: Option<DeviceMapSetting>,
29 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
30 pub(crate) search_callback: Option<Arc<SearchCallback>>,
31 pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
32 pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
33 pub(crate) device: Option<Device>,
34
35 pub(crate) prompt_chunksize: Option<NonZeroUsize>,
37 pub(crate) force_cpu: bool,
38 pub(crate) topology: Option<Topology>,
39 pub(crate) throughput_logging: bool,
40
41 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
43 pub(crate) max_num_seqs: usize,
44 pub(crate) no_kv_cache: bool,
45 pub(crate) with_logging: bool,
46 pub(crate) prefix_cache_n: Option<usize>,
47}
48
49impl GgufModelBuilder {
50 pub fn new(model_id: impl ToString, files: Vec<impl ToString>) -> Self {
57 Self {
58 model_id: model_id.to_string(),
59 files: files.into_iter().map(|f| f.to_string()).collect::<Vec<_>>(),
60 prompt_chunksize: None,
61 chat_template: None,
62 tokenizer_json: None,
63 force_cpu: false,
64 token_source: TokenSource::CacheToken,
65 hf_revision: None,
66 paged_attn_cfg: None,
67 max_num_seqs: 32,
68 no_kv_cache: false,
69 prefix_cache_n: Some(16),
70 with_logging: false,
71 topology: None,
72 tok_model_id: None,
73 device_mapping: None,
74 jinja_explicit: None,
75 throughput_logging: false,
76 search_bert_model: None,
77 search_callback: None,
78 tool_callbacks: HashMap::new(),
79 tool_callbacks_with_tools: HashMap::new(),
80 device: None,
81 }
82 }
83
84 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
86 self.search_bert_model = Some(search_bert_model);
87 self
88 }
89
90 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
92 self.search_callback = Some(callback);
93 self
94 }
95
96 pub fn with_tool_callback(
97 mut self,
98 name: impl Into<String>,
99 callback: Arc<ToolCallback>,
100 ) -> Self {
101 self.tool_callbacks.insert(name.into(), callback);
102 self
103 }
104
105 pub fn with_tool_callback_and_tool(
108 mut self,
109 name: impl Into<String>,
110 callback: Arc<ToolCallback>,
111 tool: Tool,
112 ) -> Self {
113 let name = name.into();
114 self.tool_callbacks_with_tools
115 .insert(name, ToolCallbackWithTool { callback, tool });
116 self
117 }
118
119 pub fn with_throughput_logging(mut self) -> Self {
121 self.throughput_logging = true;
122 self
123 }
124
125 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
127 self.jinja_explicit = Some(jinja_explicit);
128 self
129 }
130
131 pub fn with_tok_model_id(mut self, tok_model_id: impl ToString) -> Self {
133 self.tok_model_id = Some(tok_model_id.to_string());
134 self
135 }
136
137 pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
139 self.prompt_chunksize = Some(prompt_chunksize);
140 self
141 }
142
143 pub fn with_topology(mut self, topology: Topology) -> Self {
145 self.topology = Some(topology);
146 self
147 }
148
149 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
151 self.chat_template = Some(chat_template.to_string());
152 self
153 }
154
155 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
157 self.tokenizer_json = Some(tokenizer_json.to_string());
158 self
159 }
160
161 pub fn with_force_cpu(mut self) -> Self {
163 self.force_cpu = true;
164 self
165 }
166
167 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
169 self.token_source = token_source;
170 self
171 }
172
173 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
175 self.hf_revision = Some(revision.to_string());
176 self
177 }
178
179 pub fn with_paged_attn(
186 mut self,
187 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
188 ) -> anyhow::Result<Self> {
189 if paged_attn_supported() {
190 self.paged_attn_cfg = Some(paged_attn_cfg()?);
191 } else {
192 self.paged_attn_cfg = None;
193 }
194 Ok(self)
195 }
196
197 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
199 self.max_num_seqs = max_num_seqs;
200 self
201 }
202
203 pub fn with_no_kv_cache(mut self) -> Self {
205 self.no_kv_cache = true;
206 self
207 }
208
209 pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
211 self.prefix_cache_n = n_seqs;
212 self
213 }
214
215 pub fn with_logging(mut self) -> Self {
217 self.with_logging = true;
218 self
219 }
220
221 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
223 self.device_mapping = Some(device_mapping);
224 self
225 }
226
227 pub fn with_device(mut self, device: Device) -> Self {
229 self.device = Some(device);
230 self
231 }
232
233 pub async fn build(self) -> anyhow::Result<Model> {
234 let config = GGUFSpecificConfig {
235 prompt_chunksize: self.prompt_chunksize,
236 topology: self.topology,
237 };
238
239 if self.with_logging {
240 initialize_logging();
241 }
242
243 let loader = GGUFLoaderBuilder::new(
244 self.chat_template,
245 self.tok_model_id,
246 self.model_id,
247 self.files,
248 config,
249 self.no_kv_cache,
250 self.jinja_explicit,
251 )
252 .build();
253
254 let pipeline = loader.load_model_from_hf(
256 self.hf_revision,
257 self.token_source,
258 &ModelDType::Auto,
259 &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
260 !self.with_logging,
261 self.device_mapping
262 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
263 None,
264 self.paged_attn_cfg,
265 )?;
266
267 let scheduler_method = match self.paged_attn_cfg {
268 Some(_) => {
269 let config = pipeline
270 .lock()
271 .await
272 .get_metadata()
273 .cache_config
274 .as_ref()
275 .unwrap()
276 .clone();
277
278 SchedulerConfig::PagedAttentionMeta {
279 max_num_seqs: self.max_num_seqs,
280 config,
281 }
282 }
283 None => SchedulerConfig::DefaultScheduler {
284 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
285 },
286 };
287
288 let mut runner = MistralRsBuilder::new(
289 pipeline,
290 scheduler_method,
291 self.throughput_logging,
292 self.search_bert_model,
293 );
294 if let Some(cb) = self.search_callback.clone() {
295 runner = runner.with_search_callback(cb);
296 }
297 for (name, cb) in &self.tool_callbacks {
298 runner = runner.with_tool_callback(name.clone(), cb.clone());
299 }
300 for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
301 runner = runner.with_tool_callback_and_tool(
302 name.clone(),
303 callback_with_tool.callback.clone(),
304 callback_with_tool.tool.clone(),
305 );
306 }
307 runner = runner
308 .with_no_kv_cache(self.no_kv_cache)
309 .with_no_prefix_cache(self.prefix_cache_n.is_none());
310
311 if let Some(n) = self.prefix_cache_n {
312 runner = runner.with_prefix_cache_n(n)
313 }
314
315 Ok(Model::new(runner.build().await))
316 }
317}