1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6 ops::{Deref, DerefMut},
7 path::PathBuf,
8 sync::Arc,
9};
10
11use crate::{best_device, Model};
12
13#[derive(Clone)]
15pub struct ToolCallbackWithTool {
16 pub callback: Arc<ToolCallback>,
17 pub tool: Tool,
18}
19
20#[derive(Clone)]
21pub struct TextModelBuilder {
23 pub(crate) model_id: String,
25 pub(crate) token_source: TokenSource,
26 pub(crate) hf_revision: Option<String>,
27 pub(crate) write_uqff: Option<PathBuf>,
28 pub(crate) from_uqff: Option<Vec<PathBuf>>,
29 pub(crate) imatrix: Option<PathBuf>,
30 pub(crate) calibration_file: Option<PathBuf>,
31 pub(crate) chat_template: Option<String>,
32 pub(crate) jinja_explicit: Option<String>,
33 pub(crate) tokenizer_json: Option<String>,
34 pub(crate) device_mapping: Option<DeviceMapSetting>,
35 pub(crate) hf_cache_path: Option<PathBuf>,
36 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
37 pub(crate) search_callback: Option<Arc<SearchCallback>>,
38 pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
39 pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
40 pub(crate) mcp_client_config: Option<McpClientConfig>,
41 pub(crate) device: Option<Device>,
42 pub(crate) matformer_config_path: Option<PathBuf>,
43 pub(crate) matformer_slice_name: Option<String>,
44
45 pub(crate) topology: Option<Topology>,
47 pub(crate) organization: IsqOrganization,
48 pub(crate) loader_type: Option<NormalLoaderType>,
49 pub(crate) dtype: ModelDType,
50 pub(crate) force_cpu: bool,
51 pub(crate) isq: Option<IsqType>,
52 pub(crate) throughput_logging: bool,
53
54 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
56 pub(crate) max_num_seqs: usize,
57 pub(crate) no_kv_cache: bool,
58 pub(crate) with_logging: bool,
59 pub(crate) prefix_cache_n: Option<usize>,
60}
61
62pub struct PagedAttentionMetaBuilder {
64 block_size: Option<usize>,
65 mem_gpu: MemoryGpuConfig,
66 cache_type: PagedCacheType,
67}
68
69impl Default for PagedAttentionMetaBuilder {
70 fn default() -> Self {
71 Self {
72 block_size: None,
73 mem_gpu: MemoryGpuConfig::ContextSize(4096),
74 cache_type: PagedCacheType::Auto,
75 }
76 }
77}
78
79impl PagedAttentionMetaBuilder {
80 pub fn with_block_size(mut self, block_size: usize) -> Self {
81 self.block_size = Some(block_size);
82 self
83 }
84
85 pub fn with_gpu_memory(mut self, mem_gpu: MemoryGpuConfig) -> Self {
86 self.mem_gpu = mem_gpu;
87 self
88 }
89
90 pub fn with_paged_cache_type(mut self, cache_type: PagedCacheType) -> Self {
91 self.cache_type = cache_type;
92 self
93 }
94
95 pub fn build(self) -> anyhow::Result<PagedAttentionConfig> {
96 PagedAttentionConfig::new(self.block_size, self.mem_gpu, self.cache_type)
97 }
98}
99
100impl TextModelBuilder {
101 pub fn new(model_id: impl ToString) -> Self {
109 Self {
110 model_id: model_id.to_string(),
111 topology: None,
112 organization: IsqOrganization::Default,
113 write_uqff: None,
114 from_uqff: None,
115 chat_template: None,
116 tokenizer_json: None,
117 loader_type: None,
118 dtype: ModelDType::Auto,
119 force_cpu: false,
120 token_source: TokenSource::CacheToken,
121 hf_revision: None,
122 isq: None,
123 paged_attn_cfg: None,
124 max_num_seqs: 32,
125 no_kv_cache: false,
126 prefix_cache_n: Some(16),
127 with_logging: false,
128 device_mapping: None,
129 imatrix: None,
130 calibration_file: None,
131 jinja_explicit: None,
132 throughput_logging: false,
133 hf_cache_path: None,
134 search_bert_model: None,
135 search_callback: None,
136 tool_callbacks: HashMap::new(),
137 tool_callbacks_with_tools: HashMap::new(),
138 mcp_client_config: None,
139 device: None,
140 matformer_config_path: None,
141 matformer_slice_name: None,
142 }
143 }
144
145 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
147 self.search_bert_model = Some(search_bert_model);
148 self
149 }
150
151 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
153 self.search_callback = Some(callback);
154 self
155 }
156
157 pub fn with_tool_callback(
159 mut self,
160 name: impl Into<String>,
161 callback: Arc<ToolCallback>,
162 ) -> Self {
163 self.tool_callbacks.insert(name.into(), callback);
164 self
165 }
166
167 pub fn with_tool_callback_and_tool(
170 mut self,
171 name: impl Into<String>,
172 callback: Arc<ToolCallback>,
173 tool: Tool,
174 ) -> Self {
175 let name = name.into();
176 self.tool_callbacks_with_tools
177 .insert(name, ToolCallbackWithTool { callback, tool });
178 self
179 }
180
181 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
184 self.mcp_client_config = Some(config);
185 self
186 }
187
188 pub fn with_throughput_logging(mut self) -> Self {
190 self.throughput_logging = true;
191 self
192 }
193
194 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
196 self.jinja_explicit = Some(jinja_explicit);
197 self
198 }
199
200 pub fn with_topology(mut self, topology: Topology) -> Self {
202 self.topology = Some(topology);
203 self
204 }
205
206 pub fn with_mixture_qexperts_isq(mut self) -> Self {
208 self.organization = IsqOrganization::MoeExpertsOnly;
209 self
210 }
211
212 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
214 self.chat_template = Some(chat_template.to_string());
215 self
216 }
217
218 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
220 self.tokenizer_json = Some(tokenizer_json.to_string());
221 self
222 }
223
224 pub fn with_loader_type(mut self, loader_type: NormalLoaderType) -> Self {
227 self.loader_type = Some(loader_type);
228 self
229 }
230
231 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
233 self.dtype = dtype;
234 self
235 }
236
237 pub fn with_force_cpu(mut self) -> Self {
239 self.force_cpu = true;
240 self
241 }
242
243 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
245 self.token_source = token_source;
246 self
247 }
248
249 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
251 self.hf_revision = Some(revision.to_string());
252 self
253 }
254
255 pub fn with_isq(mut self, isq: IsqType) -> Self {
257 self.isq = Some(isq);
258 self
259 }
260
261 pub fn with_imatrix(mut self, path: PathBuf) -> Self {
263 self.imatrix = Some(path);
264 self
265 }
266
267 pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
269 self.calibration_file = Some(path);
270 self
271 }
272
273 pub fn with_paged_attn(
278 mut self,
279 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
280 ) -> anyhow::Result<Self> {
281 if paged_attn_supported() {
282 self.paged_attn_cfg = Some(paged_attn_cfg()?);
283 } else {
284 self.paged_attn_cfg = None;
285 }
286 Ok(self)
287 }
288
289 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
291 self.max_num_seqs = max_num_seqs;
292 self
293 }
294
295 pub fn with_no_kv_cache(mut self) -> Self {
297 self.no_kv_cache = true;
298 self
299 }
300
301 pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
303 self.prefix_cache_n = n_seqs;
304 self
305 }
306
307 pub fn with_logging(mut self) -> Self {
309 self.with_logging = true;
310 self
311 }
312
313 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
315 self.device_mapping = Some(device_mapping);
316 self
317 }
318
319 #[deprecated(
320 note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
321 )]
322 pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
330 self.from_uqff = Some(path);
331 self
332 }
333
334 pub fn write_uqff(mut self, path: PathBuf) -> Self {
345 self.write_uqff = Some(path);
346 self
347 }
348
349 pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
351 self.hf_cache_path = Some(hf_cache_path);
352 self
353 }
354
355 pub fn with_device(mut self, device: Device) -> Self {
357 self.device = Some(device);
358 self
359 }
360
361 pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
363 self.matformer_config_path = Some(path);
364 self
365 }
366
367 pub fn with_matformer_slice_name(mut self, name: String) -> Self {
369 self.matformer_slice_name = Some(name);
370 self
371 }
372
373 pub async fn build(self) -> anyhow::Result<Model> {
374 let config = NormalSpecificConfig {
375 topology: self.topology,
376 organization: self.organization,
377 write_uqff: self.write_uqff,
378 from_uqff: self.from_uqff,
379 imatrix: self.imatrix,
380 calibration_file: self.calibration_file,
381 hf_cache_path: self.hf_cache_path,
382 matformer_config_path: self.matformer_config_path,
383 matformer_slice_name: self.matformer_slice_name,
384 };
385
386 if self.with_logging {
387 initialize_logging();
388 }
389
390 let loader = NormalLoaderBuilder::new(
391 config,
392 self.chat_template,
393 self.tokenizer_json,
394 Some(self.model_id),
395 self.no_kv_cache,
396 self.jinja_explicit,
397 )
398 .build(self.loader_type)?;
399
400 let pipeline = loader.load_model_from_hf(
402 self.hf_revision,
403 self.token_source,
404 &self.dtype,
405 &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
406 !self.with_logging,
407 self.device_mapping
408 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
409 self.isq,
410 self.paged_attn_cfg,
411 )?;
412
413 let scheduler_method = match self.paged_attn_cfg {
414 Some(_) => {
415 let config = pipeline
416 .lock()
417 .await
418 .get_metadata()
419 .cache_config
420 .as_ref()
421 .cloned();
422
423 if let Some(config) = config {
424 SchedulerConfig::PagedAttentionMeta {
425 max_num_seqs: self.max_num_seqs,
426 config,
427 }
428 } else {
429 SchedulerConfig::DefaultScheduler {
430 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
431 }
432 }
433 }
434 None => SchedulerConfig::DefaultScheduler {
435 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
436 },
437 };
438
439 let mut runner = MistralRsBuilder::new(
440 pipeline,
441 scheduler_method,
442 self.throughput_logging,
443 self.search_bert_model,
444 );
445 if let Some(cb) = self.search_callback.clone() {
446 runner = runner.with_search_callback(cb);
447 }
448 for (name, cb) in &self.tool_callbacks {
449 runner = runner.with_tool_callback(name.clone(), cb.clone());
450 }
451 for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
452 runner = runner.with_tool_callback_and_tool(
453 name.clone(),
454 callback_with_tool.callback.clone(),
455 callback_with_tool.tool.clone(),
456 );
457 }
458 if let Some(mcp_config) = self.mcp_client_config {
459 runner = runner.with_mcp_client(mcp_config);
460 }
461 runner = runner
462 .with_no_kv_cache(self.no_kv_cache)
463 .with_no_prefix_cache(self.prefix_cache_n.is_none());
464
465 if let Some(n) = self.prefix_cache_n {
466 runner = runner.with_prefix_cache_n(n)
467 }
468
469 Ok(Model::new(runner.build().await))
470 }
471}
472
473#[derive(Clone)]
474pub struct UqffTextModelBuilder(TextModelBuilder);
477
478impl UqffTextModelBuilder {
479 pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
486 let mut inner = TextModelBuilder::new(model_id);
487 inner.from_uqff = Some(uqff_file);
488 Self(inner)
489 }
490
491 pub async fn build(self) -> anyhow::Result<Model> {
492 self.0.build().await
493 }
494
495 pub fn into_inner(self) -> TextModelBuilder {
497 self.0
498 }
499}
500
501impl Deref for UqffTextModelBuilder {
502 type Target = TextModelBuilder;
503
504 fn deref(&self) -> &Self::Target {
505 &self.0
506 }
507}
508
509impl DerefMut for UqffTextModelBuilder {
510 fn deref_mut(&mut self) -> &mut Self::Target {
511 &mut self.0
512 }
513}
514
515impl From<UqffTextModelBuilder> for TextModelBuilder {
516 fn from(value: UqffTextModelBuilder) -> Self {
517 value.0
518 }
519}