mistralrs/
vision_model.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6    ops::{Deref, DerefMut},
7    path::PathBuf,
8    sync::Arc,
9};
10
11use crate::{best_device, Model};
12
13/// A tool callback with its associated Tool definition.
14#[derive(Clone)]
15pub struct ToolCallbackWithTool {
16    pub callback: Arc<ToolCallback>,
17    pub tool: Tool,
18}
19
20#[derive(Clone)]
21/// Configure a vision model with the various parameters for loading, running, and other inference behaviors.
22pub struct VisionModelBuilder {
23    // Loading model
24    pub(crate) model_id: String,
25    pub(crate) token_source: TokenSource,
26    pub(crate) hf_revision: Option<String>,
27    pub(crate) write_uqff: Option<PathBuf>,
28    pub(crate) from_uqff: Option<Vec<PathBuf>>,
29    pub(crate) calibration_file: Option<PathBuf>,
30    pub(crate) imatrix: Option<PathBuf>,
31    pub(crate) chat_template: Option<String>,
32    pub(crate) jinja_explicit: Option<String>,
33    pub(crate) tokenizer_json: Option<String>,
34    pub(crate) device_mapping: Option<DeviceMapSetting>,
35    pub(crate) max_edge: Option<u32>,
36    pub(crate) hf_cache_path: Option<PathBuf>,
37    pub(crate) search_bert_model: Option<BertEmbeddingModel>,
38    pub(crate) search_callback: Option<Arc<SearchCallback>>,
39    pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
40    pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
41    pub(crate) device: Option<Device>,
42    pub(crate) matformer_config_path: Option<PathBuf>,
43    pub(crate) matformer_slice_name: Option<String>,
44
45    // Model running
46    pub(crate) topology: Option<Topology>,
47    pub(crate) loader_type: Option<VisionLoaderType>,
48    pub(crate) dtype: ModelDType,
49    pub(crate) force_cpu: bool,
50    pub(crate) isq: Option<IsqType>,
51    pub(crate) throughput_logging: bool,
52
53    // Other things
54    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
55    pub(crate) max_num_seqs: usize,
56    pub(crate) with_logging: bool,
57    pub(crate) prefix_cache_n: Option<usize>,
58}
59
60impl VisionModelBuilder {
61    /// A few defaults are applied here:
62    /// - Token source is from the cache (.cache/huggingface/token)
63    /// - Maximum number of sequences running is 32
64    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
65    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
66    pub fn new(model_id: impl ToString) -> Self {
67        Self {
68            model_id: model_id.to_string(),
69            topology: None,
70            write_uqff: None,
71            from_uqff: None,
72            chat_template: None,
73            tokenizer_json: None,
74            max_edge: None,
75            loader_type: None,
76            dtype: ModelDType::Auto,
77            force_cpu: false,
78            token_source: TokenSource::CacheToken,
79            hf_revision: None,
80            isq: None,
81            max_num_seqs: 32,
82            with_logging: false,
83            device_mapping: None,
84            calibration_file: None,
85            imatrix: None,
86            jinja_explicit: None,
87            throughput_logging: false,
88            paged_attn_cfg: None,
89            hf_cache_path: None,
90            search_bert_model: None,
91            search_callback: None,
92            tool_callbacks: HashMap::new(),
93            tool_callbacks_with_tools: HashMap::new(),
94            device: None,
95            matformer_config_path: None,
96            matformer_slice_name: None,
97            prefix_cache_n: None,
98        }
99    }
100
101    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified or the default.
102    pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
103        self.search_bert_model = Some(search_bert_model);
104        self
105    }
106
107    /// Override the search function used when `web_search_options` is enabled.
108    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
109        self.search_callback = Some(callback);
110        self
111    }
112
113    pub fn with_tool_callback(
114        mut self,
115        name: impl Into<String>,
116        callback: Arc<ToolCallback>,
117    ) -> Self {
118        self.tool_callbacks.insert(name.into(), callback);
119        self
120    }
121
122    /// Register a callback with an associated Tool definition that will be automatically
123    /// added to requests when tool callbacks are active.
124    pub fn with_tool_callback_and_tool(
125        mut self,
126        name: impl Into<String>,
127        callback: Arc<ToolCallback>,
128        tool: Tool,
129    ) -> Self {
130        let name = name.into();
131        self.tool_callbacks_with_tools
132            .insert(name, ToolCallbackWithTool { callback, tool });
133        self
134    }
135
136    /// Enable runner throughput logging.
137    pub fn with_throughput_logging(mut self) -> Self {
138        self.throughput_logging = true;
139        self
140    }
141
142    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
143    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
144        self.jinja_explicit = Some(jinja_explicit);
145        self
146    }
147
148    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
149    pub fn with_topology(mut self, topology: Topology) -> Self {
150        self.topology = Some(topology);
151        self
152    }
153
154    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
155    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
156        self.chat_template = Some(chat_template.to_string());
157        self
158    }
159
160    /// Path to a discrete `tokenizer.json` file.
161    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
162        self.tokenizer_json = Some(tokenizer_json.to_string());
163        self
164    }
165
166    /// Manually set the model loader type. Otherwise, it will attempt to automatically
167    /// determine the loader type.
168    pub fn with_loader_type(mut self, loader_type: VisionLoaderType) -> Self {
169        self.loader_type = Some(loader_type);
170        self
171    }
172
173    /// Load the model in a certain dtype.
174    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
175        self.dtype = dtype;
176        self
177    }
178
179    /// Force usage of the CPU device. Do not use PagedAttention with this.
180    pub fn with_force_cpu(mut self) -> Self {
181        self.force_cpu = true;
182        self
183    }
184
185    /// Source of the Hugging Face token.
186    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
187        self.token_source = token_source;
188        self
189    }
190
191    /// Set the revision to use for a Hugging Face remote model.
192    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
193        self.hf_revision = Some(revision.to_string());
194        self
195    }
196
197    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
198    pub fn with_isq(mut self, isq: IsqType) -> Self {
199        self.isq = Some(isq);
200        self
201    }
202
203    /// Utilise this calibration_file file during ISQ
204    pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
205        self.calibration_file = Some(path);
206        self
207    }
208
209    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
210    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`](crate::PagedAttentionMetaBuilder).
211    ///
212    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
213    pub fn with_paged_attn(
214        mut self,
215        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
216    ) -> anyhow::Result<Self> {
217        if paged_attn_supported() {
218            self.paged_attn_cfg = Some(paged_attn_cfg()?);
219        } else {
220            self.paged_attn_cfg = None;
221        }
222        Ok(self)
223    }
224
225    /// Set the maximum number of sequences which can be run at once.
226    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
227        self.max_num_seqs = max_num_seqs;
228        self
229    }
230
231    /// Set the number of sequences to hold in the prefix cache. Set to `None` to disable the prefix cacher.
232    pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
233        self.prefix_cache_n = n_seqs;
234        self
235    }
236
237    /// Enable logging.
238    pub fn with_logging(mut self) -> Self {
239        self.with_logging = true;
240        self
241    }
242
243    /// Provide metadata to initialize the device mapper.
244    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
245        self.device_mapping = Some(device_mapping);
246        self
247    }
248
249    #[deprecated(
250        note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
251    )]
252    /// Path to read a `.uqff` file from. Other necessary configuration files must be present at this location.
253    ///
254    /// For example, these include:
255    /// - `residual.safetensors`
256    /// - `tokenizer.json`
257    /// - `config.json`
258    /// - More depending on the model
259    pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
260        self.from_uqff = Some(path);
261        self
262    }
263
264    /// Automatically resize and pad images to this maximum edge length. Aspect ratio is preserved.
265    /// This is only supported on the Qwen2-VL and Idefics 2 models. Others handle this internally.
266    pub fn from_max_edge(mut self, max_edge: u32) -> Self {
267        self.max_edge = Some(max_edge);
268        self
269    }
270
271    /// Path to write a `.uqff` file to and serialize the other necessary files.
272    ///
273    /// The parent (part of the path excluding the filename) will determine where any other files
274    /// serialized are written to.
275    ///
276    /// For example, these include:
277    /// - `residual.safetensors`
278    /// - `tokenizer.json`
279    /// - `config.json`
280    /// - More depending on the model
281    pub fn write_uqff(mut self, path: PathBuf) -> Self {
282        self.write_uqff = Some(path);
283        self
284    }
285
286    /// Cache path for Hugging Face models downloaded locally
287    pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
288        self.hf_cache_path = Some(hf_cache_path);
289        self
290    }
291
292    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
293    pub fn with_device(mut self, device: Device) -> Self {
294        self.device = Some(device);
295        self
296    }
297
298    /// Path to a Matryoshka Transformer configuration CSV file.
299    pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
300        self.matformer_config_path = Some(path);
301        self
302    }
303
304    /// Name of the slice to use from the Matryoshka Transformer configuration.
305    pub fn with_matformer_slice_name(mut self, name: String) -> Self {
306        self.matformer_slice_name = Some(name);
307        self
308    }
309
310    pub async fn build(self) -> anyhow::Result<Model> {
311        let config = VisionSpecificConfig {
312            topology: self.topology,
313            write_uqff: self.write_uqff,
314            from_uqff: self.from_uqff,
315            max_edge: self.max_edge,
316            calibration_file: self.calibration_file,
317            imatrix: self.imatrix,
318            hf_cache_path: self.hf_cache_path,
319            matformer_config_path: self.matformer_config_path,
320            matformer_slice_name: self.matformer_slice_name,
321        };
322
323        if self.with_logging {
324            initialize_logging();
325        }
326
327        let loader = VisionLoaderBuilder::new(
328            config,
329            self.chat_template,
330            self.tokenizer_json,
331            Some(self.model_id),
332            self.jinja_explicit,
333        )
334        .build(self.loader_type);
335
336        // Load, into a Pipeline
337        let pipeline = loader.load_model_from_hf(
338            self.hf_revision,
339            self.token_source,
340            &self.dtype,
341            &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
342            !self.with_logging,
343            self.device_mapping
344                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_vision())),
345            self.isq,
346            self.paged_attn_cfg,
347        )?;
348
349        let scheduler_method = match self.paged_attn_cfg {
350            Some(_) => {
351                let config = pipeline
352                    .lock()
353                    .await
354                    .get_metadata()
355                    .cache_config
356                    .as_ref()
357                    .cloned();
358
359                if let Some(config) = config {
360                    SchedulerConfig::PagedAttentionMeta {
361                        max_num_seqs: self.max_num_seqs,
362                        config,
363                    }
364                } else {
365                    SchedulerConfig::DefaultScheduler {
366                        method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
367                    }
368                }
369            }
370            None => SchedulerConfig::DefaultScheduler {
371                method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
372            },
373        };
374
375        let mut runner = MistralRsBuilder::new(
376            pipeline,
377            scheduler_method,
378            self.throughput_logging,
379            self.search_bert_model,
380        );
381        if let Some(cb) = self.search_callback.clone() {
382            runner = runner.with_search_callback(cb);
383        }
384        for (name, cb) in &self.tool_callbacks {
385            runner = runner.with_tool_callback(name.clone(), cb.clone());
386        }
387        for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
388            runner = runner.with_tool_callback_and_tool(
389                name.clone(),
390                callback_with_tool.callback.clone(),
391                callback_with_tool.tool.clone(),
392            );
393        }
394        let mut runner = runner
395            .with_no_kv_cache(false)
396            .with_no_prefix_cache(self.prefix_cache_n.is_none());
397
398        if let Some(n) = self.prefix_cache_n {
399            runner = runner.with_prefix_cache_n(n)
400        }
401
402        Ok(Model::new(runner.build().await))
403    }
404}
405
406#[derive(Clone)]
407/// Configure a UQFF text model with the various parameters for loading, running, and other inference behaviors.
408/// This wraps and implements `DerefMut` for the VisionModelBuilder, so users should take care to not call UQFF-related methods.
409pub struct UqffVisionModelBuilder(VisionModelBuilder);
410
411impl UqffVisionModelBuilder {
412    /// A few defaults are applied here:
413    /// - Token source is from the cache (.cache/huggingface/token)
414    /// - Maximum number of sequences running is 32
415    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
416    pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
417        let mut inner = VisionModelBuilder::new(model_id);
418        inner.from_uqff = Some(uqff_file);
419        Self(inner)
420    }
421
422    pub async fn build(self) -> anyhow::Result<Model> {
423        self.0.build().await
424    }
425
426    /// This wraps the VisionModelBuilder, so users should take care to not call UQFF-related methods.
427    pub fn into_inner(self) -> VisionModelBuilder {
428        self.0
429    }
430}
431
432impl Deref for UqffVisionModelBuilder {
433    type Target = VisionModelBuilder;
434
435    fn deref(&self) -> &Self::Target {
436        &self.0
437    }
438}
439
440impl DerefMut for UqffVisionModelBuilder {
441    fn deref_mut(&mut self) -> &mut Self::Target {
442        &mut self.0
443    }
444}
445
446impl From<UqffVisionModelBuilder> for VisionModelBuilder {
447    fn from(value: UqffVisionModelBuilder) -> Self {
448        value.0
449    }
450}