1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6 num::NonZeroUsize,
7 ops::{Deref, DerefMut},
8 path::PathBuf,
9 sync::Arc,
10};
11
12use crate::{best_device, Model};
13
14#[derive(Clone)]
16pub struct ToolCallbackWithTool {
17 pub callback: Arc<ToolCallback>,
18 pub tool: Tool,
19}
20
21#[derive(Clone)]
22pub struct TextModelBuilder {
24 pub(crate) model_id: String,
26 pub(crate) token_source: TokenSource,
27 pub(crate) hf_revision: Option<String>,
28 pub(crate) write_uqff: Option<PathBuf>,
29 pub(crate) from_uqff: Option<Vec<PathBuf>>,
30 pub(crate) imatrix: Option<PathBuf>,
31 pub(crate) calibration_file: Option<PathBuf>,
32 pub(crate) chat_template: Option<String>,
33 pub(crate) jinja_explicit: Option<String>,
34 pub(crate) tokenizer_json: Option<String>,
35 pub(crate) device_mapping: Option<DeviceMapSetting>,
36 pub(crate) hf_cache_path: Option<PathBuf>,
37 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
38 pub(crate) search_callback: Option<Arc<SearchCallback>>,
39 pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
40 pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
41 pub(crate) mcp_client_config: Option<McpClientConfig>,
42 pub(crate) device: Option<Device>,
43
44 pub(crate) prompt_chunksize: Option<NonZeroUsize>,
46 pub(crate) topology: Option<Topology>,
47 pub(crate) organization: IsqOrganization,
48 pub(crate) loader_type: Option<NormalLoaderType>,
49 pub(crate) dtype: ModelDType,
50 pub(crate) force_cpu: bool,
51 pub(crate) isq: Option<IsqType>,
52 pub(crate) throughput_logging: bool,
53
54 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
56 pub(crate) max_num_seqs: usize,
57 pub(crate) no_kv_cache: bool,
58 pub(crate) with_logging: bool,
59 pub(crate) prefix_cache_n: Option<usize>,
60}
61
62pub struct PagedAttentionMetaBuilder {
64 block_size: Option<usize>,
65 mem_cpu: usize,
66 mem_gpu: MemoryGpuConfig,
67 cache_type: PagedCacheType,
68}
69
70impl Default for PagedAttentionMetaBuilder {
71 fn default() -> Self {
72 Self {
73 block_size: None,
74 mem_cpu: 64,
75 mem_gpu: MemoryGpuConfig::ContextSize(4096),
76 cache_type: PagedCacheType::Auto,
77 }
78 }
79}
80
81impl PagedAttentionMetaBuilder {
82 pub fn with_block_size(mut self, block_size: usize) -> Self {
83 self.block_size = Some(block_size);
84 self
85 }
86
87 pub fn with_gpu_memory(mut self, mem_gpu: MemoryGpuConfig) -> Self {
88 self.mem_gpu = mem_gpu;
89 self
90 }
91
92 pub fn with_paged_cache_type(mut self, cache_type: PagedCacheType) -> Self {
93 self.cache_type = cache_type;
94 self
95 }
96
97 pub fn build(self) -> anyhow::Result<PagedAttentionConfig> {
98 PagedAttentionConfig::new(self.block_size, self.mem_cpu, self.mem_gpu, self.cache_type)
99 }
100}
101
102impl TextModelBuilder {
103 pub fn new(model_id: impl ToString) -> Self {
111 Self {
112 model_id: model_id.to_string(),
113 prompt_chunksize: None,
114 topology: None,
115 organization: IsqOrganization::Default,
116 write_uqff: None,
117 from_uqff: None,
118 chat_template: None,
119 tokenizer_json: None,
120 loader_type: None,
121 dtype: ModelDType::Auto,
122 force_cpu: false,
123 token_source: TokenSource::CacheToken,
124 hf_revision: None,
125 isq: None,
126 paged_attn_cfg: None,
127 max_num_seqs: 32,
128 no_kv_cache: false,
129 prefix_cache_n: Some(16),
130 with_logging: false,
131 device_mapping: None,
132 imatrix: None,
133 calibration_file: None,
134 jinja_explicit: None,
135 throughput_logging: false,
136 hf_cache_path: None,
137 search_bert_model: None,
138 search_callback: None,
139 tool_callbacks: HashMap::new(),
140 tool_callbacks_with_tools: HashMap::new(),
141 mcp_client_config: None,
142 device: None,
143 }
144 }
145
146 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
148 self.search_bert_model = Some(search_bert_model);
149 self
150 }
151
152 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
154 self.search_callback = Some(callback);
155 self
156 }
157
158 pub fn with_tool_callback(
160 mut self,
161 name: impl Into<String>,
162 callback: Arc<ToolCallback>,
163 ) -> Self {
164 self.tool_callbacks.insert(name.into(), callback);
165 self
166 }
167
168 pub fn with_tool_callback_and_tool(
171 mut self,
172 name: impl Into<String>,
173 callback: Arc<ToolCallback>,
174 tool: Tool,
175 ) -> Self {
176 let name = name.into();
177 self.tool_callbacks_with_tools
178 .insert(name, ToolCallbackWithTool { callback, tool });
179 self
180 }
181
182 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
185 self.mcp_client_config = Some(config);
186 self
187 }
188
189 pub fn with_throughput_logging(mut self) -> Self {
191 self.throughput_logging = true;
192 self
193 }
194
195 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
197 self.jinja_explicit = Some(jinja_explicit);
198 self
199 }
200
201 pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
203 self.prompt_chunksize = Some(prompt_chunksize);
204 self
205 }
206
207 pub fn with_topology(mut self, topology: Topology) -> Self {
209 self.topology = Some(topology);
210 self
211 }
212
213 pub fn with_mixture_qexperts_isq(mut self) -> Self {
215 self.organization = IsqOrganization::MoeExpertsOnly;
216 self
217 }
218
219 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
221 self.chat_template = Some(chat_template.to_string());
222 self
223 }
224
225 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
227 self.tokenizer_json = Some(tokenizer_json.to_string());
228 self
229 }
230
231 pub fn with_loader_type(mut self, loader_type: NormalLoaderType) -> Self {
234 self.loader_type = Some(loader_type);
235 self
236 }
237
238 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
240 self.dtype = dtype;
241 self
242 }
243
244 pub fn with_force_cpu(mut self) -> Self {
246 self.force_cpu = true;
247 self
248 }
249
250 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
252 self.token_source = token_source;
253 self
254 }
255
256 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
258 self.hf_revision = Some(revision.to_string());
259 self
260 }
261
262 pub fn with_isq(mut self, isq: IsqType) -> Self {
264 self.isq = Some(isq);
265 self
266 }
267
268 pub fn with_imatrix(mut self, path: PathBuf) -> Self {
270 self.imatrix = Some(path);
271 self
272 }
273
274 pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
276 self.calibration_file = Some(path);
277 self
278 }
279
280 pub fn with_paged_attn(
285 mut self,
286 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
287 ) -> anyhow::Result<Self> {
288 if paged_attn_supported() {
289 self.paged_attn_cfg = Some(paged_attn_cfg()?);
290 } else {
291 self.paged_attn_cfg = None;
292 }
293 Ok(self)
294 }
295
296 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
298 self.max_num_seqs = max_num_seqs;
299 self
300 }
301
302 pub fn with_no_kv_cache(mut self) -> Self {
304 self.no_kv_cache = true;
305 self
306 }
307
308 pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
310 self.prefix_cache_n = n_seqs;
311 self
312 }
313
314 pub fn with_logging(mut self) -> Self {
316 self.with_logging = true;
317 self
318 }
319
320 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
322 self.device_mapping = Some(device_mapping);
323 self
324 }
325
326 pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
328 self.from_uqff = Some(path);
329 self
330 }
331
332 pub fn write_uqff(mut self, path: PathBuf) -> Self {
341 self.write_uqff = Some(path);
342 self
343 }
344
345 pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
347 self.hf_cache_path = Some(hf_cache_path);
348 self
349 }
350
351 pub fn with_device(mut self, device: Device) -> Self {
353 self.device = Some(device);
354 self
355 }
356
357 pub async fn build(self) -> anyhow::Result<Model> {
358 let config = NormalSpecificConfig {
359 prompt_chunksize: self.prompt_chunksize,
360 topology: self.topology,
361 organization: self.organization,
362 write_uqff: self.write_uqff,
363 from_uqff: self.from_uqff,
364 imatrix: self.imatrix,
365 calibration_file: self.calibration_file,
366 hf_cache_path: self.hf_cache_path,
367 };
368
369 if self.with_logging {
370 initialize_logging();
371 }
372
373 let loader = NormalLoaderBuilder::new(
374 config,
375 self.chat_template,
376 self.tokenizer_json,
377 Some(self.model_id),
378 self.no_kv_cache,
379 self.jinja_explicit,
380 )
381 .build(self.loader_type)?;
382
383 let pipeline = loader.load_model_from_hf(
385 self.hf_revision,
386 self.token_source,
387 &self.dtype,
388 &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
389 !self.with_logging,
390 self.device_mapping
391 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
392 self.isq,
393 self.paged_attn_cfg,
394 )?;
395
396 let scheduler_method = match self.paged_attn_cfg {
397 Some(_) => {
398 let config = pipeline
399 .lock()
400 .await
401 .get_metadata()
402 .cache_config
403 .as_ref()
404 .cloned();
405
406 if let Some(config) = config {
407 SchedulerConfig::PagedAttentionMeta {
408 max_num_seqs: self.max_num_seqs,
409 config,
410 }
411 } else {
412 SchedulerConfig::DefaultScheduler {
413 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
414 }
415 }
416 }
417 None => SchedulerConfig::DefaultScheduler {
418 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
419 },
420 };
421
422 let mut runner = MistralRsBuilder::new(
423 pipeline,
424 scheduler_method,
425 self.throughput_logging,
426 self.search_bert_model,
427 );
428 if let Some(cb) = self.search_callback.clone() {
429 runner = runner.with_search_callback(cb);
430 }
431 for (name, cb) in &self.tool_callbacks {
432 runner = runner.with_tool_callback(name.clone(), cb.clone());
433 }
434 for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
435 runner = runner.with_tool_callback_and_tool(
436 name.clone(),
437 callback_with_tool.callback.clone(),
438 callback_with_tool.tool.clone(),
439 );
440 }
441 if let Some(mcp_config) = self.mcp_client_config {
442 runner = runner.with_mcp_client(mcp_config);
443 }
444 runner = runner
445 .with_no_kv_cache(self.no_kv_cache)
446 .with_no_prefix_cache(self.prefix_cache_n.is_none());
447
448 if let Some(n) = self.prefix_cache_n {
449 runner = runner.with_prefix_cache_n(n)
450 }
451
452 Ok(Model::new(runner.build().await))
453 }
454}
455
456#[derive(Clone)]
457pub struct UqffTextModelBuilder(TextModelBuilder);
460
461impl UqffTextModelBuilder {
462 pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
469 let mut inner = TextModelBuilder::new(model_id);
470 inner = inner.from_uqff(uqff_file);
471 Self(inner)
472 }
473
474 pub async fn build(self) -> anyhow::Result<Model> {
475 self.0.build().await
476 }
477
478 pub fn into_inner(self) -> TextModelBuilder {
480 self.0
481 }
482}
483
484impl Deref for UqffTextModelBuilder {
485 type Target = TextModelBuilder;
486
487 fn deref(&self) -> &Self::Target {
488 &self.0
489 }
490}
491
492impl DerefMut for UqffTextModelBuilder {
493 fn deref_mut(&mut self) -> &mut Self::Target {
494 &mut self.0
495 }
496}
497
498impl From<UqffTextModelBuilder> for TextModelBuilder {
499 fn from(value: UqffTextModelBuilder) -> Self {
500 value.0
501 }
502}