1mod diffusion_loaders;
2mod normal_loaders;
3mod vision_loaders;
4
5use std::{
6 fmt::{self, Debug, Display},
7 path::PathBuf,
8 str::FromStr,
9 sync::Arc,
10};
11
12use anyhow::{Context, Result};
13use as_any::AsAny;
14use candle_core::{DType, Device};
15use itertools::Itertools;
16use mistralrs_quant::IsqType;
17use tokio::sync::Mutex;
18
19pub use normal_loaders::{
20 AutoLoader, DeepSeekV2Loader, DeepSeekV3Loader, Gemma2Loader, GemmaLoader, LlamaLoader,
21 MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata, NormalModel,
22 NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader,
23 Qwen3MoELoader, Starcoder2Loader,
24};
25
26use tracing::{info, warn};
27pub use vision_loaders::{
28 Gemma3Loader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, MiniCpmOLoader,
29 Mistral3Loader, Phi3VLoader, Phi4MMLoader, Qwen2VLLoader, Qwen2_5VLLoader, VLlama4Loader,
30 VLlamaLoader, VisionLoaderType, VisionModel, VisionModelLoader,
31};
32
33pub use diffusion_loaders::{
34 DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
35 DiffusionModelPathsInner, FluxLoader,
36};
37
38use crate::{
39 paged_attention::{
40 calculate_cache_config, ModelConfigLike, DEFAULT_PAGED_ATTENTION_BLOCK_SIZE,
41 },
42 utils::debug::DeviceRepr,
43 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, MemoryUsage, PagedAttentionConfig,
44 TryIntoDType,
45};
46
47use super::{paths::AdapterPaths, Pipeline};
48
49pub trait ModelPaths: AsAny + Debug + Send + Sync {
52 fn get_weight_filenames(&self) -> &[PathBuf];
54
55 fn get_config_filename(&self) -> &PathBuf;
59
60 fn get_tokenizer_filename(&self) -> &PathBuf;
64
65 fn get_template_filename(&self) -> &Option<PathBuf>;
69
70 fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
72
73 fn get_preprocessor_config(&self) -> &Option<PathBuf>;
75
76 fn get_processor_config(&self) -> &Option<PathBuf>;
78
79 fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
81
82 fn get_adapter_paths(&self) -> &AdapterPaths;
84}
85
86#[derive(Clone, Debug)]
87pub struct LocalModelPaths<P: Debug> {
89 pub tokenizer_filename: P,
90 pub config_filename: P,
91 pub template_filename: Option<P>,
92 pub filenames: Vec<P>,
93 pub adapter_paths: AdapterPaths,
94 pub gen_conf: Option<P>,
95 pub preprocessor_config: Option<P>,
96 pub processor_config: Option<P>,
97 pub chat_template_json_filename: Option<P>,
98}
99
100impl<P: Debug> LocalModelPaths<P> {
101 #[allow(clippy::too_many_arguments)]
102 pub fn new(
103 tokenizer_filename: P,
104 config_filename: P,
105 template_filename: P,
106 filenames: Vec<P>,
107 adapter_paths: AdapterPaths,
108 gen_conf: Option<P>,
109 preprocessor_config: Option<P>,
110 processor_config: Option<P>,
111 chat_template_json_filename: Option<P>,
112 ) -> Self {
113 Self {
114 tokenizer_filename,
115 config_filename,
116 template_filename: Some(template_filename),
117 filenames,
118 adapter_paths,
119 gen_conf,
120 preprocessor_config,
121 processor_config,
122 chat_template_json_filename,
123 }
124 }
125}
126
127impl ModelPaths for LocalModelPaths<PathBuf> {
128 fn get_config_filename(&self) -> &PathBuf {
129 &self.config_filename
130 }
131 fn get_tokenizer_filename(&self) -> &PathBuf {
132 &self.tokenizer_filename
133 }
134 fn get_weight_filenames(&self) -> &[PathBuf] {
135 &self.filenames
136 }
137 fn get_template_filename(&self) -> &Option<PathBuf> {
138 &self.template_filename
139 }
140 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
141 self.gen_conf.as_ref()
142 }
143 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
144 &self.preprocessor_config
145 }
146 fn get_processor_config(&self) -> &Option<PathBuf> {
147 &self.processor_config
148 }
149 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
150 &self.chat_template_json_filename
151 }
152 fn get_adapter_paths(&self) -> &AdapterPaths {
153 &self.adapter_paths
154 }
155}
156
157#[derive(Debug, Clone)]
158pub enum TokenSource {
160 Literal(String),
161 EnvVar(String),
162 Path(String),
163 CacheToken,
164 None,
165}
166
167impl FromStr for TokenSource {
168 type Err = String;
169
170 fn from_str(s: &str) -> Result<Self, Self::Err> {
171 let parts: Vec<&str> = s.splitn(2, ':').collect();
172 match parts[0] {
173 "literal" => parts
174 .get(1)
175 .map(|&value| TokenSource::Literal(value.to_string()))
176 .ok_or_else(|| "Expected a value for 'literal'".to_string()),
177 "env" => Ok(TokenSource::EnvVar(
178 parts
179 .get(1)
180 .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
181 .to_string(),
182 )),
183 "path" => parts
184 .get(1)
185 .map(|&value| TokenSource::Path(value.to_string()))
186 .ok_or_else(|| "Expected a value for 'path'".to_string()),
187 "cache" => Ok(TokenSource::CacheToken),
188 "none" => Ok(TokenSource::None),
189 _ => Err("Invalid token source format".to_string()),
190 }
191 }
192}
193
194impl fmt::Display for TokenSource {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 match self {
197 TokenSource::Literal(value) => write!(f, "literal:{}", value),
198 TokenSource::EnvVar(value) => write!(f, "env:{}", value),
199 TokenSource::Path(value) => write!(f, "path:{}", value),
200 TokenSource::CacheToken => write!(f, "cache"),
201 TokenSource::None => write!(f, "none"),
202 }
203 }
204}
205
206#[derive(Clone, Default, derive_more::From, strum::Display)]
208pub enum ModelKind {
209 #[default]
210 #[strum(to_string = "normal (no adapters)")]
211 Normal,
212
213 #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
214 GgufQuantized { quant: QuantizationKind },
215
216 #[strum(to_string = "{adapter}")]
217 Adapter { adapter: AdapterKind },
218
219 #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
220 GgufAdapter {
221 adapter: AdapterKind,
222 quant: QuantizationKind,
223 },
224
225 #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
226 Speculative {
227 target: Box<ModelKind>,
228 draft: Box<ModelKind>,
229 },
230
231 #[strum(to_string = "anymoe: target: `{target}`")]
232 AnyMoe { target: Box<ModelKind> },
233}
234
235#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
236#[strum(serialize_all = "kebab-case")]
237pub enum QuantizationKind {
238 Ggml,
240 Gguf,
242 Gptq,
244}
245
246#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
247#[strum(serialize_all = "kebab-case")]
248pub enum AdapterKind {
249 Lora,
251 XLora,
253}
254
255pub trait PrettyName: strum::EnumMessage + ToString {
257 fn pretty_name(&self) -> String {
258 match self.get_documentation() {
259 Some(s) => s.to_string(),
260 None => self.to_string(),
263 }
264 }
265}
266
267impl PrettyName for AdapterKind {}
268impl PrettyName for QuantizationKind {}
269
270impl ModelKind {
271 pub fn is_quantized(&self) -> bool {
273 self.quantized_kind().iter().any(|q| q.is_some())
274 }
275
276 pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
277 self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
278 }
279
280 pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
281 use ModelKind::*;
282
283 match self {
284 Normal | Adapter { .. } => vec![None],
285 GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
286 Speculative { target, draft } => {
287 let t = *target.clone();
288 let d = *draft.clone();
289
290 [t.quantized_kind(), d.quantized_kind()].concat()
291 }
292 AnyMoe { target } => target.quantized_kind(),
293 }
294 }
295
296 pub fn is_adapted(&self) -> bool {
298 self.adapted_kind().iter().any(|a| a.is_some())
299 }
300
301 pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
302 self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
303 }
304
305 pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
306 use ModelKind::*;
307
308 match self {
309 Normal | GgufQuantized { .. } => vec![None],
310 Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
311 Speculative { target, draft } => {
312 let t = *target.clone();
313 let d = *draft.clone();
314
315 [t.adapted_kind(), d.adapted_kind()].concat()
316 }
317 AnyMoe { target } => target.adapted_kind(),
318 }
319 }
320}
321
322macro_rules! b_to_mb {
323 ($x:expr) => {
324 $x / (1024 * 1024)
325 };
326}
327
328#[derive(Debug, Clone)]
329pub enum AutoDeviceMapParams {
330 Text {
331 max_seq_len: usize,
332 max_batch_size: usize,
333 },
334 Vision {
335 max_seq_len: usize,
336 max_batch_size: usize,
337 max_image_shape: (usize, usize),
338 max_num_images: usize,
339 },
340}
341
342impl AutoDeviceMapParams {
343 pub fn max_seq_len(&self) -> usize {
344 match self {
345 Self::Text { max_seq_len, .. } | Self::Vision { max_seq_len, .. } => *max_seq_len,
346 }
347 }
348}
349
350impl Display for AutoDeviceMapParams {
351 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352 match self {
353 Self::Text {
354 max_seq_len,
355 max_batch_size,
356 } => write!(
357 f,
358 "text[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}]"
359 ),
360 Self::Vision {
361 max_seq_len,
362 max_batch_size,
363 max_image_shape,
364 max_num_images
365 } => write!(
366 f,
367 "vision[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}, max_image_shape: {max_image_shape:?}, max_num_images: {max_num_images}]"
368 ),
369 }
370 }
371}
372
373impl AutoDeviceMapParams {
374 pub const DEFAULT_MAX_SEQ_LEN: usize = 4 * 1024;
375 pub const DEFAULT_MAX_BATCH_SIZE: usize = 1;
376 pub const DEFAULT_MAX_NUM_IMAGES: usize = 1;
377 pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 1024;
378
379 pub fn default_text() -> Self {
380 Self::Text {
381 max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
382 max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
383 }
384 }
385
386 pub fn default_vision() -> Self {
387 Self::Vision {
388 max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
389 max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
390 max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
391 max_image_shape: (
392 Self::DEFAULT_MAX_IMAGE_LENGTH,
393 Self::DEFAULT_MAX_IMAGE_LENGTH,
394 ),
395 }
396 }
397}
398
399#[derive(Clone, Debug)]
400pub(crate) enum NonMappedSubModel {
401 Vision,
402}
403
404impl Display for NonMappedSubModel {
405 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
406 match self {
407 Self::Vision => write!(f, "vision"),
408 }
409 }
410}
411
412fn calculate_key_block_shape(
413 model_config: &dyn ModelConfigLike,
414 dtype: DType,
415 block_size: usize,
416) -> (usize, usize, usize, usize) {
417 let element_size = dtype.size_in_bytes();
418 let x = 16 / element_size;
419 (
420 model_config.num_kv_heads(),
421 model_config.k_head_dim() / x,
422 block_size,
423 x,
424 )
425}
426
427fn calculate_value_block_shape(
428 model_config: &dyn ModelConfigLike,
429 block_size: usize,
430) -> (usize, usize, usize) {
431 (
432 model_config.num_kv_heads(),
433 model_config.v_head_dim(),
434 block_size,
435 )
436}
437
438pub trait DeviceMappedModelLoader {
439 fn non_mapped_max_act_size_elems(
442 &self,
443 config: &str,
444 params: &AutoDeviceMapParams,
445 ) -> Result<usize>;
446 fn mapped_max_act_size_elems(
448 &self,
449 config: &str,
450 params: &AutoDeviceMapParams,
451 prompt_chunksize: usize,
452 ) -> Result<usize>;
453 fn non_mapped_size_in_bytes(
455 &self,
456 config: &str,
457 dtype: DType,
458 weight_pack_factor: usize,
459 ) -> Result<usize>;
460 fn layer_sizes_in_bytes(
462 &self,
463 config: &str,
464 dtype: DType,
465 weight_pack_factor: usize,
466 ) -> Result<Vec<usize>>;
467 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
468 None
469 }
470 fn num_layers(&self, config: &str) -> Result<usize>;
471 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
472
473 #[allow(clippy::too_many_arguments)]
474 fn get_device_layers(
475 &self,
476 config: &str,
477 num_layers: usize,
478 mut layer_sizes_in_bytes: Vec<usize>,
479 non_mapped_size_in_bytes: usize,
480 total_model_size_in_bytes: usize,
481 devices: &[Device],
482 dtype: DType,
483 params: &AutoDeviceMapParams,
484 prompt_chunksize: usize,
485 paged_attn_config: Option<&PagedAttentionConfig>,
486 ) -> Result<DeviceMapMetadata> {
487 let mapped_max_act_size_in_bytes =
488 self.mapped_max_act_size_elems(config, params, prompt_chunksize)?
489 * dtype.size_in_bytes();
490 let non_mapped_max_act_size_in_bytes =
491 self.non_mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
492
493 let mut remaining_to_map = total_model_size_in_bytes;
494
495 let max_seq_len = match params {
496 AutoDeviceMapParams::Text { max_seq_len, .. }
497 | AutoDeviceMapParams::Vision { max_seq_len, .. } => *max_seq_len,
498 };
499 let max_batch_size = match params {
500 AutoDeviceMapParams::Text { max_batch_size, .. }
501 | AutoDeviceMapParams::Vision { max_batch_size, .. } => *max_batch_size,
502 };
503
504 let model_cfg = self.model_config(config)?;
505 let kv_cache_size_elems = match paged_attn_config {
506 Some(paged_attn_config) => {
507 let cache_config = calculate_cache_config(
508 paged_attn_config.mem_gpu,
509 paged_attn_config.mem_cpu,
510 Some(
511 paged_attn_config
512 .block_size
513 .unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE),
514 ),
515 dtype,
516 &*model_cfg,
517 &devices[0],
518 &devices.iter().map(|x| Some(x.clone())).collect::<Vec<_>>(),
519 true,
520 )?;
521
522 let key_block_shape =
523 calculate_key_block_shape(&*model_cfg, dtype, cache_config.block_size);
524 let key_block_size = cache_config.num_gpu_blocks
525 * key_block_shape.0
526 * key_block_shape.1
527 * key_block_shape.2
528 * key_block_shape.3;
529
530 let value_block_shape = calculate_value_block_shape(
531 &*self.model_config(config)?,
532 cache_config.block_size,
533 );
534 let value_block_size = cache_config.num_gpu_blocks
535 * value_block_shape.0
536 * value_block_shape.1
537 * value_block_shape.2;
538
539 key_block_size + value_block_size
540 }
541 None => {
542 let key_block_shape = [
544 max_batch_size,
545 model_cfg.num_kv_heads(),
546 max_seq_len,
547 model_cfg.k_head_dim(),
548 ];
549 let value_block_shape = [
550 max_batch_size,
551 model_cfg.num_kv_heads(),
552 max_seq_len,
553 model_cfg.v_head_dim(),
554 ];
555
556 key_block_shape.into_iter().product::<usize>()
557 + value_block_shape.iter().product::<usize>()
558 }
559 };
560 let kv_cache_size_in_bytes = kv_cache_size_elems * dtype.size_in_bytes();
561
562 let mut per_layer_avail = Vec::new();
563 for dev in [devices, &[Device::Cpu]].concat() {
565 let avail = MemoryUsage.get_memory_available(&dev)?;
566 per_layer_avail.push((avail, dev));
567 }
568 per_layer_avail.reverse();
570
571 layer_sizes_in_bytes.reverse();
573
574 let mut device_layers = Vec::new();
575
576 info!("Using automatic device mapping parameters: {params}.");
577 if let Some(sub_models) = self.non_mapped_sub_models() {
578 let (_, last) = per_layer_avail.last().unwrap();
579 info!(
580 "The following sub-models will not be device mapped and will be loaded on {}: {}",
581 last.device_pretty_repr(),
582 sub_models.iter().map(|x| x.to_string()).join(", ")
583 );
584 }
585
586 let mut current_ordinal = 0;
587 let mut current_layer = 0;
588 let per_layer_avail_cpy = per_layer_avail.clone();
589 let mut mapping_includes_cpu = false;
590 while remaining_to_map > 0 && !per_layer_avail.is_empty() {
591 let (device_capacity, device) = per_layer_avail
592 .pop()
593 .context("No more devices to map to. The model does not fit on this system.")?;
594 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
596 let device_capacity = (device_capacity as f64 * 0.90) as usize;
597
598 #[allow(clippy::if_same_then_else)]
605 let layers_on_device = if current_ordinal == 0
606 && device_capacity
607 >= remaining_to_map
608 + non_mapped_max_act_size_in_bytes.max(mapped_max_act_size_in_bytes)
609 + non_mapped_size_in_bytes
610 + kv_cache_size_in_bytes * (num_layers - current_layer)
611 {
612 remaining_to_map = 0;
613
614 num_layers - current_layer
615 } else if current_ordinal != 0
616 && device_capacity
617 >= remaining_to_map
618 + mapped_max_act_size_in_bytes
619 + kv_cache_size_in_bytes * (num_layers - current_layer)
620 {
621 remaining_to_map = 0;
622
623 num_layers - current_layer
624 } else {
625 let mut used_capacity = mapped_max_act_size_in_bytes;
627 let mut used_capacity_no_act = 0;
628 let mut layers_on_device = 0;
629
630 if current_ordinal == 0 {
632 used_capacity = used_capacity.max(non_mapped_max_act_size_in_bytes);
634 used_capacity += non_mapped_size_in_bytes;
635 used_capacity_no_act += non_mapped_size_in_bytes;
636 }
637
638 while let Some(&last) = layer_sizes_in_bytes.last() {
639 let delta = last + kv_cache_size_in_bytes;
640 if used_capacity + delta > device_capacity {
641 break;
642 }
643 let _ = layer_sizes_in_bytes.pop().unwrap();
644 used_capacity += delta;
645 used_capacity_no_act += delta;
646 layers_on_device += 1;
647 }
648
649 if layers_on_device > 0 {
652 remaining_to_map = remaining_to_map.saturating_sub(used_capacity_no_act);
653 } else {
654 warn!(
655 "Device {} can fit 0 layers. Consider reducing auto map params from current: {params} (ex. reducing max seq len or max num images)",
656 device.device_pretty_repr(),
657 );
658 current_ordinal += 1;
659 continue;
660 }
661 layers_on_device
662 };
663
664 if !device.is_cpu() {
666 device_layers.push(DeviceLayerMapMetadata {
667 ordinal: current_ordinal,
668 layers: layers_on_device,
669 });
670 current_ordinal += 1;
671 } else {
672 mapping_includes_cpu = true;
673 }
674
675 current_layer += layers_on_device;
676 }
677 if remaining_to_map > 0 {
678 anyhow::bail!(
679 "This model does not fit on the devices {:?}, and exceeds total capacity by {}MB. Auto device mapping params: {params}",
680 per_layer_avail_cpy
681 .iter()
682 .rev()
683 .map(|(avail, dev)| format!(
684 "{} (avail: {}MB)",
685 dev.device_pretty_repr(),
686 avail / (1024 * 1024),
687 ))
688 .collect::<Vec<_>>(),
689 b_to_mb!(remaining_to_map)
690 );
691 }
692
693 if paged_attn_config.is_some_and(|_| mapping_includes_cpu) {
696 return self.get_device_layers(
697 config,
698 num_layers,
699 layer_sizes_in_bytes,
700 non_mapped_size_in_bytes,
701 total_model_size_in_bytes,
702 devices,
703 dtype,
704 params,
705 prompt_chunksize,
706 None,
707 );
708 }
709
710 Ok(DeviceMapMetadata::from_num_device_layers(device_layers))
711 }
712}
713
714pub trait Loader: Send + Sync {
735 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
739 fn load_model_from_hf(
740 &self,
741 revision: Option<String>,
742 token_source: TokenSource,
743 dtype: &dyn TryIntoDType,
744 device: &Device,
745 silent: bool,
746 mapper: DeviceMapSetting,
747 in_situ_quant: Option<IsqType>,
748 paged_attn_config: Option<PagedAttentionConfig>,
749 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
750
751 #[allow(
754 clippy::type_complexity,
755 clippy::too_many_arguments,
756 clippy::borrowed_box
757 )]
758 fn load_model_from_path(
759 &self,
760 paths: &Box<dyn ModelPaths>,
761 dtype: &dyn TryIntoDType,
762 device: &Device,
763 silent: bool,
764 mapper: DeviceMapSetting,
765 in_situ_quant: Option<IsqType>,
766 paged_attn_config: Option<PagedAttentionConfig>,
767 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
768
769 fn get_id(&self) -> String;
770 fn get_kind(&self) -> ModelKind;
771}