1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6 num::NonZeroUsize,
7 ops::{Deref, DerefMut},
8 path::PathBuf,
9 sync::Arc,
10};
11
12use crate::{best_device, Model};
13
14#[derive(Clone)]
16pub struct ToolCallbackWithTool {
17 pub callback: Arc<ToolCallback>,
18 pub tool: Tool,
19}
20
21#[derive(Clone)]
22pub struct VisionModelBuilder {
24 pub(crate) model_id: String,
26 pub(crate) token_source: TokenSource,
27 pub(crate) hf_revision: Option<String>,
28 pub(crate) write_uqff: Option<PathBuf>,
29 pub(crate) from_uqff: Option<Vec<PathBuf>>,
30 pub(crate) calibration_file: Option<PathBuf>,
31 pub(crate) imatrix: Option<PathBuf>,
32 pub(crate) chat_template: Option<String>,
33 pub(crate) jinja_explicit: Option<String>,
34 pub(crate) tokenizer_json: Option<String>,
35 pub(crate) device_mapping: Option<DeviceMapSetting>,
36 pub(crate) max_edge: Option<u32>,
37 pub(crate) hf_cache_path: Option<PathBuf>,
38 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
39 pub(crate) search_callback: Option<Arc<SearchCallback>>,
40 pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
41 pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
42 pub(crate) device: Option<Device>,
43 pub(crate) matformer_config_path: Option<PathBuf>,
44 pub(crate) matformer_slice_name: Option<String>,
45
46 pub(crate) prompt_chunksize: Option<NonZeroUsize>,
48 pub(crate) topology: Option<Topology>,
49 pub(crate) loader_type: Option<VisionLoaderType>,
50 pub(crate) dtype: ModelDType,
51 pub(crate) force_cpu: bool,
52 pub(crate) isq: Option<IsqType>,
53 pub(crate) throughput_logging: bool,
54
55 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
57 pub(crate) max_num_seqs: usize,
58 pub(crate) with_logging: bool,
59}
60
61impl VisionModelBuilder {
62 pub fn new(model_id: impl ToString) -> Self {
68 Self {
69 model_id: model_id.to_string(),
70 topology: None,
71 write_uqff: None,
72 from_uqff: None,
73 prompt_chunksize: None,
74 chat_template: None,
75 tokenizer_json: None,
76 max_edge: None,
77 loader_type: None,
78 dtype: ModelDType::Auto,
79 force_cpu: false,
80 token_source: TokenSource::CacheToken,
81 hf_revision: None,
82 isq: None,
83 max_num_seqs: 32,
84 with_logging: false,
85 device_mapping: None,
86 calibration_file: None,
87 imatrix: None,
88 jinja_explicit: None,
89 throughput_logging: false,
90 paged_attn_cfg: None,
91 hf_cache_path: None,
92 search_bert_model: None,
93 search_callback: None,
94 tool_callbacks: HashMap::new(),
95 tool_callbacks_with_tools: HashMap::new(),
96 device: None,
97 matformer_config_path: None,
98 matformer_slice_name: None,
99 }
100 }
101
102 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
104 self.search_bert_model = Some(search_bert_model);
105 self
106 }
107
108 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
110 self.search_callback = Some(callback);
111 self
112 }
113
114 pub fn with_tool_callback(
115 mut self,
116 name: impl Into<String>,
117 callback: Arc<ToolCallback>,
118 ) -> Self {
119 self.tool_callbacks.insert(name.into(), callback);
120 self
121 }
122
123 pub fn with_tool_callback_and_tool(
126 mut self,
127 name: impl Into<String>,
128 callback: Arc<ToolCallback>,
129 tool: Tool,
130 ) -> Self {
131 let name = name.into();
132 self.tool_callbacks_with_tools
133 .insert(name, ToolCallbackWithTool { callback, tool });
134 self
135 }
136
137 pub fn with_throughput_logging(mut self) -> Self {
139 self.throughput_logging = true;
140 self
141 }
142
143 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
145 self.jinja_explicit = Some(jinja_explicit);
146 self
147 }
148
149 pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
151 self.prompt_chunksize = Some(prompt_chunksize);
152 self
153 }
154
155 pub fn with_topology(mut self, topology: Topology) -> Self {
157 self.topology = Some(topology);
158 self
159 }
160
161 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
163 self.chat_template = Some(chat_template.to_string());
164 self
165 }
166
167 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
169 self.tokenizer_json = Some(tokenizer_json.to_string());
170 self
171 }
172
173 pub fn with_loader_type(mut self, loader_type: VisionLoaderType) -> Self {
176 self.loader_type = Some(loader_type);
177 self
178 }
179
180 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
182 self.dtype = dtype;
183 self
184 }
185
186 pub fn with_force_cpu(mut self) -> Self {
188 self.force_cpu = true;
189 self
190 }
191
192 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
194 self.token_source = token_source;
195 self
196 }
197
198 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
200 self.hf_revision = Some(revision.to_string());
201 self
202 }
203
204 pub fn with_isq(mut self, isq: IsqType) -> Self {
206 self.isq = Some(isq);
207 self
208 }
209
210 pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
212 self.calibration_file = Some(path);
213 self
214 }
215
216 pub fn with_paged_attn(
221 mut self,
222 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
223 ) -> anyhow::Result<Self> {
224 if paged_attn_supported() {
225 self.paged_attn_cfg = Some(paged_attn_cfg()?);
226 } else {
227 self.paged_attn_cfg = None;
228 }
229 Ok(self)
230 }
231
232 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
234 self.max_num_seqs = max_num_seqs;
235 self
236 }
237
238 pub fn with_logging(mut self) -> Self {
240 self.with_logging = true;
241 self
242 }
243
244 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
246 self.device_mapping = Some(device_mapping);
247 self
248 }
249
250 pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
252 self.from_uqff = Some(path);
253 self
254 }
255
256 pub fn from_max_edge(mut self, max_edge: u32) -> Self {
259 self.max_edge = Some(max_edge);
260 self
261 }
262
263 pub fn write_uqff(mut self, path: PathBuf) -> Self {
272 self.write_uqff = Some(path);
273 self
274 }
275
276 pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
278 self.hf_cache_path = Some(hf_cache_path);
279 self
280 }
281
282 pub fn with_device(mut self, device: Device) -> Self {
284 self.device = Some(device);
285 self
286 }
287
288 pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
290 self.matformer_config_path = Some(path);
291 self
292 }
293
294 pub fn with_matformer_slice_name(mut self, name: String) -> Self {
296 self.matformer_slice_name = Some(name);
297 self
298 }
299
300 pub async fn build(self) -> anyhow::Result<Model> {
301 let config = VisionSpecificConfig {
302 prompt_chunksize: self.prompt_chunksize,
303 topology: self.topology,
304 write_uqff: self.write_uqff,
305 from_uqff: self.from_uqff,
306 max_edge: self.max_edge,
307 calibration_file: self.calibration_file,
308 imatrix: self.imatrix,
309 hf_cache_path: self.hf_cache_path,
310 matformer_config_path: self.matformer_config_path,
311 matformer_slice_name: self.matformer_slice_name,
312 };
313
314 if self.with_logging {
315 initialize_logging();
316 }
317
318 let loader = VisionLoaderBuilder::new(
319 config,
320 self.chat_template,
321 self.tokenizer_json,
322 Some(self.model_id),
323 self.jinja_explicit,
324 )
325 .build(self.loader_type);
326
327 let pipeline = loader.load_model_from_hf(
329 self.hf_revision,
330 self.token_source,
331 &self.dtype,
332 &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
333 !self.with_logging,
334 self.device_mapping
335 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_vision())),
336 self.isq,
337 self.paged_attn_cfg,
338 )?;
339
340 let scheduler_method = match self.paged_attn_cfg {
341 Some(_) => {
342 let config = pipeline
343 .lock()
344 .await
345 .get_metadata()
346 .cache_config
347 .as_ref()
348 .cloned();
349
350 if let Some(config) = config {
351 SchedulerConfig::PagedAttentionMeta {
352 max_num_seqs: self.max_num_seqs,
353 config,
354 }
355 } else {
356 SchedulerConfig::DefaultScheduler {
357 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
358 }
359 }
360 }
361 None => SchedulerConfig::DefaultScheduler {
362 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
363 },
364 };
365
366 let mut runner = MistralRsBuilder::new(
367 pipeline,
368 scheduler_method,
369 self.throughput_logging,
370 self.search_bert_model,
371 );
372 if let Some(cb) = self.search_callback.clone() {
373 runner = runner.with_search_callback(cb);
374 }
375 for (name, cb) in &self.tool_callbacks {
376 runner = runner.with_tool_callback(name.clone(), cb.clone());
377 }
378 for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
379 runner = runner.with_tool_callback_and_tool(
380 name.clone(),
381 callback_with_tool.callback.clone(),
382 callback_with_tool.tool.clone(),
383 );
384 }
385 let runner = runner.with_no_kv_cache(false).with_no_prefix_cache(false);
386
387 Ok(Model::new(runner.build().await))
388 }
389}
390
391#[derive(Clone)]
392pub struct UqffVisionModelBuilder(VisionModelBuilder);
395
396impl UqffVisionModelBuilder {
397 pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
402 let mut inner = VisionModelBuilder::new(model_id);
403 inner = inner.from_uqff(uqff_file);
404 Self(inner)
405 }
406
407 pub async fn build(self) -> anyhow::Result<Model> {
408 self.0.build().await
409 }
410
411 pub fn into_inner(self) -> VisionModelBuilder {
413 self.0
414 }
415}
416
417impl Deref for UqffVisionModelBuilder {
418 type Target = VisionModelBuilder;
419
420 fn deref(&self) -> &Self::Target {
421 &self.0
422 }
423}
424
425impl DerefMut for UqffVisionModelBuilder {
426 fn deref_mut(&mut self) -> &mut Self::Target {
427 &mut self.0
428 }
429}
430
431impl From<UqffVisionModelBuilder> for VisionModelBuilder {
432 fn from(value: UqffVisionModelBuilder) -> Self {
433 value.0
434 }
435}