1use std::{
2 fs::{self, File},
3 num::NonZeroUsize,
4 path::PathBuf,
5 str::FromStr,
6};
7
8use mistralrs_quant::MULTI_LORA_DELIMITER;
9
10use crate::{
11 get_toml_selected_model_dtype,
12 pipeline::{GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, NormalSpecificConfig},
13 toml_selector::get_toml_selected_model_device_map_params,
14 AutoDeviceMapParams, DiffusionLoaderBuilder, DiffusionSpecificConfig, GGUFSpecificConfig,
15 Loader, ModelDType, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector, Topology,
16 VisionLoaderBuilder, VisionSpecificConfig, GGUF_MULTI_FILE_DELIMITER,
17 UQFF_MULTI_FILE_DELIMITER,
18};
19
20pub struct LoaderBuilder {
22 model: ModelSelected,
23 no_kv_cache: bool,
24 chat_template: Option<String>,
25 jinja_explicit: Option<String>,
26 use_flash_attn: bool,
27 prompt_chunksize: Option<NonZeroUsize>,
28}
29
30impl LoaderBuilder {
31 pub fn new(model: ModelSelected) -> Self {
32 Self {
33 model,
34 no_kv_cache: false,
35 chat_template: None,
36 use_flash_attn: false,
37 prompt_chunksize: None,
38 jinja_explicit: None,
39 }
40 }
41
42 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
43 self.no_kv_cache = no_kv_cache;
44 self
45 }
46 pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
47 self.chat_template = chat_template;
48 self
49 }
50 pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
51 self.jinja_explicit = jinja_explicit;
52 self
53 }
54 pub fn with_use_flash_attn(mut self, use_flash_attn: bool) -> Self {
55 self.use_flash_attn = use_flash_attn;
56 self
57 }
58 pub fn with_prompt_chunksize(mut self, prompt_chunksize: Option<NonZeroUsize>) -> Self {
59 self.prompt_chunksize = prompt_chunksize;
60 self
61 }
62
63 pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
64 loader_from_model_selected(self)
65 }
66}
67
68pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
69 match model {
70 ModelSelected::Plain { .. }
71 | ModelSelected::Lora { .. }
72 | ModelSelected::GGUF { .. }
73 | ModelSelected::LoraGGUF { .. }
74 | ModelSelected::GGML { .. }
75 | ModelSelected::LoraGGML { .. }
76 | ModelSelected::Toml { .. }
77 | ModelSelected::VisionPlain { .. }
78 | ModelSelected::DiffusionPlain { .. } => None,
79 ModelSelected::XLora {
80 tgt_non_granular_index,
81 ..
82 }
83 | ModelSelected::XLoraGGUF {
84 tgt_non_granular_index,
85 ..
86 }
87 | ModelSelected::XLoraGGML {
88 tgt_non_granular_index,
89 ..
90 } => *tgt_non_granular_index,
91 }
92}
93
94pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
95 match model {
96 ModelSelected::Plain { dtype, .. }
97 | ModelSelected::Lora { dtype, .. }
98 | ModelSelected::XLora { dtype, .. }
99 | ModelSelected::VisionPlain { dtype, .. }
100 | ModelSelected::DiffusionPlain { dtype, .. }
101 | ModelSelected::GGML { dtype, .. }
102 | ModelSelected::GGUF { dtype, .. }
103 | ModelSelected::XLoraGGUF { dtype, .. }
104 | ModelSelected::XLoraGGML { dtype, .. }
105 | ModelSelected::LoraGGUF { dtype, .. }
106 | ModelSelected::LoraGGML { dtype, .. } => Ok(*dtype),
107 ModelSelected::Toml { file } => {
108 let selector: TomlSelector = toml::from_str(
109 &fs::read_to_string(file.clone())
110 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
111 )?;
112 Ok(get_toml_selected_model_dtype(&selector))
113 }
114 }
115}
116
117pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
118 match model {
119 ModelSelected::Plain {
120 max_seq_len,
121 max_batch_size,
122 ..
123 }
124 | ModelSelected::Lora {
125 max_seq_len,
126 max_batch_size,
127 ..
128 }
129 | ModelSelected::XLora {
130 max_seq_len,
131 max_batch_size,
132 ..
133 }
134 | ModelSelected::GGML {
135 max_seq_len,
136 max_batch_size,
137 ..
138 }
139 | ModelSelected::GGUF {
140 max_seq_len,
141 max_batch_size,
142 ..
143 }
144 | ModelSelected::XLoraGGUF {
145 max_seq_len,
146 max_batch_size,
147 ..
148 }
149 | ModelSelected::XLoraGGML {
150 max_seq_len,
151 max_batch_size,
152 ..
153 }
154 | ModelSelected::LoraGGUF {
155 max_seq_len,
156 max_batch_size,
157 ..
158 }
159 | ModelSelected::LoraGGML {
160 max_seq_len,
161 max_batch_size,
162 ..
163 } => Ok(AutoDeviceMapParams::Text {
164 max_seq_len: *max_seq_len,
165 max_batch_size: *max_batch_size,
166 }),
167 ModelSelected::VisionPlain {
168 max_seq_len,
169 max_batch_size,
170 max_image_length,
171 max_num_images,
172 ..
173 } => Ok(AutoDeviceMapParams::Vision {
174 max_seq_len: *max_seq_len,
175 max_batch_size: *max_batch_size,
176 max_image_shape: (*max_image_length, *max_image_length),
177 max_num_images: *max_num_images,
178 }),
179 ModelSelected::DiffusionPlain { .. } => Ok(AutoDeviceMapParams::default_text()),
180 ModelSelected::Toml { file } => {
181 let selector: TomlSelector = toml::from_str(
182 &fs::read_to_string(file.clone())
183 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
184 )?;
185 get_toml_selected_model_device_map_params(&selector)
186 }
187 }
188}
189
190fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
191 let use_flash_attn = args.use_flash_attn;
192 let loader: Box<dyn Loader> = match args.model {
193 ModelSelected::Toml { file } => {
194 let selector: TomlSelector = toml::from_str(
195 &fs::read_to_string(file.clone())
196 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
197 )?;
198 let args = TomlLoaderArgs {
199 use_flash_attn,
200 chat_template: args.chat_template,
201 no_kv_cache: args.no_kv_cache,
202 prompt_chunksize: args.prompt_chunksize,
203 jinja_explicit: args.jinja_explicit,
204 };
205 (selector, args).try_into()?
206 }
207 ModelSelected::Plain {
208 model_id,
209 tokenizer_json,
210 arch,
211 dtype: _,
212 topology,
213 organization,
214 write_uqff,
215 from_uqff,
216 imatrix,
217 calibration_file,
218 max_seq_len: _,
219 max_batch_size: _,
220 hf_cache_path,
221 } => NormalLoaderBuilder::new(
222 NormalSpecificConfig {
223 use_flash_attn,
224 prompt_chunksize: args.prompt_chunksize,
225 topology: Topology::from_option_path(topology)?,
226 organization: organization.unwrap_or_default(),
227 write_uqff,
228 from_uqff: from_uqff.map(|x| {
229 x.split(UQFF_MULTI_FILE_DELIMITER)
230 .map(PathBuf::from_str)
231 .map(|x| x.unwrap())
232 .collect::<Vec<_>>()
233 }),
234 imatrix,
235 calibration_file,
236 hf_cache_path,
237 },
238 args.chat_template,
239 tokenizer_json,
240 Some(model_id),
241 args.no_kv_cache,
242 args.jinja_explicit,
243 )
244 .build(arch)?,
245 ModelSelected::XLora {
246 model_id,
247 xlora_model_id,
248 order,
249 tokenizer_json,
250 tgt_non_granular_index,
251 arch,
252 dtype: _,
253 topology,
254 write_uqff,
255 from_uqff,
256 max_seq_len: _,
257 max_batch_size: _,
258 hf_cache_path,
259 } => NormalLoaderBuilder::new(
260 NormalSpecificConfig {
261 use_flash_attn,
262 prompt_chunksize: args.prompt_chunksize,
263 topology: Topology::from_option_path(topology)?,
264 organization: Default::default(),
265 write_uqff,
266 from_uqff: from_uqff.map(|x| {
267 x.split(UQFF_MULTI_FILE_DELIMITER)
268 .map(PathBuf::from_str)
269 .map(|x| x.unwrap())
270 .collect::<Vec<_>>()
271 }),
272 imatrix: None,
273 calibration_file: None,
274 hf_cache_path,
275 },
276 args.chat_template,
277 tokenizer_json,
278 model_id,
279 args.no_kv_cache,
280 args.jinja_explicit,
281 )
282 .with_xlora(
283 xlora_model_id,
284 serde_json::from_reader(
285 File::open(order.clone())
286 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
287 )?,
288 args.no_kv_cache,
289 tgt_non_granular_index,
290 )
291 .build(arch)?,
292 ModelSelected::Lora {
293 model_id,
294 tokenizer_json,
295 adapter_model_id,
296 arch,
297 dtype: _,
298 topology,
299 write_uqff,
300 from_uqff,
301 max_seq_len: _,
302 max_batch_size: _,
303 hf_cache_path,
304 } => NormalLoaderBuilder::new(
305 NormalSpecificConfig {
306 use_flash_attn,
307 prompt_chunksize: args.prompt_chunksize,
308 topology: Topology::from_option_path(topology)?,
309 organization: Default::default(),
310 write_uqff,
311 from_uqff: from_uqff.map(|x| {
312 x.split(UQFF_MULTI_FILE_DELIMITER)
313 .map(PathBuf::from_str)
314 .map(|x| x.unwrap())
315 .collect::<Vec<_>>()
316 }),
317 imatrix: None,
318 calibration_file: None,
319 hf_cache_path,
320 },
321 args.chat_template,
322 tokenizer_json,
323 model_id,
324 args.no_kv_cache,
325 args.jinja_explicit,
326 )
327 .with_lora(
328 adapter_model_id
329 .split(MULTI_LORA_DELIMITER)
330 .map(ToString::to_string)
331 .collect(),
332 )
333 .build(arch)?,
334 ModelSelected::GGUF {
335 tok_model_id,
336 quantized_model_id,
337 quantized_filename,
338 topology,
339 ..
340 } => GGUFLoaderBuilder::new(
341 args.chat_template,
342 tok_model_id,
343 quantized_model_id,
344 quantized_filename
345 .split(GGUF_MULTI_FILE_DELIMITER)
346 .map(ToOwned::to_owned)
347 .collect::<Vec<_>>(),
348 GGUFSpecificConfig {
349 prompt_chunksize: args.prompt_chunksize,
350 topology: Topology::from_option_path(topology)?,
351 },
352 args.no_kv_cache,
353 args.jinja_explicit,
354 )
355 .build(),
356 ModelSelected::XLoraGGUF {
357 tok_model_id,
358 quantized_model_id,
359 quantized_filename,
360 xlora_model_id,
361 order,
362 tgt_non_granular_index,
363 topology,
364 ..
365 } => GGUFLoaderBuilder::new(
366 args.chat_template,
367 tok_model_id,
368 quantized_model_id,
369 quantized_filename
370 .split(GGUF_MULTI_FILE_DELIMITER)
371 .map(ToOwned::to_owned)
372 .collect::<Vec<_>>(),
373 GGUFSpecificConfig {
374 prompt_chunksize: args.prompt_chunksize,
375 topology: Topology::from_option_path(topology)?,
376 },
377 args.no_kv_cache,
378 args.jinja_explicit,
379 )
380 .with_xlora(
381 xlora_model_id,
382 serde_json::from_reader(
383 File::open(order.clone())
384 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
385 )?,
386 args.no_kv_cache,
387 tgt_non_granular_index,
388 )
389 .build(),
390 ModelSelected::LoraGGUF {
391 tok_model_id,
392 quantized_model_id,
393 quantized_filename,
394 adapters_model_id,
395 order,
396 topology,
397 ..
398 } => GGUFLoaderBuilder::new(
399 args.chat_template,
400 tok_model_id,
401 quantized_model_id,
402 quantized_filename
403 .split(GGUF_MULTI_FILE_DELIMITER)
404 .map(ToOwned::to_owned)
405 .collect::<Vec<_>>(),
406 GGUFSpecificConfig {
407 prompt_chunksize: args.prompt_chunksize,
408 topology: Topology::from_option_path(topology)?,
409 },
410 args.no_kv_cache,
411 args.jinja_explicit,
412 )
413 .with_lora(
414 adapters_model_id,
415 serde_json::from_reader(
416 File::open(order.clone())
417 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
418 )?,
419 )
420 .build(),
421 ModelSelected::GGML {
422 tok_model_id,
423 tokenizer_json,
424 quantized_model_id,
425 quantized_filename,
426 gqa,
427 topology,
428 ..
429 } => GGMLLoaderBuilder::new(
430 GGMLSpecificConfig {
431 gqa,
432 prompt_chunksize: args.prompt_chunksize,
433 topology: Topology::from_option_path(topology)?,
434 },
435 args.chat_template,
436 tokenizer_json,
437 Some(tok_model_id),
438 quantized_model_id,
439 quantized_filename,
440 args.no_kv_cache,
441 args.jinja_explicit,
442 )
443 .build(),
444 ModelSelected::XLoraGGML {
445 tok_model_id,
446 tokenizer_json,
447 quantized_model_id,
448 quantized_filename,
449 xlora_model_id,
450 order,
451 tgt_non_granular_index,
452 gqa,
453 topology,
454 ..
455 } => GGMLLoaderBuilder::new(
456 GGMLSpecificConfig {
457 gqa,
458 prompt_chunksize: args.prompt_chunksize,
459 topology: Topology::from_option_path(topology)?,
460 },
461 args.chat_template,
462 tokenizer_json,
463 tok_model_id,
464 quantized_model_id,
465 quantized_filename,
466 args.no_kv_cache,
467 args.jinja_explicit,
468 )
469 .with_xlora(
470 xlora_model_id,
471 serde_json::from_reader(
472 File::open(order.clone())
473 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
474 )?,
475 args.no_kv_cache,
476 tgt_non_granular_index,
477 )
478 .build(),
479 ModelSelected::LoraGGML {
480 tok_model_id,
481 tokenizer_json,
482 quantized_model_id,
483 quantized_filename,
484 adapters_model_id,
485 order,
486 gqa,
487 topology,
488 ..
489 } => GGMLLoaderBuilder::new(
490 GGMLSpecificConfig {
491 gqa,
492 prompt_chunksize: args.prompt_chunksize,
493 topology: Topology::from_option_path(topology)?,
494 },
495 args.chat_template,
496 tokenizer_json,
497 tok_model_id,
498 quantized_model_id,
499 quantized_filename,
500 args.no_kv_cache,
501 args.jinja_explicit,
502 )
503 .with_lora(
504 adapters_model_id,
505 serde_json::from_reader(
506 File::open(order.clone())
507 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
508 )?,
509 )
510 .build(),
511 ModelSelected::VisionPlain {
512 model_id,
513 tokenizer_json,
514 arch,
515 dtype: _,
516 topology,
517 write_uqff,
518 from_uqff,
519 max_edge,
520 calibration_file,
521 max_seq_len: _,
522 max_batch_size: _,
523 max_num_images: _,
524 max_image_length: _,
525 hf_cache_path,
526 imatrix,
527 } => VisionLoaderBuilder::new(
528 VisionSpecificConfig {
529 use_flash_attn,
530 prompt_chunksize: args.prompt_chunksize,
531 topology: Topology::from_option_path(topology)?,
532 write_uqff,
533 from_uqff: from_uqff.map(|x| {
534 x.split(UQFF_MULTI_FILE_DELIMITER)
535 .map(PathBuf::from_str)
536 .map(|x| x.unwrap())
537 .collect::<Vec<_>>()
538 }),
539 max_edge,
540 calibration_file,
541 imatrix,
542 hf_cache_path,
543 },
544 args.chat_template,
545 tokenizer_json,
546 Some(model_id),
547 args.jinja_explicit,
548 )
549 .build(arch),
550 ModelSelected::DiffusionPlain {
551 model_id,
552 arch,
553 dtype: _,
554 } => {
555 DiffusionLoaderBuilder::new(DiffusionSpecificConfig { use_flash_attn }, Some(model_id))
556 .build(arch)
557 }
558 };
559 Ok(loader)
560}