1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6 num::NonZeroUsize,
7 ops::{Deref, DerefMut},
8 path::PathBuf,
9 sync::Arc,
10};
11
12use crate::{best_device, Model};
13
14#[derive(Clone)]
16pub struct ToolCallbackWithTool {
17 pub callback: Arc<ToolCallback>,
18 pub tool: Tool,
19}
20
21#[derive(Clone)]
22pub struct TextModelBuilder {
24 pub(crate) model_id: String,
26 pub(crate) token_source: TokenSource,
27 pub(crate) hf_revision: Option<String>,
28 pub(crate) write_uqff: Option<PathBuf>,
29 pub(crate) from_uqff: Option<Vec<PathBuf>>,
30 pub(crate) imatrix: Option<PathBuf>,
31 pub(crate) calibration_file: Option<PathBuf>,
32 pub(crate) chat_template: Option<String>,
33 pub(crate) jinja_explicit: Option<String>,
34 pub(crate) tokenizer_json: Option<String>,
35 pub(crate) device_mapping: Option<DeviceMapSetting>,
36 pub(crate) hf_cache_path: Option<PathBuf>,
37 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
38 pub(crate) search_callback: Option<Arc<SearchCallback>>,
39 pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
40 pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
41 pub(crate) mcp_client_config: Option<McpClientConfig>,
42 pub(crate) device: Option<Device>,
43 pub(crate) matformer_config_path: Option<PathBuf>,
44 pub(crate) matformer_slice_name: Option<String>,
45
46 pub(crate) prompt_chunksize: Option<NonZeroUsize>,
48 pub(crate) topology: Option<Topology>,
49 pub(crate) organization: IsqOrganization,
50 pub(crate) loader_type: Option<NormalLoaderType>,
51 pub(crate) dtype: ModelDType,
52 pub(crate) force_cpu: bool,
53 pub(crate) isq: Option<IsqType>,
54 pub(crate) throughput_logging: bool,
55
56 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
58 pub(crate) max_num_seqs: usize,
59 pub(crate) no_kv_cache: bool,
60 pub(crate) with_logging: bool,
61 pub(crate) prefix_cache_n: Option<usize>,
62}
63
64pub struct PagedAttentionMetaBuilder {
66 block_size: Option<usize>,
67 mem_cpu: usize,
68 mem_gpu: MemoryGpuConfig,
69 cache_type: PagedCacheType,
70}
71
72impl Default for PagedAttentionMetaBuilder {
73 fn default() -> Self {
74 Self {
75 block_size: None,
76 mem_cpu: 64,
77 mem_gpu: MemoryGpuConfig::ContextSize(4096),
78 cache_type: PagedCacheType::Auto,
79 }
80 }
81}
82
83impl PagedAttentionMetaBuilder {
84 pub fn with_block_size(mut self, block_size: usize) -> Self {
85 self.block_size = Some(block_size);
86 self
87 }
88
89 pub fn with_gpu_memory(mut self, mem_gpu: MemoryGpuConfig) -> Self {
90 self.mem_gpu = mem_gpu;
91 self
92 }
93
94 pub fn with_paged_cache_type(mut self, cache_type: PagedCacheType) -> Self {
95 self.cache_type = cache_type;
96 self
97 }
98
99 pub fn build(self) -> anyhow::Result<PagedAttentionConfig> {
100 PagedAttentionConfig::new(self.block_size, self.mem_cpu, self.mem_gpu, self.cache_type)
101 }
102}
103
104impl TextModelBuilder {
105 pub fn new(model_id: impl ToString) -> Self {
113 Self {
114 model_id: model_id.to_string(),
115 prompt_chunksize: None,
116 topology: None,
117 organization: IsqOrganization::Default,
118 write_uqff: None,
119 from_uqff: None,
120 chat_template: None,
121 tokenizer_json: None,
122 loader_type: None,
123 dtype: ModelDType::Auto,
124 force_cpu: false,
125 token_source: TokenSource::CacheToken,
126 hf_revision: None,
127 isq: None,
128 paged_attn_cfg: None,
129 max_num_seqs: 32,
130 no_kv_cache: false,
131 prefix_cache_n: Some(16),
132 with_logging: false,
133 device_mapping: None,
134 imatrix: None,
135 calibration_file: None,
136 jinja_explicit: None,
137 throughput_logging: false,
138 hf_cache_path: None,
139 search_bert_model: None,
140 search_callback: None,
141 tool_callbacks: HashMap::new(),
142 tool_callbacks_with_tools: HashMap::new(),
143 mcp_client_config: None,
144 device: None,
145 matformer_config_path: None,
146 matformer_slice_name: None,
147 }
148 }
149
150 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
152 self.search_bert_model = Some(search_bert_model);
153 self
154 }
155
156 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
158 self.search_callback = Some(callback);
159 self
160 }
161
162 pub fn with_tool_callback(
164 mut self,
165 name: impl Into<String>,
166 callback: Arc<ToolCallback>,
167 ) -> Self {
168 self.tool_callbacks.insert(name.into(), callback);
169 self
170 }
171
172 pub fn with_tool_callback_and_tool(
175 mut self,
176 name: impl Into<String>,
177 callback: Arc<ToolCallback>,
178 tool: Tool,
179 ) -> Self {
180 let name = name.into();
181 self.tool_callbacks_with_tools
182 .insert(name, ToolCallbackWithTool { callback, tool });
183 self
184 }
185
186 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
189 self.mcp_client_config = Some(config);
190 self
191 }
192
193 pub fn with_throughput_logging(mut self) -> Self {
195 self.throughput_logging = true;
196 self
197 }
198
199 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
201 self.jinja_explicit = Some(jinja_explicit);
202 self
203 }
204
205 pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
207 self.prompt_chunksize = Some(prompt_chunksize);
208 self
209 }
210
211 pub fn with_topology(mut self, topology: Topology) -> Self {
213 self.topology = Some(topology);
214 self
215 }
216
217 pub fn with_mixture_qexperts_isq(mut self) -> Self {
219 self.organization = IsqOrganization::MoeExpertsOnly;
220 self
221 }
222
223 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
225 self.chat_template = Some(chat_template.to_string());
226 self
227 }
228
229 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
231 self.tokenizer_json = Some(tokenizer_json.to_string());
232 self
233 }
234
235 pub fn with_loader_type(mut self, loader_type: NormalLoaderType) -> Self {
238 self.loader_type = Some(loader_type);
239 self
240 }
241
242 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
244 self.dtype = dtype;
245 self
246 }
247
248 pub fn with_force_cpu(mut self) -> Self {
250 self.force_cpu = true;
251 self
252 }
253
254 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
256 self.token_source = token_source;
257 self
258 }
259
260 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
262 self.hf_revision = Some(revision.to_string());
263 self
264 }
265
266 pub fn with_isq(mut self, isq: IsqType) -> Self {
268 self.isq = Some(isq);
269 self
270 }
271
272 pub fn with_imatrix(mut self, path: PathBuf) -> Self {
274 self.imatrix = Some(path);
275 self
276 }
277
278 pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
280 self.calibration_file = Some(path);
281 self
282 }
283
284 pub fn with_paged_attn(
289 mut self,
290 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
291 ) -> anyhow::Result<Self> {
292 if paged_attn_supported() {
293 self.paged_attn_cfg = Some(paged_attn_cfg()?);
294 } else {
295 self.paged_attn_cfg = None;
296 }
297 Ok(self)
298 }
299
300 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
302 self.max_num_seqs = max_num_seqs;
303 self
304 }
305
306 pub fn with_no_kv_cache(mut self) -> Self {
308 self.no_kv_cache = true;
309 self
310 }
311
312 pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
314 self.prefix_cache_n = n_seqs;
315 self
316 }
317
318 pub fn with_logging(mut self) -> Self {
320 self.with_logging = true;
321 self
322 }
323
324 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
326 self.device_mapping = Some(device_mapping);
327 self
328 }
329
330 #[deprecated(
331 note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
332 )]
333 pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
341 self.from_uqff = Some(path);
342 self
343 }
344
345 pub fn write_uqff(mut self, path: PathBuf) -> Self {
356 self.write_uqff = Some(path);
357 self
358 }
359
360 pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
362 self.hf_cache_path = Some(hf_cache_path);
363 self
364 }
365
366 pub fn with_device(mut self, device: Device) -> Self {
368 self.device = Some(device);
369 self
370 }
371
372 pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
374 self.matformer_config_path = Some(path);
375 self
376 }
377
378 pub fn with_matformer_slice_name(mut self, name: String) -> Self {
380 self.matformer_slice_name = Some(name);
381 self
382 }
383
384 pub async fn build(self) -> anyhow::Result<Model> {
385 let config = NormalSpecificConfig {
386 prompt_chunksize: self.prompt_chunksize,
387 topology: self.topology,
388 organization: self.organization,
389 write_uqff: self.write_uqff,
390 from_uqff: self.from_uqff,
391 imatrix: self.imatrix,
392 calibration_file: self.calibration_file,
393 hf_cache_path: self.hf_cache_path,
394 matformer_config_path: self.matformer_config_path,
395 matformer_slice_name: self.matformer_slice_name,
396 };
397
398 if self.with_logging {
399 initialize_logging();
400 }
401
402 let loader = NormalLoaderBuilder::new(
403 config,
404 self.chat_template,
405 self.tokenizer_json,
406 Some(self.model_id),
407 self.no_kv_cache,
408 self.jinja_explicit,
409 )
410 .build(self.loader_type)?;
411
412 let pipeline = loader.load_model_from_hf(
414 self.hf_revision,
415 self.token_source,
416 &self.dtype,
417 &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
418 !self.with_logging,
419 self.device_mapping
420 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
421 self.isq,
422 self.paged_attn_cfg,
423 )?;
424
425 let scheduler_method = match self.paged_attn_cfg {
426 Some(_) => {
427 let config = pipeline
428 .lock()
429 .await
430 .get_metadata()
431 .cache_config
432 .as_ref()
433 .cloned();
434
435 if let Some(config) = config {
436 SchedulerConfig::PagedAttentionMeta {
437 max_num_seqs: self.max_num_seqs,
438 config,
439 }
440 } else {
441 SchedulerConfig::DefaultScheduler {
442 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
443 }
444 }
445 }
446 None => SchedulerConfig::DefaultScheduler {
447 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
448 },
449 };
450
451 let mut runner = MistralRsBuilder::new(
452 pipeline,
453 scheduler_method,
454 self.throughput_logging,
455 self.search_bert_model,
456 );
457 if let Some(cb) = self.search_callback.clone() {
458 runner = runner.with_search_callback(cb);
459 }
460 for (name, cb) in &self.tool_callbacks {
461 runner = runner.with_tool_callback(name.clone(), cb.clone());
462 }
463 for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
464 runner = runner.with_tool_callback_and_tool(
465 name.clone(),
466 callback_with_tool.callback.clone(),
467 callback_with_tool.tool.clone(),
468 );
469 }
470 if let Some(mcp_config) = self.mcp_client_config {
471 runner = runner.with_mcp_client(mcp_config);
472 }
473 runner = runner
474 .with_no_kv_cache(self.no_kv_cache)
475 .with_no_prefix_cache(self.prefix_cache_n.is_none());
476
477 if let Some(n) = self.prefix_cache_n {
478 runner = runner.with_prefix_cache_n(n)
479 }
480
481 Ok(Model::new(runner.build().await))
482 }
483}
484
485#[derive(Clone)]
486pub struct UqffTextModelBuilder(TextModelBuilder);
489
490impl UqffTextModelBuilder {
491 pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
498 let mut inner = TextModelBuilder::new(model_id);
499 inner.from_uqff = Some(uqff_file);
500 Self(inner)
501 }
502
503 pub async fn build(self) -> anyhow::Result<Model> {
504 self.0.build().await
505 }
506
507 pub fn into_inner(self) -> TextModelBuilder {
509 self.0
510 }
511}
512
513impl Deref for UqffTextModelBuilder {
514 type Target = TextModelBuilder;
515
516 fn deref(&self) -> &Self::Target {
517 &self.0
518 }
519}
520
521impl DerefMut for UqffTextModelBuilder {
522 fn deref_mut(&mut self) -> &mut Self::Target {
523 &mut self.0
524 }
525}
526
527impl From<UqffTextModelBuilder> for TextModelBuilder {
528 fn from(value: UqffTextModelBuilder) -> Self {
529 value.0
530 }
531}