mistralrs_core/pipeline/loaders/
mod.rs

1mod diffusion_loaders;
2mod normal_loaders;
3mod vision_loaders;
4
5use std::{
6    fmt::{self, Debug, Display},
7    path::PathBuf,
8    str::FromStr,
9    sync::Arc,
10};
11
12use anyhow::{Context, Result};
13use as_any::AsAny;
14use candle_core::{DType, Device};
15use itertools::Itertools;
16use mistralrs_quant::IsqType;
17use tokio::sync::Mutex;
18
19pub use normal_loaders::{
20    AutoLoader, DeepSeekV2Loader, DeepSeekV3Loader, Gemma2Loader, GemmaLoader, LlamaLoader,
21    MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata, NormalModel,
22    NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader,
23    Qwen3MoELoader, Starcoder2Loader,
24};
25
26use tracing::{info, warn};
27pub use vision_loaders::{
28    Gemma3Loader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, MiniCpmOLoader,
29    Mistral3Loader, Phi3VLoader, Phi4MMLoader, Qwen2VLLoader, Qwen2_5VLLoader, VLlama4Loader,
30    VLlamaLoader, VisionLoaderType, VisionModel, VisionModelLoader,
31};
32
33pub use diffusion_loaders::{
34    DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
35    DiffusionModelPathsInner, FluxLoader,
36};
37
38use crate::{
39    paged_attention::{
40        calculate_cache_config, ModelConfigLike, DEFAULT_PAGED_ATTENTION_BLOCK_SIZE,
41    },
42    utils::debug::DeviceRepr,
43    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, MemoryUsage, PagedAttentionConfig,
44    TryIntoDType,
45};
46
47use super::{paths::AdapterPaths, Pipeline};
48
49/// `ModelPaths` abstracts the mechanism to get all necessary files for running a model. For
50/// example `LocalModelPaths` implements `ModelPaths` when all files are in the local file system.
51pub trait ModelPaths: AsAny + Debug + Send + Sync {
52    /// Model weights files (multiple files supported).
53    fn get_weight_filenames(&self) -> &[PathBuf];
54
55    /// Retrieve the [`PretrainedConfig`] file.
56    ///
57    /// [`PretrainedConfig`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/configuration#transformers.PretrainedConfig
58    fn get_config_filename(&self) -> &PathBuf;
59
60    /// A serialised [`tokenizers.Tokenizer`] HuggingFace object.
61    ///
62    /// [`tokenizers.Tokenizer`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer
63    fn get_tokenizer_filename(&self) -> &PathBuf;
64
65    /// File where the content is expected to deserialize to [`ChatTemplate`].
66    ///
67    /// [`ChatTemplate`]: crate::ChatTemplate
68    fn get_template_filename(&self) -> &Option<PathBuf>;
69
70    /// Filepath for general model configuration.
71    fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
72
73    /// Get the preprocessor config (for the vision models). This is used to pre process images.
74    fn get_preprocessor_config(&self) -> &Option<PathBuf>;
75
76    /// Get the processor config (for the vision models). This is primarily used for the chat template.
77    fn get_processor_config(&self) -> &Option<PathBuf>;
78
79    /// Get the explicit chat template.
80    fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
81
82    /// Get adapter paths.
83    fn get_adapter_paths(&self) -> &AdapterPaths;
84}
85
86#[derive(Clone, Debug)]
87/// All local paths and metadata necessary to load a model.
88pub struct LocalModelPaths<P: Debug> {
89    pub tokenizer_filename: P,
90    pub config_filename: P,
91    pub template_filename: Option<P>,
92    pub filenames: Vec<P>,
93    pub adapter_paths: AdapterPaths,
94    pub gen_conf: Option<P>,
95    pub preprocessor_config: Option<P>,
96    pub processor_config: Option<P>,
97    pub chat_template_json_filename: Option<P>,
98}
99
100impl<P: Debug> LocalModelPaths<P> {
101    #[allow(clippy::too_many_arguments)]
102    pub fn new(
103        tokenizer_filename: P,
104        config_filename: P,
105        template_filename: P,
106        filenames: Vec<P>,
107        adapter_paths: AdapterPaths,
108        gen_conf: Option<P>,
109        preprocessor_config: Option<P>,
110        processor_config: Option<P>,
111        chat_template_json_filename: Option<P>,
112    ) -> Self {
113        Self {
114            tokenizer_filename,
115            config_filename,
116            template_filename: Some(template_filename),
117            filenames,
118            adapter_paths,
119            gen_conf,
120            preprocessor_config,
121            processor_config,
122            chat_template_json_filename,
123        }
124    }
125}
126
127impl ModelPaths for LocalModelPaths<PathBuf> {
128    fn get_config_filename(&self) -> &PathBuf {
129        &self.config_filename
130    }
131    fn get_tokenizer_filename(&self) -> &PathBuf {
132        &self.tokenizer_filename
133    }
134    fn get_weight_filenames(&self) -> &[PathBuf] {
135        &self.filenames
136    }
137    fn get_template_filename(&self) -> &Option<PathBuf> {
138        &self.template_filename
139    }
140    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
141        self.gen_conf.as_ref()
142    }
143    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
144        &self.preprocessor_config
145    }
146    fn get_processor_config(&self) -> &Option<PathBuf> {
147        &self.processor_config
148    }
149    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
150        &self.chat_template_json_filename
151    }
152    fn get_adapter_paths(&self) -> &AdapterPaths {
153        &self.adapter_paths
154    }
155}
156
157#[derive(Debug, Clone)]
158/// The source of the HF token.
159pub enum TokenSource {
160    Literal(String),
161    EnvVar(String),
162    Path(String),
163    CacheToken,
164    None,
165}
166
167impl FromStr for TokenSource {
168    type Err = String;
169
170    fn from_str(s: &str) -> Result<Self, Self::Err> {
171        let parts: Vec<&str> = s.splitn(2, ':').collect();
172        match parts[0] {
173            "literal" => parts
174                .get(1)
175                .map(|&value| TokenSource::Literal(value.to_string()))
176                .ok_or_else(|| "Expected a value for 'literal'".to_string()),
177            "env" => Ok(TokenSource::EnvVar(
178                parts
179                    .get(1)
180                    .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
181                    .to_string(),
182            )),
183            "path" => parts
184                .get(1)
185                .map(|&value| TokenSource::Path(value.to_string()))
186                .ok_or_else(|| "Expected a value for 'path'".to_string()),
187            "cache" => Ok(TokenSource::CacheToken),
188            "none" => Ok(TokenSource::None),
189            _ => Err("Invalid token source format".to_string()),
190        }
191    }
192}
193
194impl fmt::Display for TokenSource {
195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        match self {
197            TokenSource::Literal(value) => write!(f, "literal:{}", value),
198            TokenSource::EnvVar(value) => write!(f, "env:{}", value),
199            TokenSource::Path(value) => write!(f, "path:{}", value),
200            TokenSource::CacheToken => write!(f, "cache"),
201            TokenSource::None => write!(f, "none"),
202        }
203    }
204}
205
206/// The kind of model to build.
207#[derive(Clone, Default, derive_more::From, strum::Display)]
208pub enum ModelKind {
209    #[default]
210    #[strum(to_string = "normal (no adapters)")]
211    Normal,
212
213    #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
214    GgufQuantized { quant: QuantizationKind },
215
216    #[strum(to_string = "{adapter}")]
217    Adapter { adapter: AdapterKind },
218
219    #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
220    GgufAdapter {
221        adapter: AdapterKind,
222        quant: QuantizationKind,
223    },
224
225    #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
226    Speculative {
227        target: Box<ModelKind>,
228        draft: Box<ModelKind>,
229    },
230
231    #[strum(to_string = "anymoe: target: `{target}`")]
232    AnyMoe { target: Box<ModelKind> },
233}
234
235#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
236#[strum(serialize_all = "kebab-case")]
237pub enum QuantizationKind {
238    /// GGML
239    Ggml,
240    /// GGUF
241    Gguf,
242    /// GPTQ
243    Gptq,
244}
245
246#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
247#[strum(serialize_all = "kebab-case")]
248pub enum AdapterKind {
249    /// LoRA
250    Lora,
251    /// X-LoRA
252    XLora,
253}
254
255// For the proper name as formatted via doc comment for a variant
256pub trait PrettyName: strum::EnumMessage + ToString {
257    fn pretty_name(&self) -> String {
258        match self.get_documentation() {
259            Some(s) => s.to_string(),
260            // Instead of panic via expect(),
261            // fallback to default kebab-case:
262            None => self.to_string(),
263        }
264    }
265}
266
267impl PrettyName for AdapterKind {}
268impl PrettyName for QuantizationKind {}
269
270impl ModelKind {
271    // Quantized helpers:
272    pub fn is_quantized(&self) -> bool {
273        self.quantized_kind().iter().any(|q| q.is_some())
274    }
275
276    pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
277        self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
278    }
279
280    pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
281        use ModelKind::*;
282
283        match self {
284            Normal | Adapter { .. } => vec![None],
285            GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
286            Speculative { target, draft } => {
287                let t = *target.clone();
288                let d = *draft.clone();
289
290                [t.quantized_kind(), d.quantized_kind()].concat()
291            }
292            AnyMoe { target } => target.quantized_kind(),
293        }
294    }
295
296    // Adapter helpers:
297    pub fn is_adapted(&self) -> bool {
298        self.adapted_kind().iter().any(|a| a.is_some())
299    }
300
301    pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
302        self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
303    }
304
305    pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
306        use ModelKind::*;
307
308        match self {
309            Normal | GgufQuantized { .. } => vec![None],
310            Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
311            Speculative { target, draft } => {
312                let t = *target.clone();
313                let d = *draft.clone();
314
315                [t.adapted_kind(), d.adapted_kind()].concat()
316            }
317            AnyMoe { target } => target.adapted_kind(),
318        }
319    }
320}
321
322macro_rules! b_to_mb {
323    ($x:expr) => {
324        $x / (1024 * 1024)
325    };
326}
327
328#[derive(Debug, Clone)]
329pub enum AutoDeviceMapParams {
330    Text {
331        max_seq_len: usize,
332        max_batch_size: usize,
333    },
334    Vision {
335        max_seq_len: usize,
336        max_batch_size: usize,
337        max_image_shape: (usize, usize),
338        max_num_images: usize,
339    },
340}
341
342impl AutoDeviceMapParams {
343    pub fn max_seq_len(&self) -> usize {
344        match self {
345            Self::Text { max_seq_len, .. } | Self::Vision { max_seq_len, .. } => *max_seq_len,
346        }
347    }
348}
349
350impl Display for AutoDeviceMapParams {
351    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352        match self {
353            Self::Text {
354                max_seq_len,
355                max_batch_size,
356            } => write!(
357                f,
358                "text[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}]"
359            ),
360            Self::Vision {
361                max_seq_len,
362                max_batch_size,
363                max_image_shape,
364                max_num_images
365            } => write!(
366                f,
367                "vision[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}, max_image_shape: {max_image_shape:?}, max_num_images: {max_num_images}]"
368            ),
369        }
370    }
371}
372
373impl AutoDeviceMapParams {
374    pub const DEFAULT_MAX_SEQ_LEN: usize = 4 * 1024;
375    pub const DEFAULT_MAX_BATCH_SIZE: usize = 1;
376    pub const DEFAULT_MAX_NUM_IMAGES: usize = 1;
377    pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 1024;
378
379    pub fn default_text() -> Self {
380        Self::Text {
381            max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
382            max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
383        }
384    }
385
386    pub fn default_vision() -> Self {
387        Self::Vision {
388            max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
389            max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
390            max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
391            max_image_shape: (
392                Self::DEFAULT_MAX_IMAGE_LENGTH,
393                Self::DEFAULT_MAX_IMAGE_LENGTH,
394            ),
395        }
396    }
397}
398
399#[derive(Clone, Debug)]
400pub(crate) enum NonMappedSubModel {
401    Vision,
402}
403
404impl Display for NonMappedSubModel {
405    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
406        match self {
407            Self::Vision => write!(f, "vision"),
408        }
409    }
410}
411
412fn calculate_key_block_shape(
413    model_config: &dyn ModelConfigLike,
414    dtype: DType,
415    block_size: usize,
416) -> (usize, usize, usize, usize) {
417    let element_size = dtype.size_in_bytes();
418    let x = 16 / element_size;
419    (
420        model_config.num_kv_heads(),
421        model_config.k_head_dim() / x,
422        block_size,
423        x,
424    )
425}
426
427fn calculate_value_block_shape(
428    model_config: &dyn ModelConfigLike,
429    block_size: usize,
430) -> (usize, usize, usize) {
431    (
432        model_config.num_kv_heads(),
433        model_config.v_head_dim(),
434        block_size,
435    )
436}
437
438pub trait DeviceMappedModelLoader {
439    /// Maximum activation size of non-mapped parts of this model.
440    /// Useful for the vision models which may prefer to keep the vison components on the GPU.
441    fn non_mapped_max_act_size_elems(
442        &self,
443        config: &str,
444        params: &AutoDeviceMapParams,
445    ) -> Result<usize>;
446    /// Maximum activation size of mapped parts of the model
447    fn mapped_max_act_size_elems(
448        &self,
449        config: &str,
450        params: &AutoDeviceMapParams,
451        prompt_chunksize: usize,
452    ) -> Result<usize>;
453    /// weight_pack_factor only applies to quantized weights.
454    fn non_mapped_size_in_bytes(
455        &self,
456        config: &str,
457        dtype: DType,
458        weight_pack_factor: usize,
459    ) -> Result<usize>;
460    /// weight_pack_factor only applies to quantized weights.
461    fn layer_sizes_in_bytes(
462        &self,
463        config: &str,
464        dtype: DType,
465        weight_pack_factor: usize,
466    ) -> Result<Vec<usize>>;
467    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
468        None
469    }
470    fn num_layers(&self, config: &str) -> Result<usize>;
471    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
472
473    #[allow(clippy::too_many_arguments)]
474    fn get_device_layers(
475        &self,
476        config: &str,
477        num_layers: usize,
478        mut layer_sizes_in_bytes: Vec<usize>,
479        non_mapped_size_in_bytes: usize,
480        total_model_size_in_bytes: usize,
481        devices: &[Device],
482        dtype: DType,
483        params: &AutoDeviceMapParams,
484        prompt_chunksize: usize,
485        paged_attn_config: Option<&PagedAttentionConfig>,
486    ) -> Result<DeviceMapMetadata> {
487        let mapped_max_act_size_in_bytes =
488            self.mapped_max_act_size_elems(config, params, prompt_chunksize)?
489                * dtype.size_in_bytes();
490        let non_mapped_max_act_size_in_bytes =
491            self.non_mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
492
493        let mut remaining_to_map = total_model_size_in_bytes;
494
495        let max_seq_len = match params {
496            AutoDeviceMapParams::Text { max_seq_len, .. }
497            | AutoDeviceMapParams::Vision { max_seq_len, .. } => *max_seq_len,
498        };
499        let max_batch_size = match params {
500            AutoDeviceMapParams::Text { max_batch_size, .. }
501            | AutoDeviceMapParams::Vision { max_batch_size, .. } => *max_batch_size,
502        };
503
504        let model_cfg = self.model_config(config)?;
505        let kv_cache_size_elems = match paged_attn_config {
506            Some(paged_attn_config) => {
507                let cache_config = calculate_cache_config(
508                    paged_attn_config.mem_gpu,
509                    paged_attn_config.mem_cpu,
510                    Some(
511                        paged_attn_config
512                            .block_size
513                            .unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE),
514                    ),
515                    dtype,
516                    &*model_cfg,
517                    &devices[0],
518                    &devices.iter().map(|x| Some(x.clone())).collect::<Vec<_>>(),
519                    true,
520                )?;
521
522                let key_block_shape =
523                    calculate_key_block_shape(&*model_cfg, dtype, cache_config.block_size);
524                let key_block_size = cache_config.num_gpu_blocks
525                    * key_block_shape.0
526                    * key_block_shape.1
527                    * key_block_shape.2
528                    * key_block_shape.3;
529
530                let value_block_shape = calculate_value_block_shape(
531                    &*self.model_config(config)?,
532                    cache_config.block_size,
533                );
534                let value_block_size = cache_config.num_gpu_blocks
535                    * value_block_shape.0
536                    * value_block_shape.1
537                    * value_block_shape.2;
538
539                key_block_size + value_block_size
540            }
541            None => {
542                // Non-paged KV cache
543                let key_block_shape = [
544                    max_batch_size,
545                    model_cfg.num_kv_heads(),
546                    max_seq_len,
547                    model_cfg.k_head_dim(),
548                ];
549                let value_block_shape = [
550                    max_batch_size,
551                    model_cfg.num_kv_heads(),
552                    max_seq_len,
553                    model_cfg.v_head_dim(),
554                ];
555
556                key_block_shape.into_iter().product::<usize>()
557                    + value_block_shape.iter().product::<usize>()
558            }
559        };
560        let kv_cache_size_in_bytes = kv_cache_size_elems * dtype.size_in_bytes();
561
562        let mut per_layer_avail = Vec::new();
563        // Always add the CPU as fallback
564        for dev in [devices, &[Device::Cpu]].concat() {
565            let avail = MemoryUsage.get_memory_available(&dev)?;
566            per_layer_avail.push((avail, dev));
567        }
568        // Reverse so we don't use the cpu first!
569        per_layer_avail.reverse();
570
571        // Reverse layer sizes so we can pop
572        layer_sizes_in_bytes.reverse();
573
574        let mut device_layers = Vec::new();
575
576        info!("Using automatic device mapping parameters: {params}.");
577        if let Some(sub_models) = self.non_mapped_sub_models() {
578            let (_, last) = per_layer_avail.last().unwrap();
579            info!(
580                "The following sub-models will not be device mapped and will be loaded on {}: {}",
581                last.device_pretty_repr(),
582                sub_models.iter().map(|x| x.to_string()).join(", ")
583            );
584        }
585
586        let mut current_ordinal = 0;
587        let mut current_layer = 0;
588        let per_layer_avail_cpy = per_layer_avail.clone();
589        let mut mapping_includes_cpu = false;
590        while remaining_to_map > 0 && !per_layer_avail.is_empty() {
591            let (device_capacity, device) = per_layer_avail
592                .pop()
593                .context("No more devices to map to. The model does not fit on this system.")?;
594            // All usage of 90% of the memory as a maximum.
595            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
596            let device_capacity = (device_capacity as f64 * 0.90) as usize;
597
598            // Algorithm is to check the following:
599            // 1) (no mapping) if *everything* fits on the first dev (non mapped and mapped)
600            // 2) if the mapped activations plus remaining fits on the nth device
601            // 3) common case, iteratively find the optimal amount of layers to put on the nth device
602            //   - if this is the first dev: must hold the non-mapped act and non-mapped model
603            //   - otherwise, must hold the mapped act
604            #[allow(clippy::if_same_then_else)]
605            let layers_on_device = if current_ordinal == 0
606                && device_capacity
607                    >= remaining_to_map
608                        + non_mapped_max_act_size_in_bytes.max(mapped_max_act_size_in_bytes)
609                        + non_mapped_size_in_bytes
610                        + kv_cache_size_in_bytes * (num_layers - current_layer)
611            {
612                remaining_to_map = 0;
613
614                num_layers - current_layer
615            } else if current_ordinal != 0
616                && device_capacity
617                    >= remaining_to_map
618                        + mapped_max_act_size_in_bytes
619                        + kv_cache_size_in_bytes * (num_layers - current_layer)
620            {
621                remaining_to_map = 0;
622
623                num_layers - current_layer
624            } else {
625                // All devices need to account for the max mapped act size
626                let mut used_capacity = mapped_max_act_size_in_bytes;
627                let mut used_capacity_no_act = 0;
628                let mut layers_on_device = 0;
629
630                // Device w/ ordinal 0 carries the non-mapped things
631                if current_ordinal == 0 {
632                    // Ensure the activations are properly handled
633                    used_capacity = used_capacity.max(non_mapped_max_act_size_in_bytes);
634                    used_capacity += non_mapped_size_in_bytes;
635                    used_capacity_no_act += non_mapped_size_in_bytes;
636                }
637
638                while let Some(&last) = layer_sizes_in_bytes.last() {
639                    let delta = last + kv_cache_size_in_bytes;
640                    if used_capacity + delta > device_capacity {
641                        break;
642                    }
643                    let _ = layer_sizes_in_bytes.pop().unwrap();
644                    used_capacity += delta;
645                    used_capacity_no_act += delta;
646                    layers_on_device += 1;
647                }
648
649                // Do not reduce amount to map if this device can't fit any
650                // If the device cannot fit any layers, warn the user.
651                if layers_on_device > 0 {
652                    remaining_to_map = remaining_to_map.saturating_sub(used_capacity_no_act);
653                } else {
654                    warn!(
655                        "Device {} can fit 0 layers. Consider reducing auto map params from current: {params} (ex. reducing max seq len or max num images)",
656                        device.device_pretty_repr(),
657                    );
658                    current_ordinal += 1;
659                    continue;
660                }
661                layers_on_device
662            };
663
664            // CPU mappings are automatically handled by the traditional device mapper, we can just leave them out here.
665            if !device.is_cpu() {
666                device_layers.push(DeviceLayerMapMetadata {
667                    ordinal: current_ordinal,
668                    layers: layers_on_device,
669                });
670                current_ordinal += 1;
671            } else {
672                mapping_includes_cpu = true;
673            }
674
675            current_layer += layers_on_device;
676        }
677        if remaining_to_map > 0 {
678            anyhow::bail!(
679                "This model does not fit on the devices {:?}, and exceeds total capacity by {}MB. Auto device mapping params: {params}",
680                per_layer_avail_cpy
681                    .iter()
682                    .rev()
683                    .map(|(avail, dev)| format!(
684                        "{} (avail: {}MB)",
685                        dev.device_pretty_repr(),
686                        avail / (1024 * 1024),
687                    ))
688                    .collect::<Vec<_>>(),
689                b_to_mb!(remaining_to_map)
690            );
691        }
692
693        // TODO: PagedAttention is not supported with CPU for now.
694        // Recalculate without PagedAttention metadata.
695        if paged_attn_config.is_some_and(|_| mapping_includes_cpu) {
696            return self.get_device_layers(
697                config,
698                num_layers,
699                layer_sizes_in_bytes,
700                non_mapped_size_in_bytes,
701                total_model_size_in_bytes,
702                devices,
703                dtype,
704                params,
705                prompt_chunksize,
706                None,
707            );
708        }
709
710        Ok(DeviceMapMetadata::from_num_device_layers(device_layers))
711    }
712}
713
714/// The `Loader` trait abstracts the loading process. The primary entrypoint is the
715/// `load_model` method.
716///
717/// # Example
718/// ```no_run
719/// use mistralrs_core::{Loader, TokenSource, DeviceMapSetting, AutoDeviceMapParams, ModelDType};
720/// use candle_core::Device;
721///
722/// let loader: Box<dyn Loader> = todo!();
723/// let pipeline = loader.load_model_from_hf(
724///     None,
725///     TokenSource::CacheToken,
726///     &ModelDType::Auto,
727///     &Device::cuda_if_available(0).unwrap(),
728///     false,
729///     DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
730///     None,
731///     None,
732/// ).unwrap();
733/// ```
734pub trait Loader: Send + Sync {
735    /// If `revision` is None, then it defaults to `main`.
736    /// If `dtype` is None, then it defaults to the model default (usually BF16).
737    /// If model is not found on HF, will attempt to resolve locally.
738    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
739    fn load_model_from_hf(
740        &self,
741        revision: Option<String>,
742        token_source: TokenSource,
743        dtype: &dyn TryIntoDType,
744        device: &Device,
745        silent: bool,
746        mapper: DeviceMapSetting,
747        in_situ_quant: Option<IsqType>,
748        paged_attn_config: Option<PagedAttentionConfig>,
749    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
750
751    /// Load a model from the specified paths.
752    /// Also initializes `DEBUG`.
753    #[allow(
754        clippy::type_complexity,
755        clippy::too_many_arguments,
756        clippy::borrowed_box
757    )]
758    fn load_model_from_path(
759        &self,
760        paths: &Box<dyn ModelPaths>,
761        dtype: &dyn TryIntoDType,
762        device: &Device,
763        silent: bool,
764        mapper: DeviceMapSetting,
765        in_situ_quant: Option<IsqType>,
766        paged_attn_config: Option<PagedAttentionConfig>,
767    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
768
769    fn get_id(&self) -> String;
770    fn get_kind(&self) -> ModelKind;
771}