1mod diffusion_loaders;
2mod normal_loaders;
3mod vision_loaders;
4
5use std::{
6 fmt::{self, Debug, Display},
7 path::PathBuf,
8 str::FromStr,
9 sync::Arc,
10};
11
12use anyhow::{Context, Result};
13use as_any::AsAny;
14use candle_core::{DType, Device};
15use itertools::Itertools;
16use mistralrs_quant::IsqType;
17use tokio::sync::Mutex;
18
19pub use normal_loaders::{
20 AutoLoader, DeepSeekV2Loader, DeepSeekV3Loader, Gemma2Loader, GemmaLoader, LlamaLoader,
21 MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata, NormalModel,
22 NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Starcoder2Loader,
23};
24
25use tracing::{info, warn};
26pub use vision_loaders::{
27 Gemma3Loader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, MiniCpmOLoader,
28 Mistral3Loader, Phi3VLoader, Phi4MMLoader, Qwen2VLLoader, Qwen2_5VLLoader, VLlamaLoader,
29 VisionLoaderType, VisionModel, VisionModelLoader,
30};
31
32pub use diffusion_loaders::{
33 DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
34 DiffusionModelPathsInner, FluxLoader,
35};
36
37use crate::{
38 paged_attention::{
39 calculate_cache_config, ModelConfigLike, DEFAULT_PAGED_ATTENTION_BLOCK_SIZE,
40 },
41 utils::debug::DeviceRepr,
42 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, MemoryUsage, PagedAttentionConfig,
43 TryIntoDType,
44};
45
46use super::{paths::AdapterPaths, Pipeline};
47
48pub trait ModelPaths: AsAny + Debug + Send + Sync {
51 fn get_weight_filenames(&self) -> &[PathBuf];
53
54 fn get_config_filename(&self) -> &PathBuf;
58
59 fn get_tokenizer_filename(&self) -> &PathBuf;
63
64 fn get_template_filename(&self) -> &Option<PathBuf>;
68
69 fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
71
72 fn get_preprocessor_config(&self) -> &Option<PathBuf>;
74
75 fn get_processor_config(&self) -> &Option<PathBuf>;
77
78 fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
80
81 fn get_adapter_paths(&self) -> &AdapterPaths;
83}
84
85#[derive(Clone, Debug)]
86pub struct LocalModelPaths<P: Debug> {
88 pub tokenizer_filename: P,
89 pub config_filename: P,
90 pub template_filename: Option<P>,
91 pub filenames: Vec<P>,
92 pub adapter_paths: AdapterPaths,
93 pub gen_conf: Option<P>,
94 pub preprocessor_config: Option<P>,
95 pub processor_config: Option<P>,
96 pub chat_template_json_filename: Option<P>,
97}
98
99impl<P: Debug> LocalModelPaths<P> {
100 #[allow(clippy::too_many_arguments)]
101 pub fn new(
102 tokenizer_filename: P,
103 config_filename: P,
104 template_filename: P,
105 filenames: Vec<P>,
106 adapter_paths: AdapterPaths,
107 gen_conf: Option<P>,
108 preprocessor_config: Option<P>,
109 processor_config: Option<P>,
110 chat_template_json_filename: Option<P>,
111 ) -> Self {
112 Self {
113 tokenizer_filename,
114 config_filename,
115 template_filename: Some(template_filename),
116 filenames,
117 adapter_paths,
118 gen_conf,
119 preprocessor_config,
120 processor_config,
121 chat_template_json_filename,
122 }
123 }
124}
125
126impl ModelPaths for LocalModelPaths<PathBuf> {
127 fn get_config_filename(&self) -> &PathBuf {
128 &self.config_filename
129 }
130 fn get_tokenizer_filename(&self) -> &PathBuf {
131 &self.tokenizer_filename
132 }
133 fn get_weight_filenames(&self) -> &[PathBuf] {
134 &self.filenames
135 }
136 fn get_template_filename(&self) -> &Option<PathBuf> {
137 &self.template_filename
138 }
139 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
140 self.gen_conf.as_ref()
141 }
142 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
143 &self.preprocessor_config
144 }
145 fn get_processor_config(&self) -> &Option<PathBuf> {
146 &self.processor_config
147 }
148 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
149 &self.chat_template_json_filename
150 }
151 fn get_adapter_paths(&self) -> &AdapterPaths {
152 &self.adapter_paths
153 }
154}
155
156#[derive(Debug, Clone)]
157pub enum TokenSource {
159 Literal(String),
160 EnvVar(String),
161 Path(String),
162 CacheToken,
163 None,
164}
165
166impl FromStr for TokenSource {
167 type Err = String;
168
169 fn from_str(s: &str) -> Result<Self, Self::Err> {
170 let parts: Vec<&str> = s.splitn(2, ':').collect();
171 match parts[0] {
172 "literal" => parts
173 .get(1)
174 .map(|&value| TokenSource::Literal(value.to_string()))
175 .ok_or_else(|| "Expected a value for 'literal'".to_string()),
176 "env" => Ok(TokenSource::EnvVar(
177 parts
178 .get(1)
179 .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
180 .to_string(),
181 )),
182 "path" => parts
183 .get(1)
184 .map(|&value| TokenSource::Path(value.to_string()))
185 .ok_or_else(|| "Expected a value for 'path'".to_string()),
186 "cache" => Ok(TokenSource::CacheToken),
187 "none" => Ok(TokenSource::None),
188 _ => Err("Invalid token source format".to_string()),
189 }
190 }
191}
192
193impl fmt::Display for TokenSource {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 match self {
196 TokenSource::Literal(value) => write!(f, "literal:{}", value),
197 TokenSource::EnvVar(value) => write!(f, "env:{}", value),
198 TokenSource::Path(value) => write!(f, "path:{}", value),
199 TokenSource::CacheToken => write!(f, "cache"),
200 TokenSource::None => write!(f, "none"),
201 }
202 }
203}
204
205#[derive(Clone, Default, derive_more::From, strum::Display)]
207pub enum ModelKind {
208 #[default]
209 #[strum(to_string = "normal (no adapters)")]
210 Normal,
211
212 #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
213 GgufQuantized { quant: QuantizationKind },
214
215 #[strum(to_string = "{adapter}")]
216 Adapter { adapter: AdapterKind },
217
218 #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
219 GgufAdapter {
220 adapter: AdapterKind,
221 quant: QuantizationKind,
222 },
223
224 #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
225 Speculative {
226 target: Box<ModelKind>,
227 draft: Box<ModelKind>,
228 },
229
230 #[strum(to_string = "anymoe: target: `{target}`")]
231 AnyMoe { target: Box<ModelKind> },
232}
233
234#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
235#[strum(serialize_all = "kebab-case")]
236pub enum QuantizationKind {
237 Ggml,
239 Gguf,
241 Gptq,
243}
244
245#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
246#[strum(serialize_all = "kebab-case")]
247pub enum AdapterKind {
248 Lora,
250 XLora,
252}
253
254pub trait PrettyName: strum::EnumMessage + ToString {
256 fn pretty_name(&self) -> String {
257 match self.get_documentation() {
258 Some(s) => s.to_string(),
259 None => self.to_string(),
262 }
263 }
264}
265
266impl PrettyName for AdapterKind {}
267impl PrettyName for QuantizationKind {}
268
269impl ModelKind {
270 pub fn is_quantized(&self) -> bool {
272 self.quantized_kind().iter().any(|q| q.is_some())
273 }
274
275 pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
276 self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
277 }
278
279 pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
280 use ModelKind::*;
281
282 match self {
283 Normal | Adapter { .. } => vec![None],
284 GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
285 Speculative { target, draft } => {
286 let t = *target.clone();
287 let d = *draft.clone();
288
289 [t.quantized_kind(), d.quantized_kind()].concat()
290 }
291 AnyMoe { target } => target.quantized_kind(),
292 }
293 }
294
295 pub fn is_adapted(&self) -> bool {
297 self.adapted_kind().iter().any(|a| a.is_some())
298 }
299
300 pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
301 self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
302 }
303
304 pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
305 use ModelKind::*;
306
307 match self {
308 Normal | GgufQuantized { .. } => vec![None],
309 Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
310 Speculative { target, draft } => {
311 let t = *target.clone();
312 let d = *draft.clone();
313
314 [t.adapted_kind(), d.adapted_kind()].concat()
315 }
316 AnyMoe { target } => target.adapted_kind(),
317 }
318 }
319}
320
321macro_rules! b_to_mb {
322 ($x:expr) => {
323 $x / (1024 * 1024)
324 };
325}
326
327#[derive(Debug, Clone)]
328pub enum AutoDeviceMapParams {
329 Text {
330 max_seq_len: usize,
331 max_batch_size: usize,
332 },
333 Vision {
334 max_seq_len: usize,
335 max_batch_size: usize,
336 max_image_shape: (usize, usize),
337 max_num_images: usize,
338 },
339}
340
341impl AutoDeviceMapParams {
342 pub fn max_seq_len(&self) -> usize {
343 match self {
344 Self::Text { max_seq_len, .. } | Self::Vision { max_seq_len, .. } => *max_seq_len,
345 }
346 }
347}
348
349impl Display for AutoDeviceMapParams {
350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351 match self {
352 Self::Text {
353 max_seq_len,
354 max_batch_size,
355 } => write!(
356 f,
357 "text[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}]"
358 ),
359 Self::Vision {
360 max_seq_len,
361 max_batch_size,
362 max_image_shape,
363 max_num_images
364 } => write!(
365 f,
366 "vision[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}, max_image_shape: {max_image_shape:?}, max_num_images: {max_num_images}]"
367 ),
368 }
369 }
370}
371
372impl AutoDeviceMapParams {
373 pub const DEFAULT_MAX_SEQ_LEN: usize = 4 * 1024;
374 pub const DEFAULT_MAX_BATCH_SIZE: usize = 1;
375 pub const DEFAULT_MAX_NUM_IMAGES: usize = 1;
376 pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 1024;
377
378 pub fn default_text() -> Self {
379 Self::Text {
380 max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
381 max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
382 }
383 }
384
385 pub fn default_vision() -> Self {
386 Self::Vision {
387 max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
388 max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
389 max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
390 max_image_shape: (
391 Self::DEFAULT_MAX_IMAGE_LENGTH,
392 Self::DEFAULT_MAX_IMAGE_LENGTH,
393 ),
394 }
395 }
396}
397
398#[derive(Clone, Debug)]
399pub(crate) enum NonMappedSubModel {
400 Vision,
401}
402
403impl Display for NonMappedSubModel {
404 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
405 match self {
406 Self::Vision => write!(f, "vision"),
407 }
408 }
409}
410
411fn calculate_key_block_shape(
412 model_config: &dyn ModelConfigLike,
413 dtype: DType,
414 block_size: usize,
415) -> (usize, usize, usize, usize) {
416 let element_size = dtype.size_in_bytes();
417 let x = 16 / element_size;
418 (
419 model_config.num_kv_heads(),
420 model_config.k_head_dim() / x,
421 block_size,
422 x,
423 )
424}
425
426fn calculate_value_block_shape(
427 model_config: &dyn ModelConfigLike,
428 block_size: usize,
429) -> (usize, usize, usize) {
430 (
431 model_config.num_kv_heads(),
432 model_config.v_head_dim(),
433 block_size,
434 )
435}
436
437pub trait DeviceMappedModelLoader {
438 fn non_mapped_max_act_size_elems(
441 &self,
442 config: &str,
443 params: &AutoDeviceMapParams,
444 ) -> Result<usize>;
445 fn mapped_max_act_size_elems(
447 &self,
448 config: &str,
449 params: &AutoDeviceMapParams,
450 prompt_chunksize: usize,
451 ) -> Result<usize>;
452 fn non_mapped_size_in_bytes(
454 &self,
455 config: &str,
456 dtype: DType,
457 weight_pack_factor: usize,
458 ) -> Result<usize>;
459 fn layer_sizes_in_bytes(
461 &self,
462 config: &str,
463 dtype: DType,
464 weight_pack_factor: usize,
465 ) -> Result<Vec<usize>>;
466 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
467 None
468 }
469 fn num_layers(&self, config: &str) -> Result<usize>;
470 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
471
472 #[allow(clippy::too_many_arguments)]
473 fn get_device_layers(
474 &self,
475 config: &str,
476 num_layers: usize,
477 mut layer_sizes_in_bytes: Vec<usize>,
478 non_mapped_size_in_bytes: usize,
479 total_model_size_in_bytes: usize,
480 devices: &[Device],
481 dtype: DType,
482 params: &AutoDeviceMapParams,
483 prompt_chunksize: usize,
484 paged_attn_config: Option<&PagedAttentionConfig>,
485 ) -> Result<DeviceMapMetadata> {
486 let mapped_max_act_size_in_bytes =
487 self.mapped_max_act_size_elems(config, params, prompt_chunksize)?
488 * dtype.size_in_bytes();
489 let non_mapped_max_act_size_in_bytes =
490 self.non_mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
491
492 let mut remaining_to_map = total_model_size_in_bytes;
493
494 let max_seq_len = match params {
495 AutoDeviceMapParams::Text { max_seq_len, .. }
496 | AutoDeviceMapParams::Vision { max_seq_len, .. } => *max_seq_len,
497 };
498 let max_batch_size = match params {
499 AutoDeviceMapParams::Text { max_batch_size, .. }
500 | AutoDeviceMapParams::Vision { max_batch_size, .. } => *max_batch_size,
501 };
502
503 let model_cfg = self.model_config(config)?;
504 let kv_cache_size_elems = match paged_attn_config {
505 Some(paged_attn_config) => {
506 let cache_config = calculate_cache_config(
507 paged_attn_config.mem_gpu,
508 paged_attn_config.mem_cpu,
509 Some(
510 paged_attn_config
511 .block_size
512 .unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE),
513 ),
514 dtype,
515 &*model_cfg,
516 &devices[0],
517 &devices.iter().map(|x| Some(x.clone())).collect::<Vec<_>>(),
518 true,
519 )?;
520
521 let key_block_shape =
522 calculate_key_block_shape(&*model_cfg, dtype, cache_config.block_size);
523 let key_block_size = cache_config.num_gpu_blocks
524 * key_block_shape.0
525 * key_block_shape.1
526 * key_block_shape.2
527 * key_block_shape.3;
528
529 let value_block_shape = calculate_value_block_shape(
530 &*self.model_config(config)?,
531 cache_config.block_size,
532 );
533 let value_block_size = cache_config.num_gpu_blocks
534 * value_block_shape.0
535 * value_block_shape.1
536 * value_block_shape.2;
537
538 key_block_size + value_block_size
539 }
540 None => {
541 let key_block_shape = [
543 max_batch_size,
544 model_cfg.num_kv_heads(),
545 max_seq_len,
546 model_cfg.k_head_dim(),
547 ];
548 let value_block_shape = [
549 max_batch_size,
550 model_cfg.num_kv_heads(),
551 max_seq_len,
552 model_cfg.v_head_dim(),
553 ];
554
555 key_block_shape.into_iter().product::<usize>()
556 + value_block_shape.iter().product::<usize>()
557 }
558 };
559 let kv_cache_size_in_bytes = kv_cache_size_elems * dtype.size_in_bytes();
560
561 let mut per_layer_avail = Vec::new();
562 for dev in [devices, &[Device::Cpu]].concat() {
564 let avail = MemoryUsage.get_memory_available(&dev)?;
565 per_layer_avail.push((avail, dev));
566 }
567 per_layer_avail.reverse();
569
570 layer_sizes_in_bytes.reverse();
572
573 let mut device_layers = Vec::new();
574
575 info!("Using automatic device mapping parameters: {params}.");
576 if let Some(sub_models) = self.non_mapped_sub_models() {
577 let (_, last) = per_layer_avail.last().unwrap();
578 info!(
579 "The following sub-models will not be device mapped and will be loaded on {}: {}",
580 last.device_pretty_repr(),
581 sub_models.iter().map(|x| x.to_string()).join(", ")
582 );
583 }
584
585 let mut current_ordinal = 0;
586 let mut current_layer = 0;
587 let per_layer_avail_cpy = per_layer_avail.clone();
588 let mut mapping_includes_cpu = false;
589 while remaining_to_map > 0 && !per_layer_avail.is_empty() {
590 let (device_capacity, device) = per_layer_avail
591 .pop()
592 .context("No more devices to map to. The model does not fit on this system.")?;
593 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
595 let device_capacity = (device_capacity as f64 * 0.90) as usize;
596
597 #[allow(clippy::if_same_then_else)]
604 let layers_on_device = if current_ordinal == 0
605 && device_capacity
606 >= remaining_to_map
607 + non_mapped_max_act_size_in_bytes.max(mapped_max_act_size_in_bytes)
608 + non_mapped_size_in_bytes
609 + kv_cache_size_in_bytes * (num_layers - current_layer)
610 {
611 remaining_to_map = 0;
612
613 num_layers - current_layer
614 } else if current_ordinal != 0
615 && device_capacity
616 >= remaining_to_map
617 + mapped_max_act_size_in_bytes
618 + kv_cache_size_in_bytes * (num_layers - current_layer)
619 {
620 remaining_to_map = 0;
621
622 num_layers - current_layer
623 } else {
624 let mut used_capacity = mapped_max_act_size_in_bytes;
626 let mut used_capacity_no_act = 0;
627 let mut layers_on_device = 0;
628
629 if current_ordinal == 0 {
631 used_capacity = used_capacity.max(non_mapped_max_act_size_in_bytes);
633 used_capacity += non_mapped_size_in_bytes;
634 used_capacity_no_act += non_mapped_size_in_bytes;
635 }
636
637 while let Some(&last) = layer_sizes_in_bytes.last() {
638 let delta = last + kv_cache_size_in_bytes;
639 if used_capacity + delta > device_capacity {
640 break;
641 }
642 let _ = layer_sizes_in_bytes.pop().unwrap();
643 used_capacity += delta;
644 used_capacity_no_act += delta;
645 layers_on_device += 1;
646 }
647
648 if layers_on_device > 0 {
651 remaining_to_map = remaining_to_map.saturating_sub(used_capacity_no_act);
652 } else {
653 warn!(
654 "Device {} can fit 0 layers. Consider reducing auto map params from current: {params} (ex. reducing max seq len or max num images)",
655 device.device_pretty_repr(),
656 );
657 current_ordinal += 1;
658 continue;
659 }
660 layers_on_device
661 };
662
663 if !device.is_cpu() {
665 device_layers.push(DeviceLayerMapMetadata {
666 ordinal: current_ordinal,
667 layers: layers_on_device,
668 });
669 current_ordinal += 1;
670 } else {
671 mapping_includes_cpu = true;
672 }
673
674 current_layer += layers_on_device;
675 }
676 if remaining_to_map > 0 {
677 anyhow::bail!(
678 "This model does not fit on the devices {:?}, and exceeds total capacity by {}MB. Auto device mapping params: {params}",
679 per_layer_avail_cpy
680 .iter()
681 .rev()
682 .map(|(avail, dev)| format!(
683 "{} (avail: {}MB)",
684 dev.device_pretty_repr(),
685 avail / (1024 * 1024),
686 ))
687 .collect::<Vec<_>>(),
688 b_to_mb!(remaining_to_map)
689 );
690 }
691
692 if paged_attn_config.is_some_and(|_| mapping_includes_cpu) {
695 return self.get_device_layers(
696 config,
697 num_layers,
698 layer_sizes_in_bytes,
699 non_mapped_size_in_bytes,
700 total_model_size_in_bytes,
701 devices,
702 dtype,
703 params,
704 prompt_chunksize,
705 None,
706 );
707 }
708
709 Ok(DeviceMapMetadata::from_num_device_layers(device_layers))
710 }
711}
712
713pub trait Loader: Send + Sync {
734 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
738 fn load_model_from_hf(
739 &self,
740 revision: Option<String>,
741 token_source: TokenSource,
742 dtype: &dyn TryIntoDType,
743 device: &Device,
744 silent: bool,
745 mapper: DeviceMapSetting,
746 in_situ_quant: Option<IsqType>,
747 paged_attn_config: Option<PagedAttentionConfig>,
748 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
749
750 #[allow(
753 clippy::type_complexity,
754 clippy::too_many_arguments,
755 clippy::borrowed_box
756 )]
757 fn load_model_from_path(
758 &self,
759 paths: &Box<dyn ModelPaths>,
760 dtype: &dyn TryIntoDType,
761 device: &Device,
762 silent: bool,
763 mapper: DeviceMapSetting,
764 in_situ_quant: Option<IsqType>,
765 paged_attn_config: Option<PagedAttentionConfig>,
766 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
767
768 fn get_id(&self) -> String;
769 fn get_kind(&self) -> ModelKind;
770}