1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6 ops::{Deref, DerefMut},
7 path::PathBuf,
8 sync::Arc,
9};
10
11use crate::{best_device, Model};
12
13#[derive(Clone)]
15pub struct ToolCallbackWithTool {
16 pub callback: Arc<ToolCallback>,
17 pub tool: Tool,
18}
19
20#[derive(Clone)]
21pub struct TextModelBuilder {
23 pub(crate) model_id: String,
25 pub(crate) token_source: TokenSource,
26 pub(crate) hf_revision: Option<String>,
27 pub(crate) write_uqff: Option<PathBuf>,
28 pub(crate) from_uqff: Option<Vec<PathBuf>>,
29 pub(crate) imatrix: Option<PathBuf>,
30 pub(crate) calibration_file: Option<PathBuf>,
31 pub(crate) chat_template: Option<String>,
32 pub(crate) jinja_explicit: Option<String>,
33 pub(crate) tokenizer_json: Option<String>,
34 pub(crate) device_mapping: Option<DeviceMapSetting>,
35 pub(crate) hf_cache_path: Option<PathBuf>,
36 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
37 pub(crate) search_callback: Option<Arc<SearchCallback>>,
38 pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
39 pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
40 pub(crate) mcp_client_config: Option<McpClientConfig>,
41 pub(crate) device: Option<Device>,
42 pub(crate) matformer_config_path: Option<PathBuf>,
43 pub(crate) matformer_slice_name: Option<String>,
44
45 pub(crate) topology: Option<Topology>,
47 pub(crate) organization: IsqOrganization,
48 pub(crate) loader_type: Option<NormalLoaderType>,
49 pub(crate) dtype: ModelDType,
50 pub(crate) force_cpu: bool,
51 pub(crate) isq: Option<IsqType>,
52 pub(crate) throughput_logging: bool,
53
54 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
56 pub(crate) max_num_seqs: usize,
57 pub(crate) no_kv_cache: bool,
58 pub(crate) with_logging: bool,
59 pub(crate) prefix_cache_n: Option<usize>,
60}
61
62pub struct PagedAttentionMetaBuilder {
64 block_size: Option<usize>,
65 mem_cpu: usize,
66 mem_gpu: MemoryGpuConfig,
67 cache_type: PagedCacheType,
68}
69
70impl Default for PagedAttentionMetaBuilder {
71 fn default() -> Self {
72 Self {
73 block_size: None,
74 mem_cpu: 64,
75 mem_gpu: MemoryGpuConfig::ContextSize(4096),
76 cache_type: PagedCacheType::Auto,
77 }
78 }
79}
80
81impl PagedAttentionMetaBuilder {
82 pub fn with_block_size(mut self, block_size: usize) -> Self {
83 self.block_size = Some(block_size);
84 self
85 }
86
87 pub fn with_gpu_memory(mut self, mem_gpu: MemoryGpuConfig) -> Self {
88 self.mem_gpu = mem_gpu;
89 self
90 }
91
92 pub fn with_paged_cache_type(mut self, cache_type: PagedCacheType) -> Self {
93 self.cache_type = cache_type;
94 self
95 }
96
97 pub fn build(self) -> anyhow::Result<PagedAttentionConfig> {
98 PagedAttentionConfig::new(self.block_size, self.mem_cpu, self.mem_gpu, self.cache_type)
99 }
100}
101
102impl TextModelBuilder {
103 pub fn new(model_id: impl ToString) -> Self {
111 Self {
112 model_id: model_id.to_string(),
113 topology: None,
114 organization: IsqOrganization::Default,
115 write_uqff: None,
116 from_uqff: None,
117 chat_template: None,
118 tokenizer_json: None,
119 loader_type: None,
120 dtype: ModelDType::Auto,
121 force_cpu: false,
122 token_source: TokenSource::CacheToken,
123 hf_revision: None,
124 isq: None,
125 paged_attn_cfg: None,
126 max_num_seqs: 32,
127 no_kv_cache: false,
128 prefix_cache_n: Some(16),
129 with_logging: false,
130 device_mapping: None,
131 imatrix: None,
132 calibration_file: None,
133 jinja_explicit: None,
134 throughput_logging: false,
135 hf_cache_path: None,
136 search_bert_model: None,
137 search_callback: None,
138 tool_callbacks: HashMap::new(),
139 tool_callbacks_with_tools: HashMap::new(),
140 mcp_client_config: None,
141 device: None,
142 matformer_config_path: None,
143 matformer_slice_name: None,
144 }
145 }
146
147 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
149 self.search_bert_model = Some(search_bert_model);
150 self
151 }
152
153 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
155 self.search_callback = Some(callback);
156 self
157 }
158
159 pub fn with_tool_callback(
161 mut self,
162 name: impl Into<String>,
163 callback: Arc<ToolCallback>,
164 ) -> Self {
165 self.tool_callbacks.insert(name.into(), callback);
166 self
167 }
168
169 pub fn with_tool_callback_and_tool(
172 mut self,
173 name: impl Into<String>,
174 callback: Arc<ToolCallback>,
175 tool: Tool,
176 ) -> Self {
177 let name = name.into();
178 self.tool_callbacks_with_tools
179 .insert(name, ToolCallbackWithTool { callback, tool });
180 self
181 }
182
183 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
186 self.mcp_client_config = Some(config);
187 self
188 }
189
190 pub fn with_throughput_logging(mut self) -> Self {
192 self.throughput_logging = true;
193 self
194 }
195
196 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
198 self.jinja_explicit = Some(jinja_explicit);
199 self
200 }
201
202 pub fn with_topology(mut self, topology: Topology) -> Self {
204 self.topology = Some(topology);
205 self
206 }
207
208 pub fn with_mixture_qexperts_isq(mut self) -> Self {
210 self.organization = IsqOrganization::MoeExpertsOnly;
211 self
212 }
213
214 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
216 self.chat_template = Some(chat_template.to_string());
217 self
218 }
219
220 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
222 self.tokenizer_json = Some(tokenizer_json.to_string());
223 self
224 }
225
226 pub fn with_loader_type(mut self, loader_type: NormalLoaderType) -> Self {
229 self.loader_type = Some(loader_type);
230 self
231 }
232
233 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
235 self.dtype = dtype;
236 self
237 }
238
239 pub fn with_force_cpu(mut self) -> Self {
241 self.force_cpu = true;
242 self
243 }
244
245 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
247 self.token_source = token_source;
248 self
249 }
250
251 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
253 self.hf_revision = Some(revision.to_string());
254 self
255 }
256
257 pub fn with_isq(mut self, isq: IsqType) -> Self {
259 self.isq = Some(isq);
260 self
261 }
262
263 pub fn with_imatrix(mut self, path: PathBuf) -> Self {
265 self.imatrix = Some(path);
266 self
267 }
268
269 pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
271 self.calibration_file = Some(path);
272 self
273 }
274
275 pub fn with_paged_attn(
280 mut self,
281 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
282 ) -> anyhow::Result<Self> {
283 if paged_attn_supported() {
284 self.paged_attn_cfg = Some(paged_attn_cfg()?);
285 } else {
286 self.paged_attn_cfg = None;
287 }
288 Ok(self)
289 }
290
291 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
293 self.max_num_seqs = max_num_seqs;
294 self
295 }
296
297 pub fn with_no_kv_cache(mut self) -> Self {
299 self.no_kv_cache = true;
300 self
301 }
302
303 pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
305 self.prefix_cache_n = n_seqs;
306 self
307 }
308
309 pub fn with_logging(mut self) -> Self {
311 self.with_logging = true;
312 self
313 }
314
315 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
317 self.device_mapping = Some(device_mapping);
318 self
319 }
320
321 #[deprecated(
322 note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
323 )]
324 pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
332 self.from_uqff = Some(path);
333 self
334 }
335
336 pub fn write_uqff(mut self, path: PathBuf) -> Self {
347 self.write_uqff = Some(path);
348 self
349 }
350
351 pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
353 self.hf_cache_path = Some(hf_cache_path);
354 self
355 }
356
357 pub fn with_device(mut self, device: Device) -> Self {
359 self.device = Some(device);
360 self
361 }
362
363 pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
365 self.matformer_config_path = Some(path);
366 self
367 }
368
369 pub fn with_matformer_slice_name(mut self, name: String) -> Self {
371 self.matformer_slice_name = Some(name);
372 self
373 }
374
375 pub async fn build(self) -> anyhow::Result<Model> {
376 let config = NormalSpecificConfig {
377 topology: self.topology,
378 organization: self.organization,
379 write_uqff: self.write_uqff,
380 from_uqff: self.from_uqff,
381 imatrix: self.imatrix,
382 calibration_file: self.calibration_file,
383 hf_cache_path: self.hf_cache_path,
384 matformer_config_path: self.matformer_config_path,
385 matformer_slice_name: self.matformer_slice_name,
386 };
387
388 if self.with_logging {
389 initialize_logging();
390 }
391
392 let loader = NormalLoaderBuilder::new(
393 config,
394 self.chat_template,
395 self.tokenizer_json,
396 Some(self.model_id),
397 self.no_kv_cache,
398 self.jinja_explicit,
399 )
400 .build(self.loader_type)?;
401
402 let pipeline = loader.load_model_from_hf(
404 self.hf_revision,
405 self.token_source,
406 &self.dtype,
407 &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
408 !self.with_logging,
409 self.device_mapping
410 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
411 self.isq,
412 self.paged_attn_cfg,
413 )?;
414
415 let scheduler_method = match self.paged_attn_cfg {
416 Some(_) => {
417 let config = pipeline
418 .lock()
419 .await
420 .get_metadata()
421 .cache_config
422 .as_ref()
423 .cloned();
424
425 if let Some(config) = config {
426 SchedulerConfig::PagedAttentionMeta {
427 max_num_seqs: self.max_num_seqs,
428 config,
429 }
430 } else {
431 SchedulerConfig::DefaultScheduler {
432 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
433 }
434 }
435 }
436 None => SchedulerConfig::DefaultScheduler {
437 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
438 },
439 };
440
441 let mut runner = MistralRsBuilder::new(
442 pipeline,
443 scheduler_method,
444 self.throughput_logging,
445 self.search_bert_model,
446 );
447 if let Some(cb) = self.search_callback.clone() {
448 runner = runner.with_search_callback(cb);
449 }
450 for (name, cb) in &self.tool_callbacks {
451 runner = runner.with_tool_callback(name.clone(), cb.clone());
452 }
453 for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
454 runner = runner.with_tool_callback_and_tool(
455 name.clone(),
456 callback_with_tool.callback.clone(),
457 callback_with_tool.tool.clone(),
458 );
459 }
460 if let Some(mcp_config) = self.mcp_client_config {
461 runner = runner.with_mcp_client(mcp_config);
462 }
463 runner = runner
464 .with_no_kv_cache(self.no_kv_cache)
465 .with_no_prefix_cache(self.prefix_cache_n.is_none());
466
467 if let Some(n) = self.prefix_cache_n {
468 runner = runner.with_prefix_cache_n(n)
469 }
470
471 Ok(Model::new(runner.build().await))
472 }
473}
474
475#[derive(Clone)]
476pub struct UqffTextModelBuilder(TextModelBuilder);
479
480impl UqffTextModelBuilder {
481 pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
488 let mut inner = TextModelBuilder::new(model_id);
489 inner.from_uqff = Some(uqff_file);
490 Self(inner)
491 }
492
493 pub async fn build(self) -> anyhow::Result<Model> {
494 self.0.build().await
495 }
496
497 pub fn into_inner(self) -> TextModelBuilder {
499 self.0
500 }
501}
502
503impl Deref for UqffTextModelBuilder {
504 type Target = TextModelBuilder;
505
506 fn deref(&self) -> &Self::Target {
507 &self.0
508 }
509}
510
511impl DerefMut for UqffTextModelBuilder {
512 fn deref_mut(&mut self) -> &mut Self::Target {
513 &mut self.0
514 }
515}
516
517impl From<UqffTextModelBuilder> for TextModelBuilder {
518 fn from(value: UqffTextModelBuilder) -> Self {
519 value.0
520 }
521}