mistralrs/
text_model.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6    num::NonZeroUsize,
7    ops::{Deref, DerefMut},
8    path::PathBuf,
9    sync::Arc,
10};
11
12use crate::{best_device, Model};
13
14/// A tool callback with its associated Tool definition.
15#[derive(Clone)]
16pub struct ToolCallbackWithTool {
17    pub callback: Arc<ToolCallback>,
18    pub tool: Tool,
19}
20
21#[derive(Clone)]
22/// Configure a text model with the various parameters for loading, running, and other inference behaviors.
23pub struct TextModelBuilder {
24    // Loading model
25    pub(crate) model_id: String,
26    pub(crate) token_source: TokenSource,
27    pub(crate) hf_revision: Option<String>,
28    pub(crate) write_uqff: Option<PathBuf>,
29    pub(crate) from_uqff: Option<Vec<PathBuf>>,
30    pub(crate) imatrix: Option<PathBuf>,
31    pub(crate) calibration_file: Option<PathBuf>,
32    pub(crate) chat_template: Option<String>,
33    pub(crate) jinja_explicit: Option<String>,
34    pub(crate) tokenizer_json: Option<String>,
35    pub(crate) device_mapping: Option<DeviceMapSetting>,
36    pub(crate) hf_cache_path: Option<PathBuf>,
37    pub(crate) search_bert_model: Option<BertEmbeddingModel>,
38    pub(crate) search_callback: Option<Arc<SearchCallback>>,
39    pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
40    pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
41    pub(crate) mcp_client_config: Option<McpClientConfig>,
42    pub(crate) device: Option<Device>,
43
44    // Model running
45    pub(crate) prompt_chunksize: Option<NonZeroUsize>,
46    pub(crate) topology: Option<Topology>,
47    pub(crate) organization: IsqOrganization,
48    pub(crate) loader_type: Option<NormalLoaderType>,
49    pub(crate) dtype: ModelDType,
50    pub(crate) force_cpu: bool,
51    pub(crate) isq: Option<IsqType>,
52    pub(crate) throughput_logging: bool,
53
54    // Other things
55    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
56    pub(crate) max_num_seqs: usize,
57    pub(crate) no_kv_cache: bool,
58    pub(crate) with_logging: bool,
59    pub(crate) prefix_cache_n: Option<usize>,
60}
61
62/// Builder for PagedAttention metadata.
63pub struct PagedAttentionMetaBuilder {
64    block_size: Option<usize>,
65    mem_cpu: usize,
66    mem_gpu: MemoryGpuConfig,
67    cache_type: PagedCacheType,
68}
69
70impl Default for PagedAttentionMetaBuilder {
71    fn default() -> Self {
72        Self {
73            block_size: None,
74            mem_cpu: 64,
75            mem_gpu: MemoryGpuConfig::ContextSize(4096),
76            cache_type: PagedCacheType::Auto,
77        }
78    }
79}
80
81impl PagedAttentionMetaBuilder {
82    pub fn with_block_size(mut self, block_size: usize) -> Self {
83        self.block_size = Some(block_size);
84        self
85    }
86
87    pub fn with_gpu_memory(mut self, mem_gpu: MemoryGpuConfig) -> Self {
88        self.mem_gpu = mem_gpu;
89        self
90    }
91
92    pub fn with_paged_cache_type(mut self, cache_type: PagedCacheType) -> Self {
93        self.cache_type = cache_type;
94        self
95    }
96
97    pub fn build(self) -> anyhow::Result<PagedAttentionConfig> {
98        PagedAttentionConfig::new(self.block_size, self.mem_cpu, self.mem_gpu, self.cache_type)
99    }
100}
101
102impl TextModelBuilder {
103    /// A few defaults are applied here:
104    /// - MoQE ISQ organization
105    /// - Token source is from the cache (.cache/huggingface/token)
106    /// - Maximum number of sequences running is 32
107    /// - Number of sequences to hold in prefix cache is 16.
108    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
109    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
110    pub fn new(model_id: impl ToString) -> Self {
111        Self {
112            model_id: model_id.to_string(),
113            prompt_chunksize: None,
114            topology: None,
115            organization: IsqOrganization::Default,
116            write_uqff: None,
117            from_uqff: None,
118            chat_template: None,
119            tokenizer_json: None,
120            loader_type: None,
121            dtype: ModelDType::Auto,
122            force_cpu: false,
123            token_source: TokenSource::CacheToken,
124            hf_revision: None,
125            isq: None,
126            paged_attn_cfg: None,
127            max_num_seqs: 32,
128            no_kv_cache: false,
129            prefix_cache_n: Some(16),
130            with_logging: false,
131            device_mapping: None,
132            imatrix: None,
133            calibration_file: None,
134            jinja_explicit: None,
135            throughput_logging: false,
136            hf_cache_path: None,
137            search_bert_model: None,
138            search_callback: None,
139            tool_callbacks: HashMap::new(),
140            tool_callbacks_with_tools: HashMap::new(),
141            mcp_client_config: None,
142            device: None,
143        }
144    }
145
146    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified or the default.
147    pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
148        self.search_bert_model = Some(search_bert_model);
149        self
150    }
151
152    /// Override the search function used when `web_search_options` is enabled.
153    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
154        self.search_callback = Some(callback);
155        self
156    }
157
158    /// Register a callback for a specific tool name.
159    pub fn with_tool_callback(
160        mut self,
161        name: impl Into<String>,
162        callback: Arc<ToolCallback>,
163    ) -> Self {
164        self.tool_callbacks.insert(name.into(), callback);
165        self
166    }
167
168    /// Register a callback with an associated Tool definition that will be automatically
169    /// added to requests when tool callbacks are active.
170    pub fn with_tool_callback_and_tool(
171        mut self,
172        name: impl Into<String>,
173        callback: Arc<ToolCallback>,
174        tool: Tool,
175    ) -> Self {
176        let name = name.into();
177        self.tool_callbacks_with_tools
178            .insert(name, ToolCallbackWithTool { callback, tool });
179        self
180    }
181
182    /// Configure MCP client to connect to external MCP servers and automatically
183    /// register their tools for use in automatic tool calling.
184    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
185        self.mcp_client_config = Some(config);
186        self
187    }
188
189    /// Enable runner throughput logging.
190    pub fn with_throughput_logging(mut self) -> Self {
191        self.throughput_logging = true;
192        self
193    }
194
195    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
196    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
197        self.jinja_explicit = Some(jinja_explicit);
198        self
199    }
200
201    /// Set the prompt batchsize to use for inference.
202    pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
203        self.prompt_chunksize = Some(prompt_chunksize);
204        self
205    }
206
207    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
208    pub fn with_topology(mut self, topology: Topology) -> Self {
209        self.topology = Some(topology);
210        self
211    }
212
213    /// Organize ISQ to enable MoQE (Mixture of Quantized Experts, <https://arxiv.org/abs/2310.02410>)
214    pub fn with_mixture_qexperts_isq(mut self) -> Self {
215        self.organization = IsqOrganization::MoeExpertsOnly;
216        self
217    }
218
219    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
220    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
221        self.chat_template = Some(chat_template.to_string());
222        self
223    }
224
225    /// Path to a discrete `tokenizer.json` file.
226    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
227        self.tokenizer_json = Some(tokenizer_json.to_string());
228        self
229    }
230
231    /// Manually set the model loader type. Otherwise, it will attempt to automatically
232    /// determine the loader type.
233    pub fn with_loader_type(mut self, loader_type: NormalLoaderType) -> Self {
234        self.loader_type = Some(loader_type);
235        self
236    }
237
238    /// Load the model in a certain dtype.
239    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
240        self.dtype = dtype;
241        self
242    }
243
244    /// Force usage of the CPU device. Do not use PagedAttention with this.
245    pub fn with_force_cpu(mut self) -> Self {
246        self.force_cpu = true;
247        self
248    }
249
250    /// Source of the Hugging Face token.
251    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
252        self.token_source = token_source;
253        self
254    }
255
256    /// Set the revision to use for a Hugging Face remote model.
257    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
258        self.hf_revision = Some(revision.to_string());
259        self
260    }
261
262    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
263    pub fn with_isq(mut self, isq: IsqType) -> Self {
264        self.isq = Some(isq);
265        self
266    }
267
268    /// Utilise this imatrix file during ISQ. Incompatible with specifying a calibration file.
269    pub fn with_imatrix(mut self, path: PathBuf) -> Self {
270        self.imatrix = Some(path);
271        self
272    }
273
274    /// Utilise this calibration file to collcet an imatrix. Incompatible with specifying a calibration file.
275    pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
276        self.calibration_file = Some(path);
277        self
278    }
279
280    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
281    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`].
282    ///
283    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
284    pub fn with_paged_attn(
285        mut self,
286        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
287    ) -> anyhow::Result<Self> {
288        if paged_attn_supported() {
289            self.paged_attn_cfg = Some(paged_attn_cfg()?);
290        } else {
291            self.paged_attn_cfg = None;
292        }
293        Ok(self)
294    }
295
296    /// Set the maximum number of sequences which can be run at once.
297    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
298        self.max_num_seqs = max_num_seqs;
299        self
300    }
301
302    /// Disable KV cache. Trade performance for memory usage.
303    pub fn with_no_kv_cache(mut self) -> Self {
304        self.no_kv_cache = true;
305        self
306    }
307
308    /// Set the number of sequences to hold in the prefix cache. Set to `None` to disable the prefix cacher.
309    pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
310        self.prefix_cache_n = n_seqs;
311        self
312    }
313
314    /// Enable logging.
315    pub fn with_logging(mut self) -> Self {
316        self.with_logging = true;
317        self
318    }
319
320    /// Provide metadata to initialize the device mapper.
321    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
322        self.device_mapping = Some(device_mapping);
323        self
324    }
325
326    /// Path to read a UQFF file from.
327    pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
328        self.from_uqff = Some(path);
329        self
330    }
331
332    /// Path to write a UQFF file to.
333    ///
334    /// The parent (part of the path excluding the filename) will determine where any other files
335    /// generated are written to. These can be used to load UQFF models standalone, and may include:
336    /// - `residual.safetensors`
337    /// - `tokenizer.json`
338    /// - `config.json`
339    /// - And others
340    pub fn write_uqff(mut self, path: PathBuf) -> Self {
341        self.write_uqff = Some(path);
342        self
343    }
344
345    /// Cache path for Hugging Face models downloaded locally
346    pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
347        self.hf_cache_path = Some(hf_cache_path);
348        self
349    }
350
351    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
352    pub fn with_device(mut self, device: Device) -> Self {
353        self.device = Some(device);
354        self
355    }
356
357    pub async fn build(self) -> anyhow::Result<Model> {
358        let config = NormalSpecificConfig {
359            prompt_chunksize: self.prompt_chunksize,
360            topology: self.topology,
361            organization: self.organization,
362            write_uqff: self.write_uqff,
363            from_uqff: self.from_uqff,
364            imatrix: self.imatrix,
365            calibration_file: self.calibration_file,
366            hf_cache_path: self.hf_cache_path,
367        };
368
369        if self.with_logging {
370            initialize_logging();
371        }
372
373        let loader = NormalLoaderBuilder::new(
374            config,
375            self.chat_template,
376            self.tokenizer_json,
377            Some(self.model_id),
378            self.no_kv_cache,
379            self.jinja_explicit,
380        )
381        .build(self.loader_type)?;
382
383        // Load, into a Pipeline
384        let pipeline = loader.load_model_from_hf(
385            self.hf_revision,
386            self.token_source,
387            &self.dtype,
388            &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
389            !self.with_logging,
390            self.device_mapping
391                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
392            self.isq,
393            self.paged_attn_cfg,
394        )?;
395
396        let scheduler_method = match self.paged_attn_cfg {
397            Some(_) => {
398                let config = pipeline
399                    .lock()
400                    .await
401                    .get_metadata()
402                    .cache_config
403                    .as_ref()
404                    .cloned();
405
406                if let Some(config) = config {
407                    SchedulerConfig::PagedAttentionMeta {
408                        max_num_seqs: self.max_num_seqs,
409                        config,
410                    }
411                } else {
412                    SchedulerConfig::DefaultScheduler {
413                        method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
414                    }
415                }
416            }
417            None => SchedulerConfig::DefaultScheduler {
418                method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
419            },
420        };
421
422        let mut runner = MistralRsBuilder::new(
423            pipeline,
424            scheduler_method,
425            self.throughput_logging,
426            self.search_bert_model,
427        );
428        if let Some(cb) = self.search_callback.clone() {
429            runner = runner.with_search_callback(cb);
430        }
431        for (name, cb) in &self.tool_callbacks {
432            runner = runner.with_tool_callback(name.clone(), cb.clone());
433        }
434        for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
435            runner = runner.with_tool_callback_and_tool(
436                name.clone(),
437                callback_with_tool.callback.clone(),
438                callback_with_tool.tool.clone(),
439            );
440        }
441        if let Some(mcp_config) = self.mcp_client_config {
442            runner = runner.with_mcp_client(mcp_config);
443        }
444        runner = runner
445            .with_no_kv_cache(self.no_kv_cache)
446            .with_no_prefix_cache(self.prefix_cache_n.is_none());
447
448        if let Some(n) = self.prefix_cache_n {
449            runner = runner.with_prefix_cache_n(n)
450        }
451
452        Ok(Model::new(runner.build().await))
453    }
454}
455
456#[derive(Clone)]
457/// Configure a UQFF text model with the various parameters for loading, running, and other inference behaviors.
458/// This wraps and implements `DerefMut` for the TextModelBuilder, so users should take care to not call UQFF-related methods.
459pub struct UqffTextModelBuilder(TextModelBuilder);
460
461impl UqffTextModelBuilder {
462    /// A few defaults are applied here:
463    /// - MoQE ISQ organization
464    /// - Token source is from the cache (.cache/huggingface/token)
465    /// - Maximum number of sequences running is 32
466    /// - Number of sequences to hold in prefix cache is 16.
467    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
468    pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
469        let mut inner = TextModelBuilder::new(model_id);
470        inner = inner.from_uqff(uqff_file);
471        Self(inner)
472    }
473
474    pub async fn build(self) -> anyhow::Result<Model> {
475        self.0.build().await
476    }
477
478    /// This wraps the VisionModelBuilder, so users should take care to not call UQFF-related methods.
479    pub fn into_inner(self) -> TextModelBuilder {
480        self.0
481    }
482}
483
484impl Deref for UqffTextModelBuilder {
485    type Target = TextModelBuilder;
486
487    fn deref(&self) -> &Self::Target {
488        &self.0
489    }
490}
491
492impl DerefMut for UqffTextModelBuilder {
493    fn deref_mut(&mut self) -> &mut Self::Target {
494        &mut self.0
495    }
496}
497
498impl From<UqffTextModelBuilder> for TextModelBuilder {
499    fn from(value: UqffTextModelBuilder) -> Self {
500        value.0
501    }
502}