mistralrs/
vision_model.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use std::{
4    num::NonZeroUsize,
5    ops::{Deref, DerefMut},
6    path::PathBuf,
7};
8
9use crate::{best_device, Model};
10
11#[derive(Clone)]
12/// Configure a vision model with the various parameters for loading, running, and other inference behaviors.
13pub struct VisionModelBuilder {
14    // Loading model
15    pub(crate) model_id: String,
16    pub(crate) token_source: TokenSource,
17    pub(crate) hf_revision: Option<String>,
18    pub(crate) write_uqff: Option<PathBuf>,
19    pub(crate) from_uqff: Option<Vec<PathBuf>>,
20    pub(crate) calibration_file: Option<PathBuf>,
21    pub(crate) imatrix: Option<PathBuf>,
22    pub(crate) chat_template: Option<String>,
23    pub(crate) jinja_explicit: Option<String>,
24    pub(crate) tokenizer_json: Option<String>,
25    pub(crate) device_mapping: Option<DeviceMapSetting>,
26    pub(crate) max_edge: Option<u32>,
27    pub(crate) hf_cache_path: Option<PathBuf>,
28    pub(crate) search_bert_model: Option<BertEmbeddingModel>,
29    pub(crate) device: Option<Device>,
30
31    // Model running
32    pub(crate) prompt_chunksize: Option<NonZeroUsize>,
33    pub(crate) topology: Option<Topology>,
34    pub(crate) loader_type: Option<VisionLoaderType>,
35    pub(crate) dtype: ModelDType,
36    pub(crate) force_cpu: bool,
37    pub(crate) isq: Option<IsqType>,
38    pub(crate) throughput_logging: bool,
39
40    // Other things
41    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
42    pub(crate) max_num_seqs: usize,
43    pub(crate) with_logging: bool,
44}
45
46impl VisionModelBuilder {
47    /// A few defaults are applied here:
48    /// - Token source is from the cache (.cache/huggingface/token)
49    /// - Maximum number of sequences running is 32
50    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
51    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
52    pub fn new(model_id: impl ToString) -> Self {
53        Self {
54            model_id: model_id.to_string(),
55            topology: None,
56            write_uqff: None,
57            from_uqff: None,
58            prompt_chunksize: None,
59            chat_template: None,
60            tokenizer_json: None,
61            max_edge: None,
62            loader_type: None,
63            dtype: ModelDType::Auto,
64            force_cpu: false,
65            token_source: TokenSource::CacheToken,
66            hf_revision: None,
67            isq: None,
68            max_num_seqs: 32,
69            with_logging: false,
70            device_mapping: None,
71            calibration_file: None,
72            imatrix: None,
73            jinja_explicit: None,
74            throughput_logging: false,
75            paged_attn_cfg: None,
76            hf_cache_path: None,
77            search_bert_model: None,
78            device: None,
79        }
80    }
81
82    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified or the default.
83    pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
84        self.search_bert_model = Some(search_bert_model);
85        self
86    }
87
88    /// Enable runner throughput logging.
89    pub fn with_throughput_logging(mut self) -> Self {
90        self.throughput_logging = true;
91        self
92    }
93
94    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
95    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
96        self.jinja_explicit = Some(jinja_explicit);
97        self
98    }
99
100    /// Set the prompt batchsize to use for inference.
101    pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
102        self.prompt_chunksize = Some(prompt_chunksize);
103        self
104    }
105
106    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
107    pub fn with_topology(mut self, topology: Topology) -> Self {
108        self.topology = Some(topology);
109        self
110    }
111
112    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
113    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
114        self.chat_template = Some(chat_template.to_string());
115        self
116    }
117
118    /// Path to a discrete `tokenizer.json` file.
119    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
120        self.tokenizer_json = Some(tokenizer_json.to_string());
121        self
122    }
123
124    /// Manually set the model loader type. Otherwise, it will attempt to automatically
125    /// determine the loader type.
126    pub fn with_loader_type(mut self, loader_type: VisionLoaderType) -> Self {
127        self.loader_type = Some(loader_type);
128        self
129    }
130
131    /// Load the model in a certain dtype.
132    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
133        self.dtype = dtype;
134        self
135    }
136
137    /// Force usage of the CPU device. Do not use PagedAttention with this.
138    pub fn with_force_cpu(mut self) -> Self {
139        self.force_cpu = true;
140        self
141    }
142
143    /// Source of the Hugging Face token.
144    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
145        self.token_source = token_source;
146        self
147    }
148
149    /// Set the revision to use for a Hugging Face remote model.
150    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
151        self.hf_revision = Some(revision.to_string());
152        self
153    }
154
155    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
156    pub fn with_isq(mut self, isq: IsqType) -> Self {
157        self.isq = Some(isq);
158        self
159    }
160
161    /// Utilise this calibration_file file during ISQ
162    pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
163        self.calibration_file = Some(path);
164        self
165    }
166
167    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
168    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`].
169    ///
170    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
171    pub fn with_paged_attn(
172        mut self,
173        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
174    ) -> anyhow::Result<Self> {
175        if paged_attn_supported() {
176            self.paged_attn_cfg = Some(paged_attn_cfg()?);
177        } else {
178            self.paged_attn_cfg = None;
179        }
180        Ok(self)
181    }
182
183    /// Set the maximum number of sequences which can be run at once.
184    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
185        self.max_num_seqs = max_num_seqs;
186        self
187    }
188
189    /// Enable logging.
190    pub fn with_logging(mut self) -> Self {
191        self.with_logging = true;
192        self
193    }
194
195    /// Provide metadata to initialize the device mapper.
196    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
197        self.device_mapping = Some(device_mapping);
198        self
199    }
200
201    /// Path to read a UQFF file from.
202    pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
203        self.from_uqff = Some(path);
204        self
205    }
206
207    /// Automatically resize and pad images to this maximum edge length. Aspect ratio is preserved.
208    /// This is only supported on the Qwen2-VL and Idefics 2 models. Others handle this internally.
209    pub fn from_max_edge(mut self, max_edge: u32) -> Self {
210        self.max_edge = Some(max_edge);
211        self
212    }
213
214    /// Path to write a UQFF file to.
215    ///
216    /// The parent (part of the path excluding the filename) will determine where any other files
217    /// generated are written to. These can be used to load UQFF models standalone, and may include:
218    /// - `residual.safetensors`
219    /// - `tokenizer.json`
220    /// - `config.json`
221    /// - And others
222    pub fn write_uqff(mut self, path: PathBuf) -> Self {
223        self.write_uqff = Some(path);
224        self
225    }
226
227    /// Cache path for Hugging Face models downloaded locally
228    pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
229        self.hf_cache_path = Some(hf_cache_path);
230        self
231    }
232
233    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
234    pub fn with_device(mut self, device: Device) -> Self {
235        self.device = Some(device);
236        self
237    }
238
239    pub async fn build(self) -> anyhow::Result<Model> {
240        let config = VisionSpecificConfig {
241            prompt_chunksize: self.prompt_chunksize,
242            topology: self.topology,
243            write_uqff: self.write_uqff,
244            from_uqff: self.from_uqff,
245            max_edge: self.max_edge,
246            calibration_file: self.calibration_file,
247            imatrix: self.imatrix,
248            hf_cache_path: self.hf_cache_path,
249        };
250
251        if self.with_logging {
252            initialize_logging();
253        }
254
255        let loader = VisionLoaderBuilder::new(
256            config,
257            self.chat_template,
258            self.tokenizer_json,
259            Some(self.model_id),
260            self.jinja_explicit,
261        )
262        .build(self.loader_type);
263
264        // Load, into a Pipeline
265        let pipeline = loader.load_model_from_hf(
266            self.hf_revision,
267            self.token_source,
268            &self.dtype,
269            &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
270            !self.with_logging,
271            self.device_mapping
272                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_vision())),
273            self.isq,
274            self.paged_attn_cfg,
275        )?;
276
277        let scheduler_method = match self.paged_attn_cfg {
278            Some(_) => {
279                let config = pipeline
280                    .lock()
281                    .await
282                    .get_metadata()
283                    .cache_config
284                    .as_ref()
285                    .cloned();
286
287                if let Some(config) = config {
288                    SchedulerConfig::PagedAttentionMeta {
289                        max_num_seqs: self.max_num_seqs,
290                        config,
291                    }
292                } else {
293                    SchedulerConfig::DefaultScheduler {
294                        method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
295                    }
296                }
297            }
298            None => SchedulerConfig::DefaultScheduler {
299                method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
300            },
301        };
302
303        let runner = MistralRsBuilder::new(
304            pipeline,
305            scheduler_method,
306            self.throughput_logging,
307            self.search_bert_model,
308        )
309        .with_no_kv_cache(false)
310        .with_no_prefix_cache(false);
311
312        Ok(Model::new(runner.build()))
313    }
314}
315
316#[derive(Clone)]
317/// Configure a UQFF text model with the various parameters for loading, running, and other inference behaviors.
318/// This wraps and implements `DerefMut` for the VisionModelBuilder, so users should take care to not call UQFF-related methods.
319pub struct UqffVisionModelBuilder(VisionModelBuilder);
320
321impl UqffVisionModelBuilder {
322    /// A few defaults are applied here:
323    /// - Token source is from the cache (.cache/huggingface/token)
324    /// - Maximum number of sequences running is 32
325    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
326    pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
327        let mut inner = VisionModelBuilder::new(model_id);
328        inner = inner.from_uqff(uqff_file);
329        Self(inner)
330    }
331
332    pub async fn build(self) -> anyhow::Result<Model> {
333        self.0.build().await
334    }
335
336    /// This wraps the VisionModelBuilder, so users should take care to not call UQFF-related methods.
337    pub fn into_inner(self) -> VisionModelBuilder {
338        self.0
339    }
340}
341
342impl Deref for UqffVisionModelBuilder {
343    type Target = VisionModelBuilder;
344
345    fn deref(&self) -> &Self::Target {
346        &self.0
347    }
348}
349
350impl DerefMut for UqffVisionModelBuilder {
351    fn deref_mut(&mut self) -> &mut Self::Target {
352        &mut self.0
353    }
354}
355
356impl From<UqffVisionModelBuilder> for VisionModelBuilder {
357    fn from(value: UqffVisionModelBuilder) -> Self {
358        value.0
359    }
360}