mistralrs/
vision_model.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6    num::NonZeroUsize,
7    ops::{Deref, DerefMut},
8    path::PathBuf,
9    sync::Arc,
10};
11
12use crate::{best_device, Model};
13
14/// A tool callback with its associated Tool definition.
15#[derive(Clone)]
16pub struct ToolCallbackWithTool {
17    pub callback: Arc<ToolCallback>,
18    pub tool: Tool,
19}
20
21#[derive(Clone)]
22/// Configure a vision model with the various parameters for loading, running, and other inference behaviors.
23pub struct VisionModelBuilder {
24    // Loading model
25    pub(crate) model_id: String,
26    pub(crate) token_source: TokenSource,
27    pub(crate) hf_revision: Option<String>,
28    pub(crate) write_uqff: Option<PathBuf>,
29    pub(crate) from_uqff: Option<Vec<PathBuf>>,
30    pub(crate) calibration_file: Option<PathBuf>,
31    pub(crate) imatrix: Option<PathBuf>,
32    pub(crate) chat_template: Option<String>,
33    pub(crate) jinja_explicit: Option<String>,
34    pub(crate) tokenizer_json: Option<String>,
35    pub(crate) device_mapping: Option<DeviceMapSetting>,
36    pub(crate) max_edge: Option<u32>,
37    pub(crate) hf_cache_path: Option<PathBuf>,
38    pub(crate) search_bert_model: Option<BertEmbeddingModel>,
39    pub(crate) search_callback: Option<Arc<SearchCallback>>,
40    pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
41    pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
42    pub(crate) device: Option<Device>,
43    pub(crate) matformer_config_path: Option<PathBuf>,
44    pub(crate) matformer_slice_name: Option<String>,
45
46    // Model running
47    pub(crate) prompt_chunksize: Option<NonZeroUsize>,
48    pub(crate) topology: Option<Topology>,
49    pub(crate) loader_type: Option<VisionLoaderType>,
50    pub(crate) dtype: ModelDType,
51    pub(crate) force_cpu: bool,
52    pub(crate) isq: Option<IsqType>,
53    pub(crate) throughput_logging: bool,
54
55    // Other things
56    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
57    pub(crate) max_num_seqs: usize,
58    pub(crate) with_logging: bool,
59}
60
61impl VisionModelBuilder {
62    /// A few defaults are applied here:
63    /// - Token source is from the cache (.cache/huggingface/token)
64    /// - Maximum number of sequences running is 32
65    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
66    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
67    pub fn new(model_id: impl ToString) -> Self {
68        Self {
69            model_id: model_id.to_string(),
70            topology: None,
71            write_uqff: None,
72            from_uqff: None,
73            prompt_chunksize: None,
74            chat_template: None,
75            tokenizer_json: None,
76            max_edge: None,
77            loader_type: None,
78            dtype: ModelDType::Auto,
79            force_cpu: false,
80            token_source: TokenSource::CacheToken,
81            hf_revision: None,
82            isq: None,
83            max_num_seqs: 32,
84            with_logging: false,
85            device_mapping: None,
86            calibration_file: None,
87            imatrix: None,
88            jinja_explicit: None,
89            throughput_logging: false,
90            paged_attn_cfg: None,
91            hf_cache_path: None,
92            search_bert_model: None,
93            search_callback: None,
94            tool_callbacks: HashMap::new(),
95            tool_callbacks_with_tools: HashMap::new(),
96            device: None,
97            matformer_config_path: None,
98            matformer_slice_name: None,
99        }
100    }
101
102    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified or the default.
103    pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
104        self.search_bert_model = Some(search_bert_model);
105        self
106    }
107
108    /// Override the search function used when `web_search_options` is enabled.
109    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
110        self.search_callback = Some(callback);
111        self
112    }
113
114    pub fn with_tool_callback(
115        mut self,
116        name: impl Into<String>,
117        callback: Arc<ToolCallback>,
118    ) -> Self {
119        self.tool_callbacks.insert(name.into(), callback);
120        self
121    }
122
123    /// Register a callback with an associated Tool definition that will be automatically
124    /// added to requests when tool callbacks are active.
125    pub fn with_tool_callback_and_tool(
126        mut self,
127        name: impl Into<String>,
128        callback: Arc<ToolCallback>,
129        tool: Tool,
130    ) -> Self {
131        let name = name.into();
132        self.tool_callbacks_with_tools
133            .insert(name, ToolCallbackWithTool { callback, tool });
134        self
135    }
136
137    /// Enable runner throughput logging.
138    pub fn with_throughput_logging(mut self) -> Self {
139        self.throughput_logging = true;
140        self
141    }
142
143    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
144    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
145        self.jinja_explicit = Some(jinja_explicit);
146        self
147    }
148
149    /// Set the prompt batchsize to use for inference.
150    pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
151        self.prompt_chunksize = Some(prompt_chunksize);
152        self
153    }
154
155    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
156    pub fn with_topology(mut self, topology: Topology) -> Self {
157        self.topology = Some(topology);
158        self
159    }
160
161    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
162    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
163        self.chat_template = Some(chat_template.to_string());
164        self
165    }
166
167    /// Path to a discrete `tokenizer.json` file.
168    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
169        self.tokenizer_json = Some(tokenizer_json.to_string());
170        self
171    }
172
173    /// Manually set the model loader type. Otherwise, it will attempt to automatically
174    /// determine the loader type.
175    pub fn with_loader_type(mut self, loader_type: VisionLoaderType) -> Self {
176        self.loader_type = Some(loader_type);
177        self
178    }
179
180    /// Load the model in a certain dtype.
181    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
182        self.dtype = dtype;
183        self
184    }
185
186    /// Force usage of the CPU device. Do not use PagedAttention with this.
187    pub fn with_force_cpu(mut self) -> Self {
188        self.force_cpu = true;
189        self
190    }
191
192    /// Source of the Hugging Face token.
193    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
194        self.token_source = token_source;
195        self
196    }
197
198    /// Set the revision to use for a Hugging Face remote model.
199    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
200        self.hf_revision = Some(revision.to_string());
201        self
202    }
203
204    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
205    pub fn with_isq(mut self, isq: IsqType) -> Self {
206        self.isq = Some(isq);
207        self
208    }
209
210    /// Utilise this calibration_file file during ISQ
211    pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
212        self.calibration_file = Some(path);
213        self
214    }
215
216    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
217    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`](crate::PagedAttentionMetaBuilder).
218    ///
219    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
220    pub fn with_paged_attn(
221        mut self,
222        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
223    ) -> anyhow::Result<Self> {
224        if paged_attn_supported() {
225            self.paged_attn_cfg = Some(paged_attn_cfg()?);
226        } else {
227            self.paged_attn_cfg = None;
228        }
229        Ok(self)
230    }
231
232    /// Set the maximum number of sequences which can be run at once.
233    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
234        self.max_num_seqs = max_num_seqs;
235        self
236    }
237
238    /// Enable logging.
239    pub fn with_logging(mut self) -> Self {
240        self.with_logging = true;
241        self
242    }
243
244    /// Provide metadata to initialize the device mapper.
245    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
246        self.device_mapping = Some(device_mapping);
247        self
248    }
249
250    /// Path to read a UQFF file from.
251    pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
252        self.from_uqff = Some(path);
253        self
254    }
255
256    /// Automatically resize and pad images to this maximum edge length. Aspect ratio is preserved.
257    /// This is only supported on the Qwen2-VL and Idefics 2 models. Others handle this internally.
258    pub fn from_max_edge(mut self, max_edge: u32) -> Self {
259        self.max_edge = Some(max_edge);
260        self
261    }
262
263    /// Path to write a UQFF file to.
264    ///
265    /// The parent (part of the path excluding the filename) will determine where any other files
266    /// generated are written to. These can be used to load UQFF models standalone, and may include:
267    /// - `residual.safetensors`
268    /// - `tokenizer.json`
269    /// - `config.json`
270    /// - And others
271    pub fn write_uqff(mut self, path: PathBuf) -> Self {
272        self.write_uqff = Some(path);
273        self
274    }
275
276    /// Cache path for Hugging Face models downloaded locally
277    pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
278        self.hf_cache_path = Some(hf_cache_path);
279        self
280    }
281
282    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
283    pub fn with_device(mut self, device: Device) -> Self {
284        self.device = Some(device);
285        self
286    }
287
288    /// Path to a Matryoshka Transformer configuration CSV file.
289    pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
290        self.matformer_config_path = Some(path);
291        self
292    }
293
294    /// Name of the slice to use from the Matryoshka Transformer configuration.
295    pub fn with_matformer_slice_name(mut self, name: String) -> Self {
296        self.matformer_slice_name = Some(name);
297        self
298    }
299
300    pub async fn build(self) -> anyhow::Result<Model> {
301        let config = VisionSpecificConfig {
302            prompt_chunksize: self.prompt_chunksize,
303            topology: self.topology,
304            write_uqff: self.write_uqff,
305            from_uqff: self.from_uqff,
306            max_edge: self.max_edge,
307            calibration_file: self.calibration_file,
308            imatrix: self.imatrix,
309            hf_cache_path: self.hf_cache_path,
310            matformer_config_path: self.matformer_config_path,
311            matformer_slice_name: self.matformer_slice_name,
312        };
313
314        if self.with_logging {
315            initialize_logging();
316        }
317
318        let loader = VisionLoaderBuilder::new(
319            config,
320            self.chat_template,
321            self.tokenizer_json,
322            Some(self.model_id),
323            self.jinja_explicit,
324        )
325        .build(self.loader_type);
326
327        // Load, into a Pipeline
328        let pipeline = loader.load_model_from_hf(
329            self.hf_revision,
330            self.token_source,
331            &self.dtype,
332            &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
333            !self.with_logging,
334            self.device_mapping
335                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_vision())),
336            self.isq,
337            self.paged_attn_cfg,
338        )?;
339
340        let scheduler_method = match self.paged_attn_cfg {
341            Some(_) => {
342                let config = pipeline
343                    .lock()
344                    .await
345                    .get_metadata()
346                    .cache_config
347                    .as_ref()
348                    .cloned();
349
350                if let Some(config) = config {
351                    SchedulerConfig::PagedAttentionMeta {
352                        max_num_seqs: self.max_num_seqs,
353                        config,
354                    }
355                } else {
356                    SchedulerConfig::DefaultScheduler {
357                        method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
358                    }
359                }
360            }
361            None => SchedulerConfig::DefaultScheduler {
362                method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
363            },
364        };
365
366        let mut runner = MistralRsBuilder::new(
367            pipeline,
368            scheduler_method,
369            self.throughput_logging,
370            self.search_bert_model,
371        );
372        if let Some(cb) = self.search_callback.clone() {
373            runner = runner.with_search_callback(cb);
374        }
375        for (name, cb) in &self.tool_callbacks {
376            runner = runner.with_tool_callback(name.clone(), cb.clone());
377        }
378        for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
379            runner = runner.with_tool_callback_and_tool(
380                name.clone(),
381                callback_with_tool.callback.clone(),
382                callback_with_tool.tool.clone(),
383            );
384        }
385        let runner = runner.with_no_kv_cache(false).with_no_prefix_cache(false);
386
387        Ok(Model::new(runner.build().await))
388    }
389}
390
391#[derive(Clone)]
392/// Configure a UQFF text model with the various parameters for loading, running, and other inference behaviors.
393/// This wraps and implements `DerefMut` for the VisionModelBuilder, so users should take care to not call UQFF-related methods.
394pub struct UqffVisionModelBuilder(VisionModelBuilder);
395
396impl UqffVisionModelBuilder {
397    /// A few defaults are applied here:
398    /// - Token source is from the cache (.cache/huggingface/token)
399    /// - Maximum number of sequences running is 32
400    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
401    pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
402        let mut inner = VisionModelBuilder::new(model_id);
403        inner = inner.from_uqff(uqff_file);
404        Self(inner)
405    }
406
407    pub async fn build(self) -> anyhow::Result<Model> {
408        self.0.build().await
409    }
410
411    /// This wraps the VisionModelBuilder, so users should take care to not call UQFF-related methods.
412    pub fn into_inner(self) -> VisionModelBuilder {
413        self.0
414    }
415}
416
417impl Deref for UqffVisionModelBuilder {
418    type Target = VisionModelBuilder;
419
420    fn deref(&self) -> &Self::Target {
421        &self.0
422    }
423}
424
425impl DerefMut for UqffVisionModelBuilder {
426    fn deref_mut(&mut self) -> &mut Self::Target {
427        &mut self.0
428    }
429}
430
431impl From<UqffVisionModelBuilder> for VisionModelBuilder {
432    fn from(value: UqffVisionModelBuilder) -> Self {
433        value.0
434    }
435}