1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6 ops::{Deref, DerefMut},
7 path::PathBuf,
8 sync::Arc,
9};
10
11use crate::{best_device, Model};
12
13#[derive(Clone)]
15pub struct ToolCallbackWithTool {
16 pub callback: Arc<ToolCallback>,
17 pub tool: Tool,
18}
19
20#[derive(Clone)]
21pub struct VisionModelBuilder {
23 pub(crate) model_id: String,
25 pub(crate) token_source: TokenSource,
26 pub(crate) hf_revision: Option<String>,
27 pub(crate) write_uqff: Option<PathBuf>,
28 pub(crate) from_uqff: Option<Vec<PathBuf>>,
29 pub(crate) calibration_file: Option<PathBuf>,
30 pub(crate) imatrix: Option<PathBuf>,
31 pub(crate) chat_template: Option<String>,
32 pub(crate) jinja_explicit: Option<String>,
33 pub(crate) tokenizer_json: Option<String>,
34 pub(crate) device_mapping: Option<DeviceMapSetting>,
35 pub(crate) max_edge: Option<u32>,
36 pub(crate) hf_cache_path: Option<PathBuf>,
37 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
38 pub(crate) search_callback: Option<Arc<SearchCallback>>,
39 pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
40 pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
41 pub(crate) device: Option<Device>,
42 pub(crate) matformer_config_path: Option<PathBuf>,
43 pub(crate) matformer_slice_name: Option<String>,
44
45 pub(crate) topology: Option<Topology>,
47 pub(crate) loader_type: Option<VisionLoaderType>,
48 pub(crate) dtype: ModelDType,
49 pub(crate) force_cpu: bool,
50 pub(crate) isq: Option<IsqType>,
51 pub(crate) throughput_logging: bool,
52
53 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
55 pub(crate) max_num_seqs: usize,
56 pub(crate) with_logging: bool,
57 pub(crate) prefix_cache_n: Option<usize>,
58}
59
60impl VisionModelBuilder {
61 pub fn new(model_id: impl ToString) -> Self {
67 Self {
68 model_id: model_id.to_string(),
69 topology: None,
70 write_uqff: None,
71 from_uqff: None,
72 chat_template: None,
73 tokenizer_json: None,
74 max_edge: None,
75 loader_type: None,
76 dtype: ModelDType::Auto,
77 force_cpu: false,
78 token_source: TokenSource::CacheToken,
79 hf_revision: None,
80 isq: None,
81 max_num_seqs: 32,
82 with_logging: false,
83 device_mapping: None,
84 calibration_file: None,
85 imatrix: None,
86 jinja_explicit: None,
87 throughput_logging: false,
88 paged_attn_cfg: None,
89 hf_cache_path: None,
90 search_bert_model: None,
91 search_callback: None,
92 tool_callbacks: HashMap::new(),
93 tool_callbacks_with_tools: HashMap::new(),
94 device: None,
95 matformer_config_path: None,
96 matformer_slice_name: None,
97 prefix_cache_n: None,
98 }
99 }
100
101 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
103 self.search_bert_model = Some(search_bert_model);
104 self
105 }
106
107 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
109 self.search_callback = Some(callback);
110 self
111 }
112
113 pub fn with_tool_callback(
114 mut self,
115 name: impl Into<String>,
116 callback: Arc<ToolCallback>,
117 ) -> Self {
118 self.tool_callbacks.insert(name.into(), callback);
119 self
120 }
121
122 pub fn with_tool_callback_and_tool(
125 mut self,
126 name: impl Into<String>,
127 callback: Arc<ToolCallback>,
128 tool: Tool,
129 ) -> Self {
130 let name = name.into();
131 self.tool_callbacks_with_tools
132 .insert(name, ToolCallbackWithTool { callback, tool });
133 self
134 }
135
136 pub fn with_throughput_logging(mut self) -> Self {
138 self.throughput_logging = true;
139 self
140 }
141
142 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
144 self.jinja_explicit = Some(jinja_explicit);
145 self
146 }
147
148 pub fn with_topology(mut self, topology: Topology) -> Self {
150 self.topology = Some(topology);
151 self
152 }
153
154 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
156 self.chat_template = Some(chat_template.to_string());
157 self
158 }
159
160 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
162 self.tokenizer_json = Some(tokenizer_json.to_string());
163 self
164 }
165
166 pub fn with_loader_type(mut self, loader_type: VisionLoaderType) -> Self {
169 self.loader_type = Some(loader_type);
170 self
171 }
172
173 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
175 self.dtype = dtype;
176 self
177 }
178
179 pub fn with_force_cpu(mut self) -> Self {
181 self.force_cpu = true;
182 self
183 }
184
185 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
187 self.token_source = token_source;
188 self
189 }
190
191 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
193 self.hf_revision = Some(revision.to_string());
194 self
195 }
196
197 pub fn with_isq(mut self, isq: IsqType) -> Self {
199 self.isq = Some(isq);
200 self
201 }
202
203 pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
205 self.calibration_file = Some(path);
206 self
207 }
208
209 pub fn with_paged_attn(
214 mut self,
215 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
216 ) -> anyhow::Result<Self> {
217 if paged_attn_supported() {
218 self.paged_attn_cfg = Some(paged_attn_cfg()?);
219 } else {
220 self.paged_attn_cfg = None;
221 }
222 Ok(self)
223 }
224
225 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
227 self.max_num_seqs = max_num_seqs;
228 self
229 }
230
231 pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
233 self.prefix_cache_n = n_seqs;
234 self
235 }
236
237 pub fn with_logging(mut self) -> Self {
239 self.with_logging = true;
240 self
241 }
242
243 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
245 self.device_mapping = Some(device_mapping);
246 self
247 }
248
249 #[deprecated(
250 note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
251 )]
252 pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
260 self.from_uqff = Some(path);
261 self
262 }
263
264 pub fn from_max_edge(mut self, max_edge: u32) -> Self {
267 self.max_edge = Some(max_edge);
268 self
269 }
270
271 pub fn write_uqff(mut self, path: PathBuf) -> Self {
282 self.write_uqff = Some(path);
283 self
284 }
285
286 pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
288 self.hf_cache_path = Some(hf_cache_path);
289 self
290 }
291
292 pub fn with_device(mut self, device: Device) -> Self {
294 self.device = Some(device);
295 self
296 }
297
298 pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
300 self.matformer_config_path = Some(path);
301 self
302 }
303
304 pub fn with_matformer_slice_name(mut self, name: String) -> Self {
306 self.matformer_slice_name = Some(name);
307 self
308 }
309
310 pub async fn build(self) -> anyhow::Result<Model> {
311 let config = VisionSpecificConfig {
312 topology: self.topology,
313 write_uqff: self.write_uqff,
314 from_uqff: self.from_uqff,
315 max_edge: self.max_edge,
316 calibration_file: self.calibration_file,
317 imatrix: self.imatrix,
318 hf_cache_path: self.hf_cache_path,
319 matformer_config_path: self.matformer_config_path,
320 matformer_slice_name: self.matformer_slice_name,
321 };
322
323 if self.with_logging {
324 initialize_logging();
325 }
326
327 let loader = VisionLoaderBuilder::new(
328 config,
329 self.chat_template,
330 self.tokenizer_json,
331 Some(self.model_id),
332 self.jinja_explicit,
333 )
334 .build(self.loader_type);
335
336 let pipeline = loader.load_model_from_hf(
338 self.hf_revision,
339 self.token_source,
340 &self.dtype,
341 &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
342 !self.with_logging,
343 self.device_mapping
344 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_vision())),
345 self.isq,
346 self.paged_attn_cfg,
347 )?;
348
349 let scheduler_method = match self.paged_attn_cfg {
350 Some(_) => {
351 let config = pipeline
352 .lock()
353 .await
354 .get_metadata()
355 .cache_config
356 .as_ref()
357 .cloned();
358
359 if let Some(config) = config {
360 SchedulerConfig::PagedAttentionMeta {
361 max_num_seqs: self.max_num_seqs,
362 config,
363 }
364 } else {
365 SchedulerConfig::DefaultScheduler {
366 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
367 }
368 }
369 }
370 None => SchedulerConfig::DefaultScheduler {
371 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
372 },
373 };
374
375 let mut runner = MistralRsBuilder::new(
376 pipeline,
377 scheduler_method,
378 self.throughput_logging,
379 self.search_bert_model,
380 );
381 if let Some(cb) = self.search_callback.clone() {
382 runner = runner.with_search_callback(cb);
383 }
384 for (name, cb) in &self.tool_callbacks {
385 runner = runner.with_tool_callback(name.clone(), cb.clone());
386 }
387 for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
388 runner = runner.with_tool_callback_and_tool(
389 name.clone(),
390 callback_with_tool.callback.clone(),
391 callback_with_tool.tool.clone(),
392 );
393 }
394 let mut runner = runner
395 .with_no_kv_cache(false)
396 .with_no_prefix_cache(self.prefix_cache_n.is_none());
397
398 if let Some(n) = self.prefix_cache_n {
399 runner = runner.with_prefix_cache_n(n)
400 }
401
402 Ok(Model::new(runner.build().await))
403 }
404}
405
406#[derive(Clone)]
407pub struct UqffVisionModelBuilder(VisionModelBuilder);
410
411impl UqffVisionModelBuilder {
412 pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
417 let mut inner = VisionModelBuilder::new(model_id);
418 inner.from_uqff = Some(uqff_file);
419 Self(inner)
420 }
421
422 pub async fn build(self) -> anyhow::Result<Model> {
423 self.0.build().await
424 }
425
426 pub fn into_inner(self) -> VisionModelBuilder {
428 self.0
429 }
430}
431
432impl Deref for UqffVisionModelBuilder {
433 type Target = VisionModelBuilder;
434
435 fn deref(&self) -> &Self::Target {
436 &self.0
437 }
438}
439
440impl DerefMut for UqffVisionModelBuilder {
441 fn deref_mut(&mut self) -> &mut Self::Target {
442 &mut self.0
443 }
444}
445
446impl From<UqffVisionModelBuilder> for VisionModelBuilder {
447 fn from(value: UqffVisionModelBuilder) -> Self {
448 value.0
449 }
450}