1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6 ops::{Deref, DerefMut},
7 path::PathBuf,
8 sync::Arc,
9};
10
11use crate::{best_device, Model};
12
13#[derive(Clone)]
15pub struct ToolCallbackWithTool {
16 pub callback: Arc<ToolCallback>,
17 pub tool: Tool,
18}
19
20#[derive(Clone)]
21pub struct VisionModelBuilder {
23 pub(crate) model_id: String,
25 pub(crate) token_source: TokenSource,
26 pub(crate) hf_revision: Option<String>,
27 pub(crate) write_uqff: Option<PathBuf>,
28 pub(crate) from_uqff: Option<Vec<PathBuf>>,
29 pub(crate) calibration_file: Option<PathBuf>,
30 pub(crate) imatrix: Option<PathBuf>,
31 pub(crate) chat_template: Option<String>,
32 pub(crate) jinja_explicit: Option<String>,
33 pub(crate) tokenizer_json: Option<String>,
34 pub(crate) device_mapping: Option<DeviceMapSetting>,
35 pub(crate) max_edge: Option<u32>,
36 pub(crate) hf_cache_path: Option<PathBuf>,
37 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
38 pub(crate) search_callback: Option<Arc<SearchCallback>>,
39 pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
40 pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
41 pub(crate) device: Option<Device>,
42 pub(crate) matformer_config_path: Option<PathBuf>,
43 pub(crate) matformer_slice_name: Option<String>,
44
45 pub(crate) topology: Option<Topology>,
47 pub(crate) loader_type: Option<VisionLoaderType>,
48 pub(crate) dtype: ModelDType,
49 pub(crate) force_cpu: bool,
50 pub(crate) isq: Option<IsqType>,
51 pub(crate) throughput_logging: bool,
52
53 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
55 pub(crate) max_num_seqs: usize,
56 pub(crate) with_logging: bool,
57}
58
59impl VisionModelBuilder {
60 pub fn new(model_id: impl ToString) -> Self {
66 Self {
67 model_id: model_id.to_string(),
68 topology: None,
69 write_uqff: None,
70 from_uqff: None,
71 chat_template: None,
72 tokenizer_json: None,
73 max_edge: None,
74 loader_type: None,
75 dtype: ModelDType::Auto,
76 force_cpu: false,
77 token_source: TokenSource::CacheToken,
78 hf_revision: None,
79 isq: None,
80 max_num_seqs: 32,
81 with_logging: false,
82 device_mapping: None,
83 calibration_file: None,
84 imatrix: None,
85 jinja_explicit: None,
86 throughput_logging: false,
87 paged_attn_cfg: None,
88 hf_cache_path: None,
89 search_bert_model: None,
90 search_callback: None,
91 tool_callbacks: HashMap::new(),
92 tool_callbacks_with_tools: HashMap::new(),
93 device: None,
94 matformer_config_path: None,
95 matformer_slice_name: None,
96 }
97 }
98
99 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
101 self.search_bert_model = Some(search_bert_model);
102 self
103 }
104
105 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
107 self.search_callback = Some(callback);
108 self
109 }
110
111 pub fn with_tool_callback(
112 mut self,
113 name: impl Into<String>,
114 callback: Arc<ToolCallback>,
115 ) -> Self {
116 self.tool_callbacks.insert(name.into(), callback);
117 self
118 }
119
120 pub fn with_tool_callback_and_tool(
123 mut self,
124 name: impl Into<String>,
125 callback: Arc<ToolCallback>,
126 tool: Tool,
127 ) -> Self {
128 let name = name.into();
129 self.tool_callbacks_with_tools
130 .insert(name, ToolCallbackWithTool { callback, tool });
131 self
132 }
133
134 pub fn with_throughput_logging(mut self) -> Self {
136 self.throughput_logging = true;
137 self
138 }
139
140 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
142 self.jinja_explicit = Some(jinja_explicit);
143 self
144 }
145
146 pub fn with_topology(mut self, topology: Topology) -> Self {
148 self.topology = Some(topology);
149 self
150 }
151
152 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
154 self.chat_template = Some(chat_template.to_string());
155 self
156 }
157
158 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
160 self.tokenizer_json = Some(tokenizer_json.to_string());
161 self
162 }
163
164 pub fn with_loader_type(mut self, loader_type: VisionLoaderType) -> Self {
167 self.loader_type = Some(loader_type);
168 self
169 }
170
171 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
173 self.dtype = dtype;
174 self
175 }
176
177 pub fn with_force_cpu(mut self) -> Self {
179 self.force_cpu = true;
180 self
181 }
182
183 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
185 self.token_source = token_source;
186 self
187 }
188
189 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
191 self.hf_revision = Some(revision.to_string());
192 self
193 }
194
195 pub fn with_isq(mut self, isq: IsqType) -> Self {
197 self.isq = Some(isq);
198 self
199 }
200
201 pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
203 self.calibration_file = Some(path);
204 self
205 }
206
207 pub fn with_paged_attn(
212 mut self,
213 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
214 ) -> anyhow::Result<Self> {
215 if paged_attn_supported() {
216 self.paged_attn_cfg = Some(paged_attn_cfg()?);
217 } else {
218 self.paged_attn_cfg = None;
219 }
220 Ok(self)
221 }
222
223 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
225 self.max_num_seqs = max_num_seqs;
226 self
227 }
228
229 pub fn with_logging(mut self) -> Self {
231 self.with_logging = true;
232 self
233 }
234
235 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
237 self.device_mapping = Some(device_mapping);
238 self
239 }
240
241 #[deprecated(
242 note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
243 )]
244 pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
252 self.from_uqff = Some(path);
253 self
254 }
255
256 pub fn from_max_edge(mut self, max_edge: u32) -> Self {
259 self.max_edge = Some(max_edge);
260 self
261 }
262
263 pub fn write_uqff(mut self, path: PathBuf) -> Self {
274 self.write_uqff = Some(path);
275 self
276 }
277
278 pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
280 self.hf_cache_path = Some(hf_cache_path);
281 self
282 }
283
284 pub fn with_device(mut self, device: Device) -> Self {
286 self.device = Some(device);
287 self
288 }
289
290 pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
292 self.matformer_config_path = Some(path);
293 self
294 }
295
296 pub fn with_matformer_slice_name(mut self, name: String) -> Self {
298 self.matformer_slice_name = Some(name);
299 self
300 }
301
302 pub async fn build(self) -> anyhow::Result<Model> {
303 let config = VisionSpecificConfig {
304 topology: self.topology,
305 write_uqff: self.write_uqff,
306 from_uqff: self.from_uqff,
307 max_edge: self.max_edge,
308 calibration_file: self.calibration_file,
309 imatrix: self.imatrix,
310 hf_cache_path: self.hf_cache_path,
311 matformer_config_path: self.matformer_config_path,
312 matformer_slice_name: self.matformer_slice_name,
313 };
314
315 if self.with_logging {
316 initialize_logging();
317 }
318
319 let loader = VisionLoaderBuilder::new(
320 config,
321 self.chat_template,
322 self.tokenizer_json,
323 Some(self.model_id),
324 self.jinja_explicit,
325 )
326 .build(self.loader_type);
327
328 let pipeline = loader.load_model_from_hf(
330 self.hf_revision,
331 self.token_source,
332 &self.dtype,
333 &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
334 !self.with_logging,
335 self.device_mapping
336 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_vision())),
337 self.isq,
338 self.paged_attn_cfg,
339 )?;
340
341 let scheduler_method = match self.paged_attn_cfg {
342 Some(_) => {
343 let config = pipeline
344 .lock()
345 .await
346 .get_metadata()
347 .cache_config
348 .as_ref()
349 .cloned();
350
351 if let Some(config) = config {
352 SchedulerConfig::PagedAttentionMeta {
353 max_num_seqs: self.max_num_seqs,
354 config,
355 }
356 } else {
357 SchedulerConfig::DefaultScheduler {
358 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
359 }
360 }
361 }
362 None => SchedulerConfig::DefaultScheduler {
363 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
364 },
365 };
366
367 let mut runner = MistralRsBuilder::new(
368 pipeline,
369 scheduler_method,
370 self.throughput_logging,
371 self.search_bert_model,
372 );
373 if let Some(cb) = self.search_callback.clone() {
374 runner = runner.with_search_callback(cb);
375 }
376 for (name, cb) in &self.tool_callbacks {
377 runner = runner.with_tool_callback(name.clone(), cb.clone());
378 }
379 for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
380 runner = runner.with_tool_callback_and_tool(
381 name.clone(),
382 callback_with_tool.callback.clone(),
383 callback_with_tool.tool.clone(),
384 );
385 }
386 let runner = runner.with_no_kv_cache(false).with_no_prefix_cache(false);
387
388 Ok(Model::new(runner.build().await))
389 }
390}
391
392#[derive(Clone)]
393pub struct UqffVisionModelBuilder(VisionModelBuilder);
396
397impl UqffVisionModelBuilder {
398 pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
403 let mut inner = VisionModelBuilder::new(model_id);
404 inner.from_uqff = Some(uqff_file);
405 Self(inner)
406 }
407
408 pub async fn build(self) -> anyhow::Result<Model> {
409 self.0.build().await
410 }
411
412 pub fn into_inner(self) -> VisionModelBuilder {
414 self.0
415 }
416}
417
418impl Deref for UqffVisionModelBuilder {
419 type Target = VisionModelBuilder;
420
421 fn deref(&self) -> &Self::Target {
422 &self.0
423 }
424}
425
426impl DerefMut for UqffVisionModelBuilder {
427 fn deref_mut(&mut self) -> &mut Self::Target {
428 &mut self.0
429 }
430}
431
432impl From<UqffVisionModelBuilder> for VisionModelBuilder {
433 fn from(value: UqffVisionModelBuilder) -> Self {
434 value.0
435 }
436}