1use mistralrs_core::*;
2use std::{
3 num::NonZeroUsize,
4 ops::{Deref, DerefMut},
5 path::PathBuf,
6};
7
8use crate::{best_device, Model};
9
10#[derive(Clone)]
11pub struct VisionModelBuilder {
13 pub(crate) model_id: String,
15 pub(crate) token_source: TokenSource,
16 pub(crate) hf_revision: Option<String>,
17 pub(crate) write_uqff: Option<PathBuf>,
18 pub(crate) from_uqff: Option<PathBuf>,
19 pub(crate) calibration_file: Option<PathBuf>,
20 pub(crate) imatrix: Option<PathBuf>,
21 pub(crate) chat_template: Option<String>,
22 pub(crate) jinja_explicit: Option<String>,
23 pub(crate) tokenizer_json: Option<String>,
24 pub(crate) device_mapping: Option<DeviceMapSetting>,
25 pub(crate) max_edge: Option<u32>,
26 pub(crate) hf_cache_path: Option<PathBuf>,
27 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
28
29 pub(crate) use_flash_attn: bool,
31 pub(crate) prompt_chunksize: Option<NonZeroUsize>,
32 pub(crate) topology: Option<Topology>,
33 pub(crate) loader_type: VisionLoaderType,
34 pub(crate) dtype: ModelDType,
35 pub(crate) force_cpu: bool,
36 pub(crate) isq: Option<IsqType>,
37 pub(crate) throughput_logging: bool,
38
39 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
41 pub(crate) max_num_seqs: usize,
42 pub(crate) with_logging: bool,
43}
44
45impl VisionModelBuilder {
46 pub fn new(model_id: impl ToString, loader_type: VisionLoaderType) -> Self {
52 Self {
53 model_id: model_id.to_string(),
54 use_flash_attn: cfg!(feature = "flash-attn"),
55 topology: None,
56 write_uqff: None,
57 from_uqff: None,
58 prompt_chunksize: None,
59 chat_template: None,
60 tokenizer_json: None,
61 max_edge: None,
62 loader_type,
63 dtype: ModelDType::Auto,
64 force_cpu: false,
65 token_source: TokenSource::CacheToken,
66 hf_revision: None,
67 isq: None,
68 max_num_seqs: 32,
69 with_logging: false,
70 device_mapping: None,
71 calibration_file: None,
72 imatrix: None,
73 jinja_explicit: None,
74 throughput_logging: false,
75 paged_attn_cfg: None,
76 hf_cache_path: None,
77 search_bert_model: None,
78 }
79 }
80
81 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
83 self.search_bert_model = Some(search_bert_model);
84 self
85 }
86
87 pub fn with_throughput_logging(mut self) -> Self {
89 self.throughput_logging = true;
90 self
91 }
92
93 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
95 self.jinja_explicit = Some(jinja_explicit);
96 self
97 }
98
99 pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
101 self.prompt_chunksize = Some(prompt_chunksize);
102 self
103 }
104
105 pub fn with_topology(mut self, topology: Topology) -> Self {
107 self.topology = Some(topology);
108 self
109 }
110
111 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
113 self.chat_template = Some(chat_template.to_string());
114 self
115 }
116
117 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
119 self.tokenizer_json = Some(tokenizer_json.to_string());
120 self
121 }
122
123 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
125 self.dtype = dtype;
126 self
127 }
128
129 pub fn with_force_cpu(mut self) -> Self {
131 self.force_cpu = true;
132 self
133 }
134
135 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
137 self.token_source = token_source;
138 self
139 }
140
141 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
143 self.hf_revision = Some(revision.to_string());
144 self
145 }
146
147 pub fn with_isq(mut self, isq: IsqType) -> Self {
149 self.isq = Some(isq);
150 self
151 }
152
153 pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
155 self.calibration_file = Some(path);
156 self
157 }
158
159 pub fn with_paged_attn(
164 mut self,
165 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
166 ) -> anyhow::Result<Self> {
167 if paged_attn_supported() {
168 self.paged_attn_cfg = Some(paged_attn_cfg()?);
169 } else {
170 self.paged_attn_cfg = None;
171 }
172 Ok(self)
173 }
174
175 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
177 self.max_num_seqs = max_num_seqs;
178 self
179 }
180
181 pub fn with_logging(mut self) -> Self {
183 self.with_logging = true;
184 self
185 }
186
187 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
189 self.device_mapping = Some(device_mapping);
190 self
191 }
192
193 pub fn from_uqff(mut self, path: PathBuf) -> Self {
195 self.from_uqff = Some(path);
196 self
197 }
198
199 pub fn from_max_edge(mut self, max_edge: u32) -> Self {
202 self.max_edge = Some(max_edge);
203 self
204 }
205
206 pub fn write_uqff(mut self, path: PathBuf) -> Self {
215 self.write_uqff = Some(path);
216 self
217 }
218
219 pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
221 self.hf_cache_path = Some(hf_cache_path);
222 self
223 }
224
225 pub async fn build(self) -> anyhow::Result<Model> {
226 let config = VisionSpecificConfig {
227 use_flash_attn: self.use_flash_attn,
228 prompt_chunksize: self.prompt_chunksize,
229 topology: self.topology,
230 write_uqff: self.write_uqff,
231 from_uqff: self.from_uqff,
232 max_edge: self.max_edge,
233 calibration_file: self.calibration_file,
234 imatrix: self.imatrix,
235 hf_cache_path: self.hf_cache_path,
236 };
237
238 if self.with_logging {
239 initialize_logging();
240 }
241
242 let loader = VisionLoaderBuilder::new(
243 config,
244 self.chat_template,
245 self.tokenizer_json,
246 Some(self.model_id),
247 self.jinja_explicit,
248 )
249 .build(self.loader_type);
250
251 let pipeline = loader.load_model_from_hf(
253 self.hf_revision,
254 self.token_source,
255 &self.dtype,
256 &best_device(self.force_cpu)?,
257 !self.with_logging,
258 self.device_mapping
259 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_vision())),
260 self.isq,
261 self.paged_attn_cfg,
262 )?;
263
264 let scheduler_method = match self.paged_attn_cfg {
265 Some(_) => {
266 let config = pipeline
267 .lock()
268 .await
269 .get_metadata()
270 .cache_config
271 .as_ref()
272 .unwrap()
273 .clone();
274
275 SchedulerConfig::PagedAttentionMeta {
276 max_num_seqs: self.max_num_seqs,
277 config,
278 }
279 }
280 None => SchedulerConfig::DefaultScheduler {
281 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
282 },
283 };
284
285 let runner = MistralRsBuilder::new(
286 pipeline,
287 scheduler_method,
288 self.throughput_logging,
289 self.search_bert_model,
290 )
291 .with_no_kv_cache(false)
292 .with_no_prefix_cache(false);
293
294 Ok(Model::new(runner.build()))
295 }
296}
297
298#[derive(Clone)]
299pub struct UqffVisionModelBuilder(VisionModelBuilder);
302
303impl UqffVisionModelBuilder {
304 pub fn new(model_id: impl ToString, loader_type: VisionLoaderType, uqff_file: PathBuf) -> Self {
309 let mut inner = VisionModelBuilder::new(model_id, loader_type);
310 inner = inner.from_uqff(uqff_file);
311 Self(inner)
312 }
313
314 pub async fn build(self) -> anyhow::Result<Model> {
315 self.0.build().await
316 }
317
318 pub fn into_inner(self) -> VisionModelBuilder {
320 self.0
321 }
322}
323
324impl Deref for UqffVisionModelBuilder {
325 type Target = VisionModelBuilder;
326
327 fn deref(&self) -> &Self::Target {
328 &self.0
329 }
330}
331
332impl DerefMut for UqffVisionModelBuilder {
333 fn deref_mut(&mut self) -> &mut Self::Target {
334 &mut self.0
335 }
336}
337
338impl From<UqffVisionModelBuilder> for VisionModelBuilder {
339 fn from(value: UqffVisionModelBuilder) -> Self {
340 value.0
341 }
342}