mistralrs/
text_model.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6    ops::{Deref, DerefMut},
7    path::PathBuf,
8    sync::Arc,
9};
10
11use crate::{best_device, Model};
12
13/// A tool callback with its associated Tool definition.
14#[derive(Clone)]
15pub struct ToolCallbackWithTool {
16    pub callback: Arc<ToolCallback>,
17    pub tool: Tool,
18}
19
20#[derive(Clone)]
21/// Configure a text model with the various parameters for loading, running, and other inference behaviors.
22pub struct TextModelBuilder {
23    // Loading model
24    pub(crate) model_id: String,
25    pub(crate) token_source: TokenSource,
26    pub(crate) hf_revision: Option<String>,
27    pub(crate) write_uqff: Option<PathBuf>,
28    pub(crate) from_uqff: Option<Vec<PathBuf>>,
29    pub(crate) imatrix: Option<PathBuf>,
30    pub(crate) calibration_file: Option<PathBuf>,
31    pub(crate) chat_template: Option<String>,
32    pub(crate) jinja_explicit: Option<String>,
33    pub(crate) tokenizer_json: Option<String>,
34    pub(crate) device_mapping: Option<DeviceMapSetting>,
35    pub(crate) hf_cache_path: Option<PathBuf>,
36    pub(crate) search_bert_model: Option<BertEmbeddingModel>,
37    pub(crate) search_callback: Option<Arc<SearchCallback>>,
38    pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
39    pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
40    pub(crate) mcp_client_config: Option<McpClientConfig>,
41    pub(crate) device: Option<Device>,
42    pub(crate) matformer_config_path: Option<PathBuf>,
43    pub(crate) matformer_slice_name: Option<String>,
44
45    // Model running
46    pub(crate) topology: Option<Topology>,
47    pub(crate) organization: IsqOrganization,
48    pub(crate) loader_type: Option<NormalLoaderType>,
49    pub(crate) dtype: ModelDType,
50    pub(crate) force_cpu: bool,
51    pub(crate) isq: Option<IsqType>,
52    pub(crate) throughput_logging: bool,
53
54    // Other things
55    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
56    pub(crate) max_num_seqs: usize,
57    pub(crate) no_kv_cache: bool,
58    pub(crate) with_logging: bool,
59    pub(crate) prefix_cache_n: Option<usize>,
60}
61
62/// Builder for PagedAttention metadata.
63pub struct PagedAttentionMetaBuilder {
64    block_size: Option<usize>,
65    mem_gpu: MemoryGpuConfig,
66    cache_type: PagedCacheType,
67}
68
69impl Default for PagedAttentionMetaBuilder {
70    fn default() -> Self {
71        Self {
72            block_size: None,
73            mem_gpu: MemoryGpuConfig::ContextSize(4096),
74            cache_type: PagedCacheType::Auto,
75        }
76    }
77}
78
79impl PagedAttentionMetaBuilder {
80    pub fn with_block_size(mut self, block_size: usize) -> Self {
81        self.block_size = Some(block_size);
82        self
83    }
84
85    pub fn with_gpu_memory(mut self, mem_gpu: MemoryGpuConfig) -> Self {
86        self.mem_gpu = mem_gpu;
87        self
88    }
89
90    pub fn with_paged_cache_type(mut self, cache_type: PagedCacheType) -> Self {
91        self.cache_type = cache_type;
92        self
93    }
94
95    pub fn build(self) -> anyhow::Result<PagedAttentionConfig> {
96        PagedAttentionConfig::new(self.block_size, self.mem_gpu, self.cache_type)
97    }
98}
99
100impl TextModelBuilder {
101    /// A few defaults are applied here:
102    /// - MoQE ISQ organization
103    /// - Token source is from the cache (.cache/huggingface/token)
104    /// - Maximum number of sequences running is 32
105    /// - Number of sequences to hold in prefix cache is 16.
106    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
107    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
108    pub fn new(model_id: impl ToString) -> Self {
109        Self {
110            model_id: model_id.to_string(),
111            topology: None,
112            organization: IsqOrganization::Default,
113            write_uqff: None,
114            from_uqff: None,
115            chat_template: None,
116            tokenizer_json: None,
117            loader_type: None,
118            dtype: ModelDType::Auto,
119            force_cpu: false,
120            token_source: TokenSource::CacheToken,
121            hf_revision: None,
122            isq: None,
123            paged_attn_cfg: None,
124            max_num_seqs: 32,
125            no_kv_cache: false,
126            prefix_cache_n: Some(16),
127            with_logging: false,
128            device_mapping: None,
129            imatrix: None,
130            calibration_file: None,
131            jinja_explicit: None,
132            throughput_logging: false,
133            hf_cache_path: None,
134            search_bert_model: None,
135            search_callback: None,
136            tool_callbacks: HashMap::new(),
137            tool_callbacks_with_tools: HashMap::new(),
138            mcp_client_config: None,
139            device: None,
140            matformer_config_path: None,
141            matformer_slice_name: None,
142        }
143    }
144
145    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified or the default.
146    pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
147        self.search_bert_model = Some(search_bert_model);
148        self
149    }
150
151    /// Override the search function used when `web_search_options` is enabled.
152    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
153        self.search_callback = Some(callback);
154        self
155    }
156
157    /// Register a callback for a specific tool name.
158    pub fn with_tool_callback(
159        mut self,
160        name: impl Into<String>,
161        callback: Arc<ToolCallback>,
162    ) -> Self {
163        self.tool_callbacks.insert(name.into(), callback);
164        self
165    }
166
167    /// Register a callback with an associated Tool definition that will be automatically
168    /// added to requests when tool callbacks are active.
169    pub fn with_tool_callback_and_tool(
170        mut self,
171        name: impl Into<String>,
172        callback: Arc<ToolCallback>,
173        tool: Tool,
174    ) -> Self {
175        let name = name.into();
176        self.tool_callbacks_with_tools
177            .insert(name, ToolCallbackWithTool { callback, tool });
178        self
179    }
180
181    /// Configure MCP client to connect to external MCP servers and automatically
182    /// register their tools for use in automatic tool calling.
183    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
184        self.mcp_client_config = Some(config);
185        self
186    }
187
188    /// Enable runner throughput logging.
189    pub fn with_throughput_logging(mut self) -> Self {
190        self.throughput_logging = true;
191        self
192    }
193
194    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
195    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
196        self.jinja_explicit = Some(jinja_explicit);
197        self
198    }
199
200    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
201    pub fn with_topology(mut self, topology: Topology) -> Self {
202        self.topology = Some(topology);
203        self
204    }
205
206    /// Organize ISQ to enable MoQE (Mixture of Quantized Experts, <https://arxiv.org/abs/2310.02410>)
207    pub fn with_mixture_qexperts_isq(mut self) -> Self {
208        self.organization = IsqOrganization::MoeExpertsOnly;
209        self
210    }
211
212    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
213    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
214        self.chat_template = Some(chat_template.to_string());
215        self
216    }
217
218    /// Path to a discrete `tokenizer.json` file.
219    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
220        self.tokenizer_json = Some(tokenizer_json.to_string());
221        self
222    }
223
224    /// Manually set the model loader type. Otherwise, it will attempt to automatically
225    /// determine the loader type.
226    pub fn with_loader_type(mut self, loader_type: NormalLoaderType) -> Self {
227        self.loader_type = Some(loader_type);
228        self
229    }
230
231    /// Load the model in a certain dtype.
232    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
233        self.dtype = dtype;
234        self
235    }
236
237    /// Force usage of the CPU device. Do not use PagedAttention with this.
238    pub fn with_force_cpu(mut self) -> Self {
239        self.force_cpu = true;
240        self
241    }
242
243    /// Source of the Hugging Face token.
244    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
245        self.token_source = token_source;
246        self
247    }
248
249    /// Set the revision to use for a Hugging Face remote model.
250    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
251        self.hf_revision = Some(revision.to_string());
252        self
253    }
254
255    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
256    pub fn with_isq(mut self, isq: IsqType) -> Self {
257        self.isq = Some(isq);
258        self
259    }
260
261    /// Utilise this imatrix file during ISQ. Incompatible with specifying a calibration file.
262    pub fn with_imatrix(mut self, path: PathBuf) -> Self {
263        self.imatrix = Some(path);
264        self
265    }
266
267    /// Utilise this calibration file to collcet an imatrix. Incompatible with specifying a calibration file.
268    pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
269        self.calibration_file = Some(path);
270        self
271    }
272
273    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
274    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`].
275    ///
276    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
277    pub fn with_paged_attn(
278        mut self,
279        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
280    ) -> anyhow::Result<Self> {
281        if paged_attn_supported() {
282            self.paged_attn_cfg = Some(paged_attn_cfg()?);
283        } else {
284            self.paged_attn_cfg = None;
285        }
286        Ok(self)
287    }
288
289    /// Set the maximum number of sequences which can be run at once.
290    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
291        self.max_num_seqs = max_num_seqs;
292        self
293    }
294
295    /// Disable KV cache. Trade performance for memory usage.
296    pub fn with_no_kv_cache(mut self) -> Self {
297        self.no_kv_cache = true;
298        self
299    }
300
301    /// Set the number of sequences to hold in the prefix cache. Set to `None` to disable the prefix cacher.
302    pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
303        self.prefix_cache_n = n_seqs;
304        self
305    }
306
307    /// Enable logging.
308    pub fn with_logging(mut self) -> Self {
309        self.with_logging = true;
310        self
311    }
312
313    /// Provide metadata to initialize the device mapper.
314    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
315        self.device_mapping = Some(device_mapping);
316        self
317    }
318
319    #[deprecated(
320        note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
321    )]
322    /// Path to read a `.uqff` file from. Other necessary configuration files must be present at this location.
323    ///
324    /// For example, these include:
325    /// - `residual.safetensors`
326    /// - `tokenizer.json`
327    /// - `config.json`
328    /// - More depending on the model
329    pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
330        self.from_uqff = Some(path);
331        self
332    }
333
334    /// Path to write a `.uqff` file to and serialize the other necessary files.
335    ///
336    /// The parent (part of the path excluding the filename) will determine where any other files
337    /// serialized are written to.
338    ///
339    /// For example, these include:
340    /// - `residual.safetensors`
341    /// - `tokenizer.json`
342    /// - `config.json`
343    /// - More depending on the model
344    pub fn write_uqff(mut self, path: PathBuf) -> Self {
345        self.write_uqff = Some(path);
346        self
347    }
348
349    /// Cache path for Hugging Face models downloaded locally
350    pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
351        self.hf_cache_path = Some(hf_cache_path);
352        self
353    }
354
355    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
356    pub fn with_device(mut self, device: Device) -> Self {
357        self.device = Some(device);
358        self
359    }
360
361    /// Path to a Matryoshka Transformer configuration CSV file.
362    pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
363        self.matformer_config_path = Some(path);
364        self
365    }
366
367    /// Name of the slice to use from the Matryoshka Transformer configuration.
368    pub fn with_matformer_slice_name(mut self, name: String) -> Self {
369        self.matformer_slice_name = Some(name);
370        self
371    }
372
373    pub async fn build(self) -> anyhow::Result<Model> {
374        let config = NormalSpecificConfig {
375            topology: self.topology,
376            organization: self.organization,
377            write_uqff: self.write_uqff,
378            from_uqff: self.from_uqff,
379            imatrix: self.imatrix,
380            calibration_file: self.calibration_file,
381            hf_cache_path: self.hf_cache_path,
382            matformer_config_path: self.matformer_config_path,
383            matformer_slice_name: self.matformer_slice_name,
384        };
385
386        if self.with_logging {
387            initialize_logging();
388        }
389
390        let loader = NormalLoaderBuilder::new(
391            config,
392            self.chat_template,
393            self.tokenizer_json,
394            Some(self.model_id),
395            self.no_kv_cache,
396            self.jinja_explicit,
397        )
398        .build(self.loader_type)?;
399
400        // Load, into a Pipeline
401        let pipeline = loader.load_model_from_hf(
402            self.hf_revision,
403            self.token_source,
404            &self.dtype,
405            &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
406            !self.with_logging,
407            self.device_mapping
408                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
409            self.isq,
410            self.paged_attn_cfg,
411        )?;
412
413        let scheduler_method = match self.paged_attn_cfg {
414            Some(_) => {
415                let config = pipeline
416                    .lock()
417                    .await
418                    .get_metadata()
419                    .cache_config
420                    .as_ref()
421                    .cloned();
422
423                if let Some(config) = config {
424                    SchedulerConfig::PagedAttentionMeta {
425                        max_num_seqs: self.max_num_seqs,
426                        config,
427                    }
428                } else {
429                    SchedulerConfig::DefaultScheduler {
430                        method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
431                    }
432                }
433            }
434            None => SchedulerConfig::DefaultScheduler {
435                method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
436            },
437        };
438
439        let mut runner = MistralRsBuilder::new(
440            pipeline,
441            scheduler_method,
442            self.throughput_logging,
443            self.search_bert_model,
444        );
445        if let Some(cb) = self.search_callback.clone() {
446            runner = runner.with_search_callback(cb);
447        }
448        for (name, cb) in &self.tool_callbacks {
449            runner = runner.with_tool_callback(name.clone(), cb.clone());
450        }
451        for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
452            runner = runner.with_tool_callback_and_tool(
453                name.clone(),
454                callback_with_tool.callback.clone(),
455                callback_with_tool.tool.clone(),
456            );
457        }
458        if let Some(mcp_config) = self.mcp_client_config {
459            runner = runner.with_mcp_client(mcp_config);
460        }
461        runner = runner
462            .with_no_kv_cache(self.no_kv_cache)
463            .with_no_prefix_cache(self.prefix_cache_n.is_none());
464
465        if let Some(n) = self.prefix_cache_n {
466            runner = runner.with_prefix_cache_n(n)
467        }
468
469        Ok(Model::new(runner.build().await))
470    }
471}
472
473#[derive(Clone)]
474/// Configure a UQFF text model with the various parameters for loading, running, and other inference behaviors.
475/// This wraps and implements `DerefMut` for the TextModelBuilder, so users should take care to not call UQFF-related methods.
476pub struct UqffTextModelBuilder(TextModelBuilder);
477
478impl UqffTextModelBuilder {
479    /// A few defaults are applied here:
480    /// - MoQE ISQ organization
481    /// - Token source is from the cache (.cache/huggingface/token)
482    /// - Maximum number of sequences running is 32
483    /// - Number of sequences to hold in prefix cache is 16.
484    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
485    pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
486        let mut inner = TextModelBuilder::new(model_id);
487        inner.from_uqff = Some(uqff_file);
488        Self(inner)
489    }
490
491    pub async fn build(self) -> anyhow::Result<Model> {
492        self.0.build().await
493    }
494
495    /// This wraps the VisionModelBuilder, so users should take care to not call UQFF-related methods.
496    pub fn into_inner(self) -> TextModelBuilder {
497        self.0
498    }
499}
500
501impl Deref for UqffTextModelBuilder {
502    type Target = TextModelBuilder;
503
504    fn deref(&self) -> &Self::Target {
505        &self.0
506    }
507}
508
509impl DerefMut for UqffTextModelBuilder {
510    fn deref_mut(&mut self) -> &mut Self::Target {
511        &mut self.0
512    }
513}
514
515impl From<UqffTextModelBuilder> for TextModelBuilder {
516    fn from(value: UqffTextModelBuilder) -> Self {
517        value.0
518    }
519}