1use candle_core::Device;
2use mistralrs_core::*;
3use std::{
4 num::NonZeroUsize,
5 ops::{Deref, DerefMut},
6 path::PathBuf,
7};
8
9use crate::{best_device, Model};
10
11#[derive(Clone)]
12pub struct VisionModelBuilder {
14 pub(crate) model_id: String,
16 pub(crate) token_source: TokenSource,
17 pub(crate) hf_revision: Option<String>,
18 pub(crate) write_uqff: Option<PathBuf>,
19 pub(crate) from_uqff: Option<Vec<PathBuf>>,
20 pub(crate) calibration_file: Option<PathBuf>,
21 pub(crate) imatrix: Option<PathBuf>,
22 pub(crate) chat_template: Option<String>,
23 pub(crate) jinja_explicit: Option<String>,
24 pub(crate) tokenizer_json: Option<String>,
25 pub(crate) device_mapping: Option<DeviceMapSetting>,
26 pub(crate) max_edge: Option<u32>,
27 pub(crate) hf_cache_path: Option<PathBuf>,
28 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
29 pub(crate) device: Option<Device>,
30
31 pub(crate) prompt_chunksize: Option<NonZeroUsize>,
33 pub(crate) topology: Option<Topology>,
34 pub(crate) loader_type: Option<VisionLoaderType>,
35 pub(crate) dtype: ModelDType,
36 pub(crate) force_cpu: bool,
37 pub(crate) isq: Option<IsqType>,
38 pub(crate) throughput_logging: bool,
39
40 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
42 pub(crate) max_num_seqs: usize,
43 pub(crate) with_logging: bool,
44}
45
46impl VisionModelBuilder {
47 pub fn new(model_id: impl ToString) -> Self {
53 Self {
54 model_id: model_id.to_string(),
55 topology: None,
56 write_uqff: None,
57 from_uqff: None,
58 prompt_chunksize: None,
59 chat_template: None,
60 tokenizer_json: None,
61 max_edge: None,
62 loader_type: None,
63 dtype: ModelDType::Auto,
64 force_cpu: false,
65 token_source: TokenSource::CacheToken,
66 hf_revision: None,
67 isq: None,
68 max_num_seqs: 32,
69 with_logging: false,
70 device_mapping: None,
71 calibration_file: None,
72 imatrix: None,
73 jinja_explicit: None,
74 throughput_logging: false,
75 paged_attn_cfg: None,
76 hf_cache_path: None,
77 search_bert_model: None,
78 device: None,
79 }
80 }
81
82 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
84 self.search_bert_model = Some(search_bert_model);
85 self
86 }
87
88 pub fn with_throughput_logging(mut self) -> Self {
90 self.throughput_logging = true;
91 self
92 }
93
94 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
96 self.jinja_explicit = Some(jinja_explicit);
97 self
98 }
99
100 pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
102 self.prompt_chunksize = Some(prompt_chunksize);
103 self
104 }
105
106 pub fn with_topology(mut self, topology: Topology) -> Self {
108 self.topology = Some(topology);
109 self
110 }
111
112 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
114 self.chat_template = Some(chat_template.to_string());
115 self
116 }
117
118 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
120 self.tokenizer_json = Some(tokenizer_json.to_string());
121 self
122 }
123
124 pub fn with_loader_type(mut self, loader_type: VisionLoaderType) -> Self {
127 self.loader_type = Some(loader_type);
128 self
129 }
130
131 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
133 self.dtype = dtype;
134 self
135 }
136
137 pub fn with_force_cpu(mut self) -> Self {
139 self.force_cpu = true;
140 self
141 }
142
143 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
145 self.token_source = token_source;
146 self
147 }
148
149 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
151 self.hf_revision = Some(revision.to_string());
152 self
153 }
154
155 pub fn with_isq(mut self, isq: IsqType) -> Self {
157 self.isq = Some(isq);
158 self
159 }
160
161 pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
163 self.calibration_file = Some(path);
164 self
165 }
166
167 pub fn with_paged_attn(
172 mut self,
173 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
174 ) -> anyhow::Result<Self> {
175 if paged_attn_supported() {
176 self.paged_attn_cfg = Some(paged_attn_cfg()?);
177 } else {
178 self.paged_attn_cfg = None;
179 }
180 Ok(self)
181 }
182
183 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
185 self.max_num_seqs = max_num_seqs;
186 self
187 }
188
189 pub fn with_logging(mut self) -> Self {
191 self.with_logging = true;
192 self
193 }
194
195 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
197 self.device_mapping = Some(device_mapping);
198 self
199 }
200
201 pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
203 self.from_uqff = Some(path);
204 self
205 }
206
207 pub fn from_max_edge(mut self, max_edge: u32) -> Self {
210 self.max_edge = Some(max_edge);
211 self
212 }
213
214 pub fn write_uqff(mut self, path: PathBuf) -> Self {
223 self.write_uqff = Some(path);
224 self
225 }
226
227 pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
229 self.hf_cache_path = Some(hf_cache_path);
230 self
231 }
232
233 pub fn with_device(mut self, device: Device) -> Self {
235 self.device = Some(device);
236 self
237 }
238
239 pub async fn build(self) -> anyhow::Result<Model> {
240 let config = VisionSpecificConfig {
241 prompt_chunksize: self.prompt_chunksize,
242 topology: self.topology,
243 write_uqff: self.write_uqff,
244 from_uqff: self.from_uqff,
245 max_edge: self.max_edge,
246 calibration_file: self.calibration_file,
247 imatrix: self.imatrix,
248 hf_cache_path: self.hf_cache_path,
249 };
250
251 if self.with_logging {
252 initialize_logging();
253 }
254
255 let loader = VisionLoaderBuilder::new(
256 config,
257 self.chat_template,
258 self.tokenizer_json,
259 Some(self.model_id),
260 self.jinja_explicit,
261 )
262 .build(self.loader_type);
263
264 let pipeline = loader.load_model_from_hf(
266 self.hf_revision,
267 self.token_source,
268 &self.dtype,
269 &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
270 !self.with_logging,
271 self.device_mapping
272 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_vision())),
273 self.isq,
274 self.paged_attn_cfg,
275 )?;
276
277 let scheduler_method = match self.paged_attn_cfg {
278 Some(_) => {
279 let config = pipeline
280 .lock()
281 .await
282 .get_metadata()
283 .cache_config
284 .as_ref()
285 .cloned();
286
287 if let Some(config) = config {
288 SchedulerConfig::PagedAttentionMeta {
289 max_num_seqs: self.max_num_seqs,
290 config,
291 }
292 } else {
293 SchedulerConfig::DefaultScheduler {
294 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
295 }
296 }
297 }
298 None => SchedulerConfig::DefaultScheduler {
299 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
300 },
301 };
302
303 let runner = MistralRsBuilder::new(
304 pipeline,
305 scheduler_method,
306 self.throughput_logging,
307 self.search_bert_model,
308 )
309 .with_no_kv_cache(false)
310 .with_no_prefix_cache(false);
311
312 Ok(Model::new(runner.build()))
313 }
314}
315
316#[derive(Clone)]
317pub struct UqffVisionModelBuilder(VisionModelBuilder);
320
321impl UqffVisionModelBuilder {
322 pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
327 let mut inner = VisionModelBuilder::new(model_id);
328 inner = inner.from_uqff(uqff_file);
329 Self(inner)
330 }
331
332 pub async fn build(self) -> anyhow::Result<Model> {
333 self.0.build().await
334 }
335
336 pub fn into_inner(self) -> VisionModelBuilder {
338 self.0
339 }
340}
341
342impl Deref for UqffVisionModelBuilder {
343 type Target = VisionModelBuilder;
344
345 fn deref(&self) -> &Self::Target {
346 &self.0
347 }
348}
349
350impl DerefMut for UqffVisionModelBuilder {
351 fn deref_mut(&mut self) -> &mut Self::Target {
352 &mut self.0
353 }
354}
355
356impl From<UqffVisionModelBuilder> for VisionModelBuilder {
357 fn from(value: UqffVisionModelBuilder) -> Self {
358 value.0
359 }
360}