mistralrs/
text_model.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6    ops::{Deref, DerefMut},
7    path::PathBuf,
8    sync::Arc,
9};
10
11use crate::model_builder_trait::{build_model_from_pipeline, build_text_pipeline};
12use crate::Model;
13
14/// A tool callback with its associated Tool definition.
15#[derive(Clone)]
16pub struct ToolCallbackWithTool {
17    pub callback: Arc<ToolCallback>,
18    pub tool: Tool,
19}
20
21#[derive(Clone)]
22/// Configure a text model with the various parameters for loading, running, and other inference behaviors.
23pub struct TextModelBuilder {
24    // Loading model
25    pub(crate) model_id: String,
26    pub(crate) token_source: TokenSource,
27    pub(crate) hf_revision: Option<String>,
28    pub(crate) write_uqff: Option<PathBuf>,
29    pub(crate) from_uqff: Option<Vec<PathBuf>>,
30    pub(crate) imatrix: Option<PathBuf>,
31    pub(crate) calibration_file: Option<PathBuf>,
32    pub(crate) chat_template: Option<String>,
33    pub(crate) jinja_explicit: Option<String>,
34    pub(crate) tokenizer_json: Option<String>,
35    pub(crate) device_mapping: Option<DeviceMapSetting>,
36    pub(crate) hf_cache_path: Option<PathBuf>,
37    pub(crate) search_embedding_model: Option<SearchEmbeddingModel>,
38    pub(crate) search_callback: Option<Arc<SearchCallback>>,
39    pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
40    pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
41    pub(crate) mcp_client_config: Option<McpClientConfig>,
42    pub(crate) device: Option<Device>,
43    pub(crate) matformer_config_path: Option<PathBuf>,
44    pub(crate) matformer_slice_name: Option<String>,
45
46    // Model running
47    pub(crate) topology: Option<Topology>,
48    pub(crate) topology_path: Option<String>,
49    pub(crate) organization: IsqOrganization,
50    pub(crate) loader_type: Option<NormalLoaderType>,
51    pub(crate) dtype: ModelDType,
52    pub(crate) force_cpu: bool,
53    pub(crate) isq: Option<IsqType>,
54    pub(crate) throughput_logging: bool,
55
56    // Other things
57    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
58    pub(crate) max_num_seqs: usize,
59    pub(crate) no_kv_cache: bool,
60    pub(crate) with_logging: bool,
61    pub(crate) prefix_cache_n: Option<usize>,
62}
63
64/// Builder for PagedAttention metadata.
65pub struct PagedAttentionMetaBuilder {
66    block_size: Option<usize>,
67    mem_gpu: MemoryGpuConfig,
68    cache_type: PagedCacheType,
69}
70
71impl Default for PagedAttentionMetaBuilder {
72    fn default() -> Self {
73        Self {
74            block_size: None,
75            mem_gpu: MemoryGpuConfig::ContextSize(4096),
76            cache_type: PagedCacheType::Auto,
77        }
78    }
79}
80
81impl PagedAttentionMetaBuilder {
82    pub fn with_block_size(mut self, block_size: usize) -> Self {
83        self.block_size = Some(block_size);
84        self
85    }
86
87    pub fn with_gpu_memory(mut self, mem_gpu: MemoryGpuConfig) -> Self {
88        self.mem_gpu = mem_gpu;
89        self
90    }
91
92    pub fn with_paged_cache_type(mut self, cache_type: PagedCacheType) -> Self {
93        self.cache_type = cache_type;
94        self
95    }
96
97    pub fn build(self) -> anyhow::Result<PagedAttentionConfig> {
98        PagedAttentionConfig::new(self.block_size, self.mem_gpu, self.cache_type)
99    }
100}
101
102impl TextModelBuilder {
103    /// A few defaults are applied here:
104    /// - MoQE ISQ organization
105    /// - Token source is from the cache (.cache/huggingface/token)
106    /// - Maximum number of sequences running is 32
107    /// - Number of sequences to hold in prefix cache is 16.
108    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
109    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
110    pub fn new(model_id: impl ToString) -> Self {
111        Self {
112            model_id: model_id.to_string(),
113            topology: None,
114            topology_path: None,
115            organization: IsqOrganization::Default,
116            write_uqff: None,
117            from_uqff: None,
118            chat_template: None,
119            tokenizer_json: None,
120            loader_type: None,
121            dtype: ModelDType::Auto,
122            force_cpu: false,
123            token_source: TokenSource::CacheToken,
124            hf_revision: None,
125            isq: None,
126            paged_attn_cfg: None,
127            max_num_seqs: 32,
128            no_kv_cache: false,
129            prefix_cache_n: Some(16),
130            with_logging: false,
131            device_mapping: None,
132            imatrix: None,
133            calibration_file: None,
134            jinja_explicit: None,
135            throughput_logging: false,
136            hf_cache_path: None,
137            search_embedding_model: None,
138            search_callback: None,
139            tool_callbacks: HashMap::new(),
140            tool_callbacks_with_tools: HashMap::new(),
141            mcp_client_config: None,
142            device: None,
143            matformer_config_path: None,
144            matformer_slice_name: None,
145        }
146    }
147
148    /// Enable searching compatible with the OpenAI `web_search_options` setting. This loads the selected search embedding reranker (EmbeddingGemma by default).
149    pub fn with_search(mut self, search_embedding_model: SearchEmbeddingModel) -> Self {
150        self.search_embedding_model = Some(search_embedding_model);
151        self
152    }
153
154    /// Override the search function used when `web_search_options` is enabled.
155    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
156        self.search_callback = Some(callback);
157        self
158    }
159
160    /// Register a callback for a specific tool name.
161    pub fn with_tool_callback(
162        mut self,
163        name: impl Into<String>,
164        callback: Arc<ToolCallback>,
165    ) -> Self {
166        self.tool_callbacks.insert(name.into(), callback);
167        self
168    }
169
170    /// Register a callback with an associated Tool definition that will be automatically
171    /// added to requests when tool callbacks are active.
172    pub fn with_tool_callback_and_tool(
173        mut self,
174        name: impl Into<String>,
175        callback: Arc<ToolCallback>,
176        tool: Tool,
177    ) -> Self {
178        let name = name.into();
179        self.tool_callbacks_with_tools
180            .insert(name, ToolCallbackWithTool { callback, tool });
181        self
182    }
183
184    /// Configure MCP client to connect to external MCP servers and automatically
185    /// register their tools for use in automatic tool calling.
186    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
187        self.mcp_client_config = Some(config);
188        self
189    }
190
191    /// Enable runner throughput logging.
192    pub fn with_throughput_logging(mut self) -> Self {
193        self.throughput_logging = true;
194        self
195    }
196
197    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
198    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
199        self.jinja_explicit = Some(jinja_explicit);
200        self
201    }
202
203    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
204    pub fn with_topology(mut self, topology: Topology) -> Self {
205        self.topology = Some(topology);
206        self
207    }
208
209    /// Set the model topology from a path. This preserves the path for unload/reload support.
210    /// If there is an overlap, the topology type is used over the ISQ type.
211    pub fn with_topology_from_path<P: AsRef<std::path::Path>>(
212        mut self,
213        path: P,
214    ) -> anyhow::Result<Self> {
215        let path_str = path.as_ref().to_string_lossy().to_string();
216        self.topology = Some(Topology::from_path(&path)?);
217        self.topology_path = Some(path_str);
218        Ok(self)
219    }
220
221    /// Organize ISQ to enable MoQE (Mixture of Quantized Experts, <https://arxiv.org/abs/2310.02410>)
222    pub fn with_mixture_qexperts_isq(mut self) -> Self {
223        self.organization = IsqOrganization::MoeExpertsOnly;
224        self
225    }
226
227    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
228    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
229        self.chat_template = Some(chat_template.to_string());
230        self
231    }
232
233    /// Path to a discrete `tokenizer.json` file.
234    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
235        self.tokenizer_json = Some(tokenizer_json.to_string());
236        self
237    }
238
239    /// Manually set the model loader type. Otherwise, it will attempt to automatically
240    /// determine the loader type.
241    pub fn with_loader_type(mut self, loader_type: NormalLoaderType) -> Self {
242        self.loader_type = Some(loader_type);
243        self
244    }
245
246    /// Load the model in a certain dtype.
247    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
248        self.dtype = dtype;
249        self
250    }
251
252    /// Force usage of the CPU device. Do not use PagedAttention with this.
253    pub fn with_force_cpu(mut self) -> Self {
254        self.force_cpu = true;
255        self
256    }
257
258    /// Source of the Hugging Face token.
259    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
260        self.token_source = token_source;
261        self
262    }
263
264    /// Set the revision to use for a Hugging Face remote model.
265    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
266        self.hf_revision = Some(revision.to_string());
267        self
268    }
269
270    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
271    pub fn with_isq(mut self, isq: IsqType) -> Self {
272        self.isq = Some(isq);
273        self
274    }
275
276    /// Utilise this imatrix file during ISQ. Incompatible with specifying a calibration file.
277    pub fn with_imatrix(mut self, path: PathBuf) -> Self {
278        self.imatrix = Some(path);
279        self
280    }
281
282    /// Utilise this calibration file to collcet an imatrix. Incompatible with specifying a calibration file.
283    pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
284        self.calibration_file = Some(path);
285        self
286    }
287
288    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
289    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`].
290    ///
291    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
292    pub fn with_paged_attn(
293        mut self,
294        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
295    ) -> anyhow::Result<Self> {
296        if paged_attn_supported() {
297            self.paged_attn_cfg = Some(paged_attn_cfg()?);
298        } else {
299            self.paged_attn_cfg = None;
300        }
301        Ok(self)
302    }
303
304    /// Set the maximum number of sequences which can be run at once.
305    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
306        self.max_num_seqs = max_num_seqs;
307        self
308    }
309
310    /// Disable KV cache. Trade performance for memory usage.
311    pub fn with_no_kv_cache(mut self) -> Self {
312        self.no_kv_cache = true;
313        self
314    }
315
316    /// Set the number of sequences to hold in the prefix cache. Set to `None` to disable the prefix cacher.
317    pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
318        self.prefix_cache_n = n_seqs;
319        self
320    }
321
322    /// Enable logging.
323    pub fn with_logging(mut self) -> Self {
324        self.with_logging = true;
325        self
326    }
327
328    /// Provide metadata to initialize the device mapper.
329    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
330        self.device_mapping = Some(device_mapping);
331        self
332    }
333
334    #[deprecated(
335        note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
336    )]
337    /// Path to read a `.uqff` file from. Other necessary configuration files must be present at this location.
338    ///
339    /// For example, these include:
340    /// - `residual.safetensors`
341    /// - `tokenizer.json`
342    /// - `config.json`
343    /// - More depending on the model
344    pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
345        self.from_uqff = Some(path);
346        self
347    }
348
349    /// Path to write a `.uqff` file to and serialize the other necessary files.
350    ///
351    /// The parent (part of the path excluding the filename) will determine where any other files
352    /// serialized are written to.
353    ///
354    /// For example, these include:
355    /// - `residual.safetensors`
356    /// - `tokenizer.json`
357    /// - `config.json`
358    /// - More depending on the model
359    pub fn write_uqff(mut self, path: PathBuf) -> Self {
360        self.write_uqff = Some(path);
361        self
362    }
363
364    /// Cache path for Hugging Face models downloaded locally
365    pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
366        self.hf_cache_path = Some(hf_cache_path);
367        self
368    }
369
370    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
371    pub fn with_device(mut self, device: Device) -> Self {
372        self.device = Some(device);
373        self
374    }
375
376    /// Path to a Matryoshka Transformer configuration CSV file.
377    pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
378        self.matformer_config_path = Some(path);
379        self
380    }
381
382    /// Name of the slice to use from the Matryoshka Transformer configuration.
383    pub fn with_matformer_slice_name(mut self, name: String) -> Self {
384        self.matformer_slice_name = Some(name);
385        self
386    }
387
388    pub async fn build(self) -> anyhow::Result<Model> {
389        let (pipeline, scheduler_config, add_model_config) = build_text_pipeline(self).await?;
390        Ok(build_model_from_pipeline(pipeline, scheduler_config, add_model_config).await)
391    }
392}
393
394#[derive(Clone)]
395/// Configure a UQFF text model with the various parameters for loading, running, and other inference behaviors.
396/// This wraps and implements `DerefMut` for the TextModelBuilder, so users should take care to not call UQFF-related methods.
397pub struct UqffTextModelBuilder(TextModelBuilder);
398
399impl UqffTextModelBuilder {
400    /// A few defaults are applied here:
401    /// - MoQE ISQ organization
402    /// - Token source is from the cache (.cache/huggingface/token)
403    /// - Maximum number of sequences running is 32
404    /// - Number of sequences to hold in prefix cache is 16.
405    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
406    pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
407        let mut inner = TextModelBuilder::new(model_id);
408        inner.from_uqff = Some(uqff_file);
409        Self(inner)
410    }
411
412    pub async fn build(self) -> anyhow::Result<Model> {
413        self.0.build().await
414    }
415
416    /// This wraps the VisionModelBuilder, so users should take care to not call UQFF-related methods.
417    pub fn into_inner(self) -> TextModelBuilder {
418        self.0
419    }
420}
421
422impl Deref for UqffTextModelBuilder {
423    type Target = TextModelBuilder;
424
425    fn deref(&self) -> &Self::Target {
426        &self.0
427    }
428}
429
430impl DerefMut for UqffTextModelBuilder {
431    fn deref_mut(&mut self) -> &mut Self::Target {
432        &mut self.0
433    }
434}
435
436impl From<UqffTextModelBuilder> for TextModelBuilder {
437    fn from(value: UqffTextModelBuilder) -> Self {
438        value.0
439    }
440}