mistralrs/
vision_model.rs

1use mistralrs_core::*;
2use std::{
3    num::NonZeroUsize,
4    ops::{Deref, DerefMut},
5    path::PathBuf,
6};
7
8use crate::{best_device, Model};
9
10#[derive(Clone)]
11/// Configure a vision model with the various parameters for loading, running, and other inference behaviors.
12pub struct VisionModelBuilder {
13    // Loading model
14    pub(crate) model_id: String,
15    pub(crate) token_source: TokenSource,
16    pub(crate) hf_revision: Option<String>,
17    pub(crate) write_uqff: Option<PathBuf>,
18    pub(crate) from_uqff: Option<PathBuf>,
19    pub(crate) calibration_file: Option<PathBuf>,
20    pub(crate) imatrix: Option<PathBuf>,
21    pub(crate) chat_template: Option<String>,
22    pub(crate) jinja_explicit: Option<String>,
23    pub(crate) tokenizer_json: Option<String>,
24    pub(crate) device_mapping: Option<DeviceMapSetting>,
25    pub(crate) max_edge: Option<u32>,
26    pub(crate) hf_cache_path: Option<PathBuf>,
27    pub(crate) search_bert_model: Option<BertEmbeddingModel>,
28
29    // Model running
30    pub(crate) use_flash_attn: bool,
31    pub(crate) prompt_chunksize: Option<NonZeroUsize>,
32    pub(crate) topology: Option<Topology>,
33    pub(crate) loader_type: VisionLoaderType,
34    pub(crate) dtype: ModelDType,
35    pub(crate) force_cpu: bool,
36    pub(crate) isq: Option<IsqType>,
37    pub(crate) throughput_logging: bool,
38
39    // Other things
40    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
41    pub(crate) max_num_seqs: usize,
42    pub(crate) with_logging: bool,
43}
44
45impl VisionModelBuilder {
46    /// A few defaults are applied here:
47    /// - Token source is from the cache (.cache/huggingface/token)
48    /// - Maximum number of sequences running is 32
49    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
50    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
51    pub fn new(model_id: impl ToString, loader_type: VisionLoaderType) -> Self {
52        Self {
53            model_id: model_id.to_string(),
54            use_flash_attn: cfg!(feature = "flash-attn"),
55            topology: None,
56            write_uqff: None,
57            from_uqff: None,
58            prompt_chunksize: None,
59            chat_template: None,
60            tokenizer_json: None,
61            max_edge: None,
62            loader_type,
63            dtype: ModelDType::Auto,
64            force_cpu: false,
65            token_source: TokenSource::CacheToken,
66            hf_revision: None,
67            isq: None,
68            max_num_seqs: 32,
69            with_logging: false,
70            device_mapping: None,
71            calibration_file: None,
72            imatrix: None,
73            jinja_explicit: None,
74            throughput_logging: false,
75            paged_attn_cfg: None,
76            hf_cache_path: None,
77            search_bert_model: None,
78        }
79    }
80
81    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified or the default.
82    pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
83        self.search_bert_model = Some(search_bert_model);
84        self
85    }
86
87    /// Enable runner throughput logging.
88    pub fn with_throughput_logging(mut self) -> Self {
89        self.throughput_logging = true;
90        self
91    }
92
93    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
94    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
95        self.jinja_explicit = Some(jinja_explicit);
96        self
97    }
98
99    /// Set the prompt batchsize to use for inference.
100    pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
101        self.prompt_chunksize = Some(prompt_chunksize);
102        self
103    }
104
105    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
106    pub fn with_topology(mut self, topology: Topology) -> Self {
107        self.topology = Some(topology);
108        self
109    }
110
111    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
112    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
113        self.chat_template = Some(chat_template.to_string());
114        self
115    }
116
117    /// Path to a discrete `tokenizer.json` file.
118    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
119        self.tokenizer_json = Some(tokenizer_json.to_string());
120        self
121    }
122
123    /// Load the model in a certain dtype.
124    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
125        self.dtype = dtype;
126        self
127    }
128
129    /// Force usage of the CPU device. Do not use PagedAttention with this.
130    pub fn with_force_cpu(mut self) -> Self {
131        self.force_cpu = true;
132        self
133    }
134
135    /// Source of the Hugging Face token.
136    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
137        self.token_source = token_source;
138        self
139    }
140
141    /// Set the revision to use for a Hugging Face remote model.
142    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
143        self.hf_revision = Some(revision.to_string());
144        self
145    }
146
147    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
148    pub fn with_isq(mut self, isq: IsqType) -> Self {
149        self.isq = Some(isq);
150        self
151    }
152
153    /// Utilise this calibration_file file during ISQ
154    pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
155        self.calibration_file = Some(path);
156        self
157    }
158
159    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
160    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`].
161    ///
162    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
163    pub fn with_paged_attn(
164        mut self,
165        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
166    ) -> anyhow::Result<Self> {
167        if paged_attn_supported() {
168            self.paged_attn_cfg = Some(paged_attn_cfg()?);
169        } else {
170            self.paged_attn_cfg = None;
171        }
172        Ok(self)
173    }
174
175    /// Set the maximum number of sequences which can be run at once.
176    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
177        self.max_num_seqs = max_num_seqs;
178        self
179    }
180
181    /// Enable logging.
182    pub fn with_logging(mut self) -> Self {
183        self.with_logging = true;
184        self
185    }
186
187    /// Provide metadata to initialize the device mapper.
188    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
189        self.device_mapping = Some(device_mapping);
190        self
191    }
192
193    /// Path to read a UQFF file from.
194    pub fn from_uqff(mut self, path: PathBuf) -> Self {
195        self.from_uqff = Some(path);
196        self
197    }
198
199    /// Automatically resize and pad images to this maximum edge length. Aspect ratio is preserved.
200    /// This is only supported on the Qwen2-VL and Idefics 2 models. Others handle this internally.
201    pub fn from_max_edge(mut self, max_edge: u32) -> Self {
202        self.max_edge = Some(max_edge);
203        self
204    }
205
206    /// Path to write a UQFF file to.
207    ///
208    /// The parent (part of the path excluding the filename) will determine where any other files
209    /// generated are written to. These can be used to load UQFF models standalone, and may include:
210    /// - `residual.safetensors`
211    /// - `tokenizer.json`
212    /// - `config.json`
213    /// - And others
214    pub fn write_uqff(mut self, path: PathBuf) -> Self {
215        self.write_uqff = Some(path);
216        self
217    }
218
219    /// Cache path for Hugging Face models downloaded locally
220    pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
221        self.hf_cache_path = Some(hf_cache_path);
222        self
223    }
224
225    pub async fn build(self) -> anyhow::Result<Model> {
226        let config = VisionSpecificConfig {
227            use_flash_attn: self.use_flash_attn,
228            prompt_chunksize: self.prompt_chunksize,
229            topology: self.topology,
230            write_uqff: self.write_uqff,
231            from_uqff: self.from_uqff,
232            max_edge: self.max_edge,
233            calibration_file: self.calibration_file,
234            imatrix: self.imatrix,
235            hf_cache_path: self.hf_cache_path,
236        };
237
238        if self.with_logging {
239            initialize_logging();
240        }
241
242        let loader = VisionLoaderBuilder::new(
243            config,
244            self.chat_template,
245            self.tokenizer_json,
246            Some(self.model_id),
247            self.jinja_explicit,
248        )
249        .build(self.loader_type);
250
251        // Load, into a Pipeline
252        let pipeline = loader.load_model_from_hf(
253            self.hf_revision,
254            self.token_source,
255            &self.dtype,
256            &best_device(self.force_cpu)?,
257            !self.with_logging,
258            self.device_mapping
259                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_vision())),
260            self.isq,
261            self.paged_attn_cfg,
262        )?;
263
264        let scheduler_method = match self.paged_attn_cfg {
265            Some(_) => {
266                let config = pipeline
267                    .lock()
268                    .await
269                    .get_metadata()
270                    .cache_config
271                    .as_ref()
272                    .unwrap()
273                    .clone();
274
275                SchedulerConfig::PagedAttentionMeta {
276                    max_num_seqs: self.max_num_seqs,
277                    config,
278                }
279            }
280            None => SchedulerConfig::DefaultScheduler {
281                method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
282            },
283        };
284
285        let runner = MistralRsBuilder::new(
286            pipeline,
287            scheduler_method,
288            self.throughput_logging,
289            self.search_bert_model,
290        )
291        .with_no_kv_cache(false)
292        .with_no_prefix_cache(false);
293
294        Ok(Model::new(runner.build()))
295    }
296}
297
298#[derive(Clone)]
299/// Configure a UQFF text model with the various parameters for loading, running, and other inference behaviors.
300/// This wraps and implements `DerefMut` for the VisionModelBuilder, so users should take care to not call UQFF-related methods.
301pub struct UqffVisionModelBuilder(VisionModelBuilder);
302
303impl UqffVisionModelBuilder {
304    /// A few defaults are applied here:
305    /// - Token source is from the cache (.cache/huggingface/token)
306    /// - Maximum number of sequences running is 32
307    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
308    pub fn new(model_id: impl ToString, loader_type: VisionLoaderType, uqff_file: PathBuf) -> Self {
309        let mut inner = VisionModelBuilder::new(model_id, loader_type);
310        inner = inner.from_uqff(uqff_file);
311        Self(inner)
312    }
313
314    pub async fn build(self) -> anyhow::Result<Model> {
315        self.0.build().await
316    }
317
318    /// This wraps the VisionModelBuilder, so users should take care to not call UQFF-related methods.
319    pub fn into_inner(self) -> VisionModelBuilder {
320        self.0
321    }
322}
323
324impl Deref for UqffVisionModelBuilder {
325    type Target = VisionModelBuilder;
326
327    fn deref(&self) -> &Self::Target {
328        &self.0
329    }
330}
331
332impl DerefMut for UqffVisionModelBuilder {
333    fn deref_mut(&mut self) -> &mut Self::Target {
334        &mut self.0
335    }
336}
337
338impl From<UqffVisionModelBuilder> for VisionModelBuilder {
339    fn from(value: UqffVisionModelBuilder) -> Self {
340        value.0
341    }
342}