mistralrs/
text_model.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6    num::NonZeroUsize,
7    ops::{Deref, DerefMut},
8    path::PathBuf,
9    sync::Arc,
10};
11
12use crate::{best_device, Model};
13
14/// A tool callback with its associated Tool definition.
15#[derive(Clone)]
16pub struct ToolCallbackWithTool {
17    pub callback: Arc<ToolCallback>,
18    pub tool: Tool,
19}
20
21#[derive(Clone)]
22/// Configure a text model with the various parameters for loading, running, and other inference behaviors.
23pub struct TextModelBuilder {
24    // Loading model
25    pub(crate) model_id: String,
26    pub(crate) token_source: TokenSource,
27    pub(crate) hf_revision: Option<String>,
28    pub(crate) write_uqff: Option<PathBuf>,
29    pub(crate) from_uqff: Option<Vec<PathBuf>>,
30    pub(crate) imatrix: Option<PathBuf>,
31    pub(crate) calibration_file: Option<PathBuf>,
32    pub(crate) chat_template: Option<String>,
33    pub(crate) jinja_explicit: Option<String>,
34    pub(crate) tokenizer_json: Option<String>,
35    pub(crate) device_mapping: Option<DeviceMapSetting>,
36    pub(crate) hf_cache_path: Option<PathBuf>,
37    pub(crate) search_bert_model: Option<BertEmbeddingModel>,
38    pub(crate) search_callback: Option<Arc<SearchCallback>>,
39    pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
40    pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
41    pub(crate) mcp_client_config: Option<McpClientConfig>,
42    pub(crate) device: Option<Device>,
43    pub(crate) matformer_config_path: Option<PathBuf>,
44    pub(crate) matformer_slice_name: Option<String>,
45
46    // Model running
47    pub(crate) prompt_chunksize: Option<NonZeroUsize>,
48    pub(crate) topology: Option<Topology>,
49    pub(crate) organization: IsqOrganization,
50    pub(crate) loader_type: Option<NormalLoaderType>,
51    pub(crate) dtype: ModelDType,
52    pub(crate) force_cpu: bool,
53    pub(crate) isq: Option<IsqType>,
54    pub(crate) throughput_logging: bool,
55
56    // Other things
57    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
58    pub(crate) max_num_seqs: usize,
59    pub(crate) no_kv_cache: bool,
60    pub(crate) with_logging: bool,
61    pub(crate) prefix_cache_n: Option<usize>,
62}
63
64/// Builder for PagedAttention metadata.
65pub struct PagedAttentionMetaBuilder {
66    block_size: Option<usize>,
67    mem_cpu: usize,
68    mem_gpu: MemoryGpuConfig,
69    cache_type: PagedCacheType,
70}
71
72impl Default for PagedAttentionMetaBuilder {
73    fn default() -> Self {
74        Self {
75            block_size: None,
76            mem_cpu: 64,
77            mem_gpu: MemoryGpuConfig::ContextSize(4096),
78            cache_type: PagedCacheType::Auto,
79        }
80    }
81}
82
83impl PagedAttentionMetaBuilder {
84    pub fn with_block_size(mut self, block_size: usize) -> Self {
85        self.block_size = Some(block_size);
86        self
87    }
88
89    pub fn with_gpu_memory(mut self, mem_gpu: MemoryGpuConfig) -> Self {
90        self.mem_gpu = mem_gpu;
91        self
92    }
93
94    pub fn with_paged_cache_type(mut self, cache_type: PagedCacheType) -> Self {
95        self.cache_type = cache_type;
96        self
97    }
98
99    pub fn build(self) -> anyhow::Result<PagedAttentionConfig> {
100        PagedAttentionConfig::new(self.block_size, self.mem_cpu, self.mem_gpu, self.cache_type)
101    }
102}
103
104impl TextModelBuilder {
105    /// A few defaults are applied here:
106    /// - MoQE ISQ organization
107    /// - Token source is from the cache (.cache/huggingface/token)
108    /// - Maximum number of sequences running is 32
109    /// - Number of sequences to hold in prefix cache is 16.
110    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
111    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
112    pub fn new(model_id: impl ToString) -> Self {
113        Self {
114            model_id: model_id.to_string(),
115            prompt_chunksize: None,
116            topology: None,
117            organization: IsqOrganization::Default,
118            write_uqff: None,
119            from_uqff: None,
120            chat_template: None,
121            tokenizer_json: None,
122            loader_type: None,
123            dtype: ModelDType::Auto,
124            force_cpu: false,
125            token_source: TokenSource::CacheToken,
126            hf_revision: None,
127            isq: None,
128            paged_attn_cfg: None,
129            max_num_seqs: 32,
130            no_kv_cache: false,
131            prefix_cache_n: Some(16),
132            with_logging: false,
133            device_mapping: None,
134            imatrix: None,
135            calibration_file: None,
136            jinja_explicit: None,
137            throughput_logging: false,
138            hf_cache_path: None,
139            search_bert_model: None,
140            search_callback: None,
141            tool_callbacks: HashMap::new(),
142            tool_callbacks_with_tools: HashMap::new(),
143            mcp_client_config: None,
144            device: None,
145            matformer_config_path: None,
146            matformer_slice_name: None,
147        }
148    }
149
150    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified or the default.
151    pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
152        self.search_bert_model = Some(search_bert_model);
153        self
154    }
155
156    /// Override the search function used when `web_search_options` is enabled.
157    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
158        self.search_callback = Some(callback);
159        self
160    }
161
162    /// Register a callback for a specific tool name.
163    pub fn with_tool_callback(
164        mut self,
165        name: impl Into<String>,
166        callback: Arc<ToolCallback>,
167    ) -> Self {
168        self.tool_callbacks.insert(name.into(), callback);
169        self
170    }
171
172    /// Register a callback with an associated Tool definition that will be automatically
173    /// added to requests when tool callbacks are active.
174    pub fn with_tool_callback_and_tool(
175        mut self,
176        name: impl Into<String>,
177        callback: Arc<ToolCallback>,
178        tool: Tool,
179    ) -> Self {
180        let name = name.into();
181        self.tool_callbacks_with_tools
182            .insert(name, ToolCallbackWithTool { callback, tool });
183        self
184    }
185
186    /// Configure MCP client to connect to external MCP servers and automatically
187    /// register their tools for use in automatic tool calling.
188    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
189        self.mcp_client_config = Some(config);
190        self
191    }
192
193    /// Enable runner throughput logging.
194    pub fn with_throughput_logging(mut self) -> Self {
195        self.throughput_logging = true;
196        self
197    }
198
199    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
200    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
201        self.jinja_explicit = Some(jinja_explicit);
202        self
203    }
204
205    /// Set the prompt batchsize to use for inference.
206    pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
207        self.prompt_chunksize = Some(prompt_chunksize);
208        self
209    }
210
211    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
212    pub fn with_topology(mut self, topology: Topology) -> Self {
213        self.topology = Some(topology);
214        self
215    }
216
217    /// Organize ISQ to enable MoQE (Mixture of Quantized Experts, <https://arxiv.org/abs/2310.02410>)
218    pub fn with_mixture_qexperts_isq(mut self) -> Self {
219        self.organization = IsqOrganization::MoeExpertsOnly;
220        self
221    }
222
223    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
224    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
225        self.chat_template = Some(chat_template.to_string());
226        self
227    }
228
229    /// Path to a discrete `tokenizer.json` file.
230    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
231        self.tokenizer_json = Some(tokenizer_json.to_string());
232        self
233    }
234
235    /// Manually set the model loader type. Otherwise, it will attempt to automatically
236    /// determine the loader type.
237    pub fn with_loader_type(mut self, loader_type: NormalLoaderType) -> Self {
238        self.loader_type = Some(loader_type);
239        self
240    }
241
242    /// Load the model in a certain dtype.
243    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
244        self.dtype = dtype;
245        self
246    }
247
248    /// Force usage of the CPU device. Do not use PagedAttention with this.
249    pub fn with_force_cpu(mut self) -> Self {
250        self.force_cpu = true;
251        self
252    }
253
254    /// Source of the Hugging Face token.
255    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
256        self.token_source = token_source;
257        self
258    }
259
260    /// Set the revision to use for a Hugging Face remote model.
261    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
262        self.hf_revision = Some(revision.to_string());
263        self
264    }
265
266    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
267    pub fn with_isq(mut self, isq: IsqType) -> Self {
268        self.isq = Some(isq);
269        self
270    }
271
272    /// Utilise this imatrix file during ISQ. Incompatible with specifying a calibration file.
273    pub fn with_imatrix(mut self, path: PathBuf) -> Self {
274        self.imatrix = Some(path);
275        self
276    }
277
278    /// Utilise this calibration file to collcet an imatrix. Incompatible with specifying a calibration file.
279    pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
280        self.calibration_file = Some(path);
281        self
282    }
283
284    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
285    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`].
286    ///
287    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
288    pub fn with_paged_attn(
289        mut self,
290        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
291    ) -> anyhow::Result<Self> {
292        if paged_attn_supported() {
293            self.paged_attn_cfg = Some(paged_attn_cfg()?);
294        } else {
295            self.paged_attn_cfg = None;
296        }
297        Ok(self)
298    }
299
300    /// Set the maximum number of sequences which can be run at once.
301    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
302        self.max_num_seqs = max_num_seqs;
303        self
304    }
305
306    /// Disable KV cache. Trade performance for memory usage.
307    pub fn with_no_kv_cache(mut self) -> Self {
308        self.no_kv_cache = true;
309        self
310    }
311
312    /// Set the number of sequences to hold in the prefix cache. Set to `None` to disable the prefix cacher.
313    pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
314        self.prefix_cache_n = n_seqs;
315        self
316    }
317
318    /// Enable logging.
319    pub fn with_logging(mut self) -> Self {
320        self.with_logging = true;
321        self
322    }
323
324    /// Provide metadata to initialize the device mapper.
325    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
326        self.device_mapping = Some(device_mapping);
327        self
328    }
329
330    #[deprecated(
331        note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
332    )]
333    /// Path to read a `.uqff` file from. Other necessary configuration files must be present at this location.
334    ///
335    /// For example, these include:
336    /// - `residual.safetensors`
337    /// - `tokenizer.json`
338    /// - `config.json`
339    /// - More depending on the model
340    pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
341        self.from_uqff = Some(path);
342        self
343    }
344
345    /// Path to write a `.uqff` file to and serialize the other necessary files.
346    ///
347    /// The parent (part of the path excluding the filename) will determine where any other files
348    /// serialized are written to.
349    ///
350    /// For example, these include:
351    /// - `residual.safetensors`
352    /// - `tokenizer.json`
353    /// - `config.json`
354    /// - More depending on the model
355    pub fn write_uqff(mut self, path: PathBuf) -> Self {
356        self.write_uqff = Some(path);
357        self
358    }
359
360    /// Cache path for Hugging Face models downloaded locally
361    pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
362        self.hf_cache_path = Some(hf_cache_path);
363        self
364    }
365
366    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
367    pub fn with_device(mut self, device: Device) -> Self {
368        self.device = Some(device);
369        self
370    }
371
372    /// Path to a Matryoshka Transformer configuration CSV file.
373    pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
374        self.matformer_config_path = Some(path);
375        self
376    }
377
378    /// Name of the slice to use from the Matryoshka Transformer configuration.
379    pub fn with_matformer_slice_name(mut self, name: String) -> Self {
380        self.matformer_slice_name = Some(name);
381        self
382    }
383
384    pub async fn build(self) -> anyhow::Result<Model> {
385        let config = NormalSpecificConfig {
386            prompt_chunksize: self.prompt_chunksize,
387            topology: self.topology,
388            organization: self.organization,
389            write_uqff: self.write_uqff,
390            from_uqff: self.from_uqff,
391            imatrix: self.imatrix,
392            calibration_file: self.calibration_file,
393            hf_cache_path: self.hf_cache_path,
394            matformer_config_path: self.matformer_config_path,
395            matformer_slice_name: self.matformer_slice_name,
396        };
397
398        if self.with_logging {
399            initialize_logging();
400        }
401
402        let loader = NormalLoaderBuilder::new(
403            config,
404            self.chat_template,
405            self.tokenizer_json,
406            Some(self.model_id),
407            self.no_kv_cache,
408            self.jinja_explicit,
409        )
410        .build(self.loader_type)?;
411
412        // Load, into a Pipeline
413        let pipeline = loader.load_model_from_hf(
414            self.hf_revision,
415            self.token_source,
416            &self.dtype,
417            &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
418            !self.with_logging,
419            self.device_mapping
420                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
421            self.isq,
422            self.paged_attn_cfg,
423        )?;
424
425        let scheduler_method = match self.paged_attn_cfg {
426            Some(_) => {
427                let config = pipeline
428                    .lock()
429                    .await
430                    .get_metadata()
431                    .cache_config
432                    .as_ref()
433                    .cloned();
434
435                if let Some(config) = config {
436                    SchedulerConfig::PagedAttentionMeta {
437                        max_num_seqs: self.max_num_seqs,
438                        config,
439                    }
440                } else {
441                    SchedulerConfig::DefaultScheduler {
442                        method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
443                    }
444                }
445            }
446            None => SchedulerConfig::DefaultScheduler {
447                method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
448            },
449        };
450
451        let mut runner = MistralRsBuilder::new(
452            pipeline,
453            scheduler_method,
454            self.throughput_logging,
455            self.search_bert_model,
456        );
457        if let Some(cb) = self.search_callback.clone() {
458            runner = runner.with_search_callback(cb);
459        }
460        for (name, cb) in &self.tool_callbacks {
461            runner = runner.with_tool_callback(name.clone(), cb.clone());
462        }
463        for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
464            runner = runner.with_tool_callback_and_tool(
465                name.clone(),
466                callback_with_tool.callback.clone(),
467                callback_with_tool.tool.clone(),
468            );
469        }
470        if let Some(mcp_config) = self.mcp_client_config {
471            runner = runner.with_mcp_client(mcp_config);
472        }
473        runner = runner
474            .with_no_kv_cache(self.no_kv_cache)
475            .with_no_prefix_cache(self.prefix_cache_n.is_none());
476
477        if let Some(n) = self.prefix_cache_n {
478            runner = runner.with_prefix_cache_n(n)
479        }
480
481        Ok(Model::new(runner.build().await))
482    }
483}
484
485#[derive(Clone)]
486/// Configure a UQFF text model with the various parameters for loading, running, and other inference behaviors.
487/// This wraps and implements `DerefMut` for the TextModelBuilder, so users should take care to not call UQFF-related methods.
488pub struct UqffTextModelBuilder(TextModelBuilder);
489
490impl UqffTextModelBuilder {
491    /// A few defaults are applied here:
492    /// - MoQE ISQ organization
493    /// - Token source is from the cache (.cache/huggingface/token)
494    /// - Maximum number of sequences running is 32
495    /// - Number of sequences to hold in prefix cache is 16.
496    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
497    pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
498        let mut inner = TextModelBuilder::new(model_id);
499        inner.from_uqff = Some(uqff_file);
500        Self(inner)
501    }
502
503    pub async fn build(self) -> anyhow::Result<Model> {
504        self.0.build().await
505    }
506
507    /// This wraps the VisionModelBuilder, so users should take care to not call UQFF-related methods.
508    pub fn into_inner(self) -> TextModelBuilder {
509        self.0
510    }
511}
512
513impl Deref for UqffTextModelBuilder {
514    type Target = TextModelBuilder;
515
516    fn deref(&self) -> &Self::Target {
517        &self.0
518    }
519}
520
521impl DerefMut for UqffTextModelBuilder {
522    fn deref_mut(&mut self) -> &mut Self::Target {
523        &mut self.0
524    }
525}
526
527impl From<UqffTextModelBuilder> for TextModelBuilder {
528    fn from(value: UqffTextModelBuilder) -> Self {
529        value.0
530    }
531}