1use std::{
2 fs::{self, File},
3 num::NonZeroUsize,
4};
5
6use mistralrs_quant::MULTI_LORA_DELIMITER;
7
8use crate::{
9 get_toml_selected_model_dtype,
10 pipeline::{GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, NormalSpecificConfig},
11 toml_selector::get_toml_selected_model_device_map_params,
12 AutoDeviceMapParams, DiffusionLoaderBuilder, DiffusionSpecificConfig, GGUFSpecificConfig,
13 Loader, ModelDType, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector, Topology,
14 VisionLoaderBuilder, VisionSpecificConfig, GGUF_MULTI_FILE_DELIMITER,
15};
16
17pub struct LoaderBuilder {
19 model: ModelSelected,
20 no_kv_cache: bool,
21 chat_template: Option<String>,
22 jinja_explicit: Option<String>,
23 use_flash_attn: bool,
24 prompt_chunksize: Option<NonZeroUsize>,
25}
26
27impl LoaderBuilder {
28 pub fn new(model: ModelSelected) -> Self {
29 Self {
30 model,
31 no_kv_cache: false,
32 chat_template: None,
33 use_flash_attn: false,
34 prompt_chunksize: None,
35 jinja_explicit: None,
36 }
37 }
38
39 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
40 self.no_kv_cache = no_kv_cache;
41 self
42 }
43 pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
44 self.chat_template = chat_template;
45 self
46 }
47 pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
48 self.jinja_explicit = jinja_explicit;
49 self
50 }
51 pub fn with_use_flash_attn(mut self, use_flash_attn: bool) -> Self {
52 self.use_flash_attn = use_flash_attn;
53 self
54 }
55 pub fn with_prompt_chunksize(mut self, prompt_chunksize: Option<NonZeroUsize>) -> Self {
56 self.prompt_chunksize = prompt_chunksize;
57 self
58 }
59
60 pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
61 loader_from_model_selected(self)
62 }
63}
64
65pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
66 match model {
67 ModelSelected::Plain { .. }
68 | ModelSelected::Lora { .. }
69 | ModelSelected::GGUF { .. }
70 | ModelSelected::LoraGGUF { .. }
71 | ModelSelected::GGML { .. }
72 | ModelSelected::LoraGGML { .. }
73 | ModelSelected::Toml { .. }
74 | ModelSelected::VisionPlain { .. }
75 | ModelSelected::DiffusionPlain { .. } => None,
76 ModelSelected::XLora {
77 tgt_non_granular_index,
78 ..
79 }
80 | ModelSelected::XLoraGGUF {
81 tgt_non_granular_index,
82 ..
83 }
84 | ModelSelected::XLoraGGML {
85 tgt_non_granular_index,
86 ..
87 } => *tgt_non_granular_index,
88 }
89}
90
91pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
92 match model {
93 ModelSelected::Plain { dtype, .. }
94 | ModelSelected::Lora { dtype, .. }
95 | ModelSelected::XLora { dtype, .. }
96 | ModelSelected::VisionPlain { dtype, .. }
97 | ModelSelected::DiffusionPlain { dtype, .. }
98 | ModelSelected::GGML { dtype, .. }
99 | ModelSelected::GGUF { dtype, .. }
100 | ModelSelected::XLoraGGUF { dtype, .. }
101 | ModelSelected::XLoraGGML { dtype, .. }
102 | ModelSelected::LoraGGUF { dtype, .. }
103 | ModelSelected::LoraGGML { dtype, .. } => Ok(*dtype),
104 ModelSelected::Toml { file } => {
105 let selector: TomlSelector = toml::from_str(
106 &fs::read_to_string(file.clone())
107 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
108 )?;
109 Ok(get_toml_selected_model_dtype(&selector))
110 }
111 }
112}
113
114pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
115 match model {
116 ModelSelected::Plain {
117 max_seq_len,
118 max_batch_size,
119 ..
120 }
121 | ModelSelected::Lora {
122 max_seq_len,
123 max_batch_size,
124 ..
125 }
126 | ModelSelected::XLora {
127 max_seq_len,
128 max_batch_size,
129 ..
130 }
131 | ModelSelected::GGML {
132 max_seq_len,
133 max_batch_size,
134 ..
135 }
136 | ModelSelected::GGUF {
137 max_seq_len,
138 max_batch_size,
139 ..
140 }
141 | ModelSelected::XLoraGGUF {
142 max_seq_len,
143 max_batch_size,
144 ..
145 }
146 | ModelSelected::XLoraGGML {
147 max_seq_len,
148 max_batch_size,
149 ..
150 }
151 | ModelSelected::LoraGGUF {
152 max_seq_len,
153 max_batch_size,
154 ..
155 }
156 | ModelSelected::LoraGGML {
157 max_seq_len,
158 max_batch_size,
159 ..
160 } => Ok(AutoDeviceMapParams::Text {
161 max_seq_len: *max_seq_len,
162 max_batch_size: *max_batch_size,
163 }),
164 ModelSelected::VisionPlain {
165 max_seq_len,
166 max_batch_size,
167 max_image_length,
168 max_num_images,
169 ..
170 } => Ok(AutoDeviceMapParams::Vision {
171 max_seq_len: *max_seq_len,
172 max_batch_size: *max_batch_size,
173 max_image_shape: (*max_image_length, *max_image_length),
174 max_num_images: *max_num_images,
175 }),
176 ModelSelected::DiffusionPlain { .. } => Ok(AutoDeviceMapParams::default_text()),
177 ModelSelected::Toml { file } => {
178 let selector: TomlSelector = toml::from_str(
179 &fs::read_to_string(file.clone())
180 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
181 )?;
182 get_toml_selected_model_device_map_params(&selector)
183 }
184 }
185}
186
187fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
188 let use_flash_attn = args.use_flash_attn;
189 let loader: Box<dyn Loader> = match args.model {
190 ModelSelected::Toml { file } => {
191 let selector: TomlSelector = toml::from_str(
192 &fs::read_to_string(file.clone())
193 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
194 )?;
195 let args = TomlLoaderArgs {
196 use_flash_attn,
197 chat_template: args.chat_template,
198 no_kv_cache: args.no_kv_cache,
199 prompt_chunksize: args.prompt_chunksize,
200 jinja_explicit: args.jinja_explicit,
201 };
202 (selector, args).try_into()?
203 }
204 ModelSelected::Plain {
205 model_id,
206 tokenizer_json,
207 arch,
208 dtype: _,
209 topology,
210 organization,
211 write_uqff,
212 from_uqff,
213 imatrix,
214 calibration_file,
215 max_seq_len: _,
216 max_batch_size: _,
217 hf_cache_path,
218 } => NormalLoaderBuilder::new(
219 NormalSpecificConfig {
220 use_flash_attn,
221 prompt_chunksize: args.prompt_chunksize,
222 topology: Topology::from_option_path(topology)?,
223 organization: organization.unwrap_or_default(),
224 write_uqff,
225 from_uqff,
226 imatrix,
227 calibration_file,
228 hf_cache_path,
229 },
230 args.chat_template,
231 tokenizer_json,
232 Some(model_id),
233 args.no_kv_cache,
234 args.jinja_explicit,
235 )
236 .build(arch)?,
237 ModelSelected::XLora {
238 model_id,
239 xlora_model_id,
240 order,
241 tokenizer_json,
242 tgt_non_granular_index,
243 arch,
244 dtype: _,
245 topology,
246 write_uqff,
247 from_uqff,
248 max_seq_len: _,
249 max_batch_size: _,
250 hf_cache_path,
251 } => NormalLoaderBuilder::new(
252 NormalSpecificConfig {
253 use_flash_attn,
254 prompt_chunksize: args.prompt_chunksize,
255 topology: Topology::from_option_path(topology)?,
256 organization: Default::default(),
257 write_uqff,
258 from_uqff,
259 imatrix: None,
260 calibration_file: None,
261 hf_cache_path,
262 },
263 args.chat_template,
264 tokenizer_json,
265 model_id,
266 args.no_kv_cache,
267 args.jinja_explicit,
268 )
269 .with_xlora(
270 xlora_model_id,
271 serde_json::from_reader(
272 File::open(order.clone())
273 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
274 )?,
275 args.no_kv_cache,
276 tgt_non_granular_index,
277 )
278 .build(arch)?,
279 ModelSelected::Lora {
280 model_id,
281 tokenizer_json,
282 adapter_model_id,
283 arch,
284 dtype: _,
285 topology,
286 write_uqff,
287 from_uqff,
288 max_seq_len: _,
289 max_batch_size: _,
290 hf_cache_path,
291 } => NormalLoaderBuilder::new(
292 NormalSpecificConfig {
293 use_flash_attn,
294 prompt_chunksize: args.prompt_chunksize,
295 topology: Topology::from_option_path(topology)?,
296 organization: Default::default(),
297 write_uqff,
298 from_uqff,
299 imatrix: None,
300 calibration_file: None,
301 hf_cache_path,
302 },
303 args.chat_template,
304 tokenizer_json,
305 model_id,
306 args.no_kv_cache,
307 args.jinja_explicit,
308 )
309 .with_lora(
310 adapter_model_id
311 .split(MULTI_LORA_DELIMITER)
312 .map(ToString::to_string)
313 .collect(),
314 )
315 .build(arch)?,
316 ModelSelected::GGUF {
317 tok_model_id,
318 quantized_model_id,
319 quantized_filename,
320 topology,
321 ..
322 } => GGUFLoaderBuilder::new(
323 args.chat_template,
324 tok_model_id,
325 quantized_model_id,
326 quantized_filename
327 .split(GGUF_MULTI_FILE_DELIMITER)
328 .map(ToOwned::to_owned)
329 .collect::<Vec<_>>(),
330 GGUFSpecificConfig {
331 prompt_chunksize: args.prompt_chunksize,
332 topology: Topology::from_option_path(topology)?,
333 },
334 args.no_kv_cache,
335 args.jinja_explicit,
336 )
337 .build(),
338 ModelSelected::XLoraGGUF {
339 tok_model_id,
340 quantized_model_id,
341 quantized_filename,
342 xlora_model_id,
343 order,
344 tgt_non_granular_index,
345 topology,
346 ..
347 } => GGUFLoaderBuilder::new(
348 args.chat_template,
349 tok_model_id,
350 quantized_model_id,
351 quantized_filename
352 .split(GGUF_MULTI_FILE_DELIMITER)
353 .map(ToOwned::to_owned)
354 .collect::<Vec<_>>(),
355 GGUFSpecificConfig {
356 prompt_chunksize: args.prompt_chunksize,
357 topology: Topology::from_option_path(topology)?,
358 },
359 args.no_kv_cache,
360 args.jinja_explicit,
361 )
362 .with_xlora(
363 xlora_model_id,
364 serde_json::from_reader(
365 File::open(order.clone())
366 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
367 )?,
368 args.no_kv_cache,
369 tgt_non_granular_index,
370 )
371 .build(),
372 ModelSelected::LoraGGUF {
373 tok_model_id,
374 quantized_model_id,
375 quantized_filename,
376 adapters_model_id,
377 order,
378 topology,
379 ..
380 } => GGUFLoaderBuilder::new(
381 args.chat_template,
382 tok_model_id,
383 quantized_model_id,
384 quantized_filename
385 .split(GGUF_MULTI_FILE_DELIMITER)
386 .map(ToOwned::to_owned)
387 .collect::<Vec<_>>(),
388 GGUFSpecificConfig {
389 prompt_chunksize: args.prompt_chunksize,
390 topology: Topology::from_option_path(topology)?,
391 },
392 args.no_kv_cache,
393 args.jinja_explicit,
394 )
395 .with_lora(
396 adapters_model_id,
397 serde_json::from_reader(
398 File::open(order.clone())
399 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
400 )?,
401 )
402 .build(),
403 ModelSelected::GGML {
404 tok_model_id,
405 tokenizer_json,
406 quantized_model_id,
407 quantized_filename,
408 gqa,
409 topology,
410 ..
411 } => GGMLLoaderBuilder::new(
412 GGMLSpecificConfig {
413 gqa,
414 prompt_chunksize: args.prompt_chunksize,
415 topology: Topology::from_option_path(topology)?,
416 },
417 args.chat_template,
418 tokenizer_json,
419 Some(tok_model_id),
420 quantized_model_id,
421 quantized_filename,
422 args.no_kv_cache,
423 args.jinja_explicit,
424 )
425 .build(),
426 ModelSelected::XLoraGGML {
427 tok_model_id,
428 tokenizer_json,
429 quantized_model_id,
430 quantized_filename,
431 xlora_model_id,
432 order,
433 tgt_non_granular_index,
434 gqa,
435 topology,
436 ..
437 } => GGMLLoaderBuilder::new(
438 GGMLSpecificConfig {
439 gqa,
440 prompt_chunksize: args.prompt_chunksize,
441 topology: Topology::from_option_path(topology)?,
442 },
443 args.chat_template,
444 tokenizer_json,
445 tok_model_id,
446 quantized_model_id,
447 quantized_filename,
448 args.no_kv_cache,
449 args.jinja_explicit,
450 )
451 .with_xlora(
452 xlora_model_id,
453 serde_json::from_reader(
454 File::open(order.clone())
455 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
456 )?,
457 args.no_kv_cache,
458 tgt_non_granular_index,
459 )
460 .build(),
461 ModelSelected::LoraGGML {
462 tok_model_id,
463 tokenizer_json,
464 quantized_model_id,
465 quantized_filename,
466 adapters_model_id,
467 order,
468 gqa,
469 topology,
470 ..
471 } => GGMLLoaderBuilder::new(
472 GGMLSpecificConfig {
473 gqa,
474 prompt_chunksize: args.prompt_chunksize,
475 topology: Topology::from_option_path(topology)?,
476 },
477 args.chat_template,
478 tokenizer_json,
479 tok_model_id,
480 quantized_model_id,
481 quantized_filename,
482 args.no_kv_cache,
483 args.jinja_explicit,
484 )
485 .with_lora(
486 adapters_model_id,
487 serde_json::from_reader(
488 File::open(order.clone())
489 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
490 )?,
491 )
492 .build(),
493 ModelSelected::VisionPlain {
494 model_id,
495 tokenizer_json,
496 arch,
497 dtype: _,
498 topology,
499 write_uqff,
500 from_uqff,
501 max_edge,
502 calibration_file,
503 max_seq_len: _,
504 max_batch_size: _,
505 max_num_images: _,
506 max_image_length: _,
507 hf_cache_path,
508 imatrix,
509 } => VisionLoaderBuilder::new(
510 VisionSpecificConfig {
511 use_flash_attn,
512 prompt_chunksize: args.prompt_chunksize,
513 topology: Topology::from_option_path(topology)?,
514 write_uqff,
515 from_uqff,
516 max_edge,
517 calibration_file,
518 imatrix,
519 hf_cache_path,
520 },
521 args.chat_template,
522 tokenizer_json,
523 Some(model_id),
524 args.jinja_explicit,
525 )
526 .build(arch),
527 ModelSelected::DiffusionPlain {
528 model_id,
529 arch,
530 dtype: _,
531 } => {
532 DiffusionLoaderBuilder::new(DiffusionSpecificConfig { use_flash_attn }, Some(model_id))
533 .build(arch)
534 }
535 };
536 Ok(loader)
537}