mistralrs_core/pipeline/loaders/
mod.rs

1mod diffusion_loaders;
2mod normal_loaders;
3mod vision_loaders;
4
5use std::{
6    fmt::{self, Debug, Display},
7    path::PathBuf,
8    str::FromStr,
9    sync::Arc,
10};
11
12use anyhow::{Context, Result};
13use as_any::AsAny;
14use candle_core::{DType, Device};
15use itertools::Itertools;
16use mistralrs_quant::IsqType;
17use tokio::sync::Mutex;
18
19pub use normal_loaders::{
20    AutoLoader, DeepSeekV2Loader, DeepSeekV3Loader, Gemma2Loader, GemmaLoader, LlamaLoader,
21    MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata, NormalModel,
22    NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Starcoder2Loader,
23};
24
25use tracing::{info, warn};
26pub use vision_loaders::{
27    Gemma3Loader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, MiniCpmOLoader,
28    Mistral3Loader, Phi3VLoader, Phi4MMLoader, Qwen2VLLoader, Qwen2_5VLLoader, VLlamaLoader,
29    VisionLoaderType, VisionModel, VisionModelLoader,
30};
31
32pub use diffusion_loaders::{
33    DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
34    DiffusionModelPathsInner, FluxLoader,
35};
36
37use crate::{
38    paged_attention::{
39        calculate_cache_config, ModelConfigLike, DEFAULT_PAGED_ATTENTION_BLOCK_SIZE,
40    },
41    utils::debug::DeviceRepr,
42    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, MemoryUsage, PagedAttentionConfig,
43    TryIntoDType,
44};
45
46use super::{paths::AdapterPaths, Pipeline};
47
48/// `ModelPaths` abstracts the mechanism to get all necessary files for running a model. For
49/// example `LocalModelPaths` implements `ModelPaths` when all files are in the local file system.
50pub trait ModelPaths: AsAny + Debug + Send + Sync {
51    /// Model weights files (multiple files supported).
52    fn get_weight_filenames(&self) -> &[PathBuf];
53
54    /// Retrieve the [`PretrainedConfig`] file.
55    ///
56    /// [`PretrainedConfig`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/configuration#transformers.PretrainedConfig
57    fn get_config_filename(&self) -> &PathBuf;
58
59    /// A serialised [`tokenizers.Tokenizer`] HuggingFace object.
60    ///
61    /// [`tokenizers.Tokenizer`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer
62    fn get_tokenizer_filename(&self) -> &PathBuf;
63
64    /// File where the content is expected to deserialize to [`ChatTemplate`].
65    ///
66    /// [`ChatTemplate`]: crate::ChatTemplate
67    fn get_template_filename(&self) -> &Option<PathBuf>;
68
69    /// Filepath for general model configuration.
70    fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
71
72    /// Get the preprocessor config (for the vision models). This is used to pre process images.
73    fn get_preprocessor_config(&self) -> &Option<PathBuf>;
74
75    /// Get the processor config (for the vision models). This is primarily used for the chat template.
76    fn get_processor_config(&self) -> &Option<PathBuf>;
77
78    /// Get the explicit chat template.
79    fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
80
81    /// Get adapter paths.
82    fn get_adapter_paths(&self) -> &AdapterPaths;
83}
84
85#[derive(Clone, Debug)]
86/// All local paths and metadata necessary to load a model.
87pub struct LocalModelPaths<P: Debug> {
88    pub tokenizer_filename: P,
89    pub config_filename: P,
90    pub template_filename: Option<P>,
91    pub filenames: Vec<P>,
92    pub adapter_paths: AdapterPaths,
93    pub gen_conf: Option<P>,
94    pub preprocessor_config: Option<P>,
95    pub processor_config: Option<P>,
96    pub chat_template_json_filename: Option<P>,
97}
98
99impl<P: Debug> LocalModelPaths<P> {
100    #[allow(clippy::too_many_arguments)]
101    pub fn new(
102        tokenizer_filename: P,
103        config_filename: P,
104        template_filename: P,
105        filenames: Vec<P>,
106        adapter_paths: AdapterPaths,
107        gen_conf: Option<P>,
108        preprocessor_config: Option<P>,
109        processor_config: Option<P>,
110        chat_template_json_filename: Option<P>,
111    ) -> Self {
112        Self {
113            tokenizer_filename,
114            config_filename,
115            template_filename: Some(template_filename),
116            filenames,
117            adapter_paths,
118            gen_conf,
119            preprocessor_config,
120            processor_config,
121            chat_template_json_filename,
122        }
123    }
124}
125
126impl ModelPaths for LocalModelPaths<PathBuf> {
127    fn get_config_filename(&self) -> &PathBuf {
128        &self.config_filename
129    }
130    fn get_tokenizer_filename(&self) -> &PathBuf {
131        &self.tokenizer_filename
132    }
133    fn get_weight_filenames(&self) -> &[PathBuf] {
134        &self.filenames
135    }
136    fn get_template_filename(&self) -> &Option<PathBuf> {
137        &self.template_filename
138    }
139    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
140        self.gen_conf.as_ref()
141    }
142    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
143        &self.preprocessor_config
144    }
145    fn get_processor_config(&self) -> &Option<PathBuf> {
146        &self.processor_config
147    }
148    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
149        &self.chat_template_json_filename
150    }
151    fn get_adapter_paths(&self) -> &AdapterPaths {
152        &self.adapter_paths
153    }
154}
155
156#[derive(Debug, Clone)]
157/// The source of the HF token.
158pub enum TokenSource {
159    Literal(String),
160    EnvVar(String),
161    Path(String),
162    CacheToken,
163    None,
164}
165
166impl FromStr for TokenSource {
167    type Err = String;
168
169    fn from_str(s: &str) -> Result<Self, Self::Err> {
170        let parts: Vec<&str> = s.splitn(2, ':').collect();
171        match parts[0] {
172            "literal" => parts
173                .get(1)
174                .map(|&value| TokenSource::Literal(value.to_string()))
175                .ok_or_else(|| "Expected a value for 'literal'".to_string()),
176            "env" => Ok(TokenSource::EnvVar(
177                parts
178                    .get(1)
179                    .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
180                    .to_string(),
181            )),
182            "path" => parts
183                .get(1)
184                .map(|&value| TokenSource::Path(value.to_string()))
185                .ok_or_else(|| "Expected a value for 'path'".to_string()),
186            "cache" => Ok(TokenSource::CacheToken),
187            "none" => Ok(TokenSource::None),
188            _ => Err("Invalid token source format".to_string()),
189        }
190    }
191}
192
193impl fmt::Display for TokenSource {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        match self {
196            TokenSource::Literal(value) => write!(f, "literal:{}", value),
197            TokenSource::EnvVar(value) => write!(f, "env:{}", value),
198            TokenSource::Path(value) => write!(f, "path:{}", value),
199            TokenSource::CacheToken => write!(f, "cache"),
200            TokenSource::None => write!(f, "none"),
201        }
202    }
203}
204
205/// The kind of model to build.
206#[derive(Clone, Default, derive_more::From, strum::Display)]
207pub enum ModelKind {
208    #[default]
209    #[strum(to_string = "normal (no adapters)")]
210    Normal,
211
212    #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
213    GgufQuantized { quant: QuantizationKind },
214
215    #[strum(to_string = "{adapter}")]
216    Adapter { adapter: AdapterKind },
217
218    #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
219    GgufAdapter {
220        adapter: AdapterKind,
221        quant: QuantizationKind,
222    },
223
224    #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
225    Speculative {
226        target: Box<ModelKind>,
227        draft: Box<ModelKind>,
228    },
229
230    #[strum(to_string = "anymoe: target: `{target}`")]
231    AnyMoe { target: Box<ModelKind> },
232}
233
234#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
235#[strum(serialize_all = "kebab-case")]
236pub enum QuantizationKind {
237    /// GGML
238    Ggml,
239    /// GGUF
240    Gguf,
241    /// GPTQ
242    Gptq,
243}
244
245#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
246#[strum(serialize_all = "kebab-case")]
247pub enum AdapterKind {
248    /// LoRA
249    Lora,
250    /// X-LoRA
251    XLora,
252}
253
254// For the proper name as formatted via doc comment for a variant
255pub trait PrettyName: strum::EnumMessage + ToString {
256    fn pretty_name(&self) -> String {
257        match self.get_documentation() {
258            Some(s) => s.to_string(),
259            // Instead of panic via expect(),
260            // fallback to default kebab-case:
261            None => self.to_string(),
262        }
263    }
264}
265
266impl PrettyName for AdapterKind {}
267impl PrettyName for QuantizationKind {}
268
269impl ModelKind {
270    // Quantized helpers:
271    pub fn is_quantized(&self) -> bool {
272        self.quantized_kind().iter().any(|q| q.is_some())
273    }
274
275    pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
276        self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
277    }
278
279    pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
280        use ModelKind::*;
281
282        match self {
283            Normal | Adapter { .. } => vec![None],
284            GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
285            Speculative { target, draft } => {
286                let t = *target.clone();
287                let d = *draft.clone();
288
289                [t.quantized_kind(), d.quantized_kind()].concat()
290            }
291            AnyMoe { target } => target.quantized_kind(),
292        }
293    }
294
295    // Adapter helpers:
296    pub fn is_adapted(&self) -> bool {
297        self.adapted_kind().iter().any(|a| a.is_some())
298    }
299
300    pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
301        self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
302    }
303
304    pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
305        use ModelKind::*;
306
307        match self {
308            Normal | GgufQuantized { .. } => vec![None],
309            Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
310            Speculative { target, draft } => {
311                let t = *target.clone();
312                let d = *draft.clone();
313
314                [t.adapted_kind(), d.adapted_kind()].concat()
315            }
316            AnyMoe { target } => target.adapted_kind(),
317        }
318    }
319}
320
321macro_rules! b_to_mb {
322    ($x:expr) => {
323        $x / (1024 * 1024)
324    };
325}
326
327#[derive(Debug, Clone)]
328pub enum AutoDeviceMapParams {
329    Text {
330        max_seq_len: usize,
331        max_batch_size: usize,
332    },
333    Vision {
334        max_seq_len: usize,
335        max_batch_size: usize,
336        max_image_shape: (usize, usize),
337        max_num_images: usize,
338    },
339}
340
341impl AutoDeviceMapParams {
342    pub fn max_seq_len(&self) -> usize {
343        match self {
344            Self::Text { max_seq_len, .. } | Self::Vision { max_seq_len, .. } => *max_seq_len,
345        }
346    }
347}
348
349impl Display for AutoDeviceMapParams {
350    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351        match self {
352            Self::Text {
353                max_seq_len,
354                max_batch_size,
355            } => write!(
356                f,
357                "text[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}]"
358            ),
359            Self::Vision {
360                max_seq_len,
361                max_batch_size,
362                max_image_shape,
363                max_num_images
364            } => write!(
365                f,
366                "vision[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}, max_image_shape: {max_image_shape:?}, max_num_images: {max_num_images}]"
367            ),
368        }
369    }
370}
371
372impl AutoDeviceMapParams {
373    pub const DEFAULT_MAX_SEQ_LEN: usize = 4 * 1024;
374    pub const DEFAULT_MAX_BATCH_SIZE: usize = 1;
375    pub const DEFAULT_MAX_NUM_IMAGES: usize = 1;
376    pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 1024;
377
378    pub fn default_text() -> Self {
379        Self::Text {
380            max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
381            max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
382        }
383    }
384
385    pub fn default_vision() -> Self {
386        Self::Vision {
387            max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
388            max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
389            max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
390            max_image_shape: (
391                Self::DEFAULT_MAX_IMAGE_LENGTH,
392                Self::DEFAULT_MAX_IMAGE_LENGTH,
393            ),
394        }
395    }
396}
397
398#[derive(Clone, Debug)]
399pub(crate) enum NonMappedSubModel {
400    Vision,
401}
402
403impl Display for NonMappedSubModel {
404    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
405        match self {
406            Self::Vision => write!(f, "vision"),
407        }
408    }
409}
410
411fn calculate_key_block_shape(
412    model_config: &dyn ModelConfigLike,
413    dtype: DType,
414    block_size: usize,
415) -> (usize, usize, usize, usize) {
416    let element_size = dtype.size_in_bytes();
417    let x = 16 / element_size;
418    (
419        model_config.num_kv_heads(),
420        model_config.k_head_dim() / x,
421        block_size,
422        x,
423    )
424}
425
426fn calculate_value_block_shape(
427    model_config: &dyn ModelConfigLike,
428    block_size: usize,
429) -> (usize, usize, usize) {
430    (
431        model_config.num_kv_heads(),
432        model_config.v_head_dim(),
433        block_size,
434    )
435}
436
437pub trait DeviceMappedModelLoader {
438    /// Maximum activation size of non-mapped parts of this model.
439    /// Useful for the vision models which may prefer to keep the vison components on the GPU.
440    fn non_mapped_max_act_size_elems(
441        &self,
442        config: &str,
443        params: &AutoDeviceMapParams,
444    ) -> Result<usize>;
445    /// Maximum activation size of mapped parts of the model
446    fn mapped_max_act_size_elems(
447        &self,
448        config: &str,
449        params: &AutoDeviceMapParams,
450        prompt_chunksize: usize,
451    ) -> Result<usize>;
452    /// weight_pack_factor only applies to quantized weights.
453    fn non_mapped_size_in_bytes(
454        &self,
455        config: &str,
456        dtype: DType,
457        weight_pack_factor: usize,
458    ) -> Result<usize>;
459    /// weight_pack_factor only applies to quantized weights.
460    fn layer_sizes_in_bytes(
461        &self,
462        config: &str,
463        dtype: DType,
464        weight_pack_factor: usize,
465    ) -> Result<Vec<usize>>;
466    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
467        None
468    }
469    fn num_layers(&self, config: &str) -> Result<usize>;
470    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
471
472    #[allow(clippy::too_many_arguments)]
473    fn get_device_layers(
474        &self,
475        config: &str,
476        num_layers: usize,
477        mut layer_sizes_in_bytes: Vec<usize>,
478        non_mapped_size_in_bytes: usize,
479        total_model_size_in_bytes: usize,
480        devices: &[Device],
481        dtype: DType,
482        params: &AutoDeviceMapParams,
483        prompt_chunksize: usize,
484        paged_attn_config: Option<&PagedAttentionConfig>,
485    ) -> Result<DeviceMapMetadata> {
486        let mapped_max_act_size_in_bytes =
487            self.mapped_max_act_size_elems(config, params, prompt_chunksize)?
488                * dtype.size_in_bytes();
489        let non_mapped_max_act_size_in_bytes =
490            self.non_mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
491
492        let mut remaining_to_map = total_model_size_in_bytes;
493
494        let max_seq_len = match params {
495            AutoDeviceMapParams::Text { max_seq_len, .. }
496            | AutoDeviceMapParams::Vision { max_seq_len, .. } => *max_seq_len,
497        };
498        let max_batch_size = match params {
499            AutoDeviceMapParams::Text { max_batch_size, .. }
500            | AutoDeviceMapParams::Vision { max_batch_size, .. } => *max_batch_size,
501        };
502
503        let model_cfg = self.model_config(config)?;
504        let kv_cache_size_elems = match paged_attn_config {
505            Some(paged_attn_config) => {
506                let cache_config = calculate_cache_config(
507                    paged_attn_config.mem_gpu,
508                    paged_attn_config.mem_cpu,
509                    Some(
510                        paged_attn_config
511                            .block_size
512                            .unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE),
513                    ),
514                    dtype,
515                    &*model_cfg,
516                    &devices[0],
517                    &devices.iter().map(|x| Some(x.clone())).collect::<Vec<_>>(),
518                    true,
519                )?;
520
521                let key_block_shape =
522                    calculate_key_block_shape(&*model_cfg, dtype, cache_config.block_size);
523                let key_block_size = cache_config.num_gpu_blocks
524                    * key_block_shape.0
525                    * key_block_shape.1
526                    * key_block_shape.2
527                    * key_block_shape.3;
528
529                let value_block_shape = calculate_value_block_shape(
530                    &*self.model_config(config)?,
531                    cache_config.block_size,
532                );
533                let value_block_size = cache_config.num_gpu_blocks
534                    * value_block_shape.0
535                    * value_block_shape.1
536                    * value_block_shape.2;
537
538                key_block_size + value_block_size
539            }
540            None => {
541                // Non-paged KV cache
542                let key_block_shape = [
543                    max_batch_size,
544                    model_cfg.num_kv_heads(),
545                    max_seq_len,
546                    model_cfg.k_head_dim(),
547                ];
548                let value_block_shape = [
549                    max_batch_size,
550                    model_cfg.num_kv_heads(),
551                    max_seq_len,
552                    model_cfg.v_head_dim(),
553                ];
554
555                key_block_shape.into_iter().product::<usize>()
556                    + value_block_shape.iter().product::<usize>()
557            }
558        };
559        let kv_cache_size_in_bytes = kv_cache_size_elems * dtype.size_in_bytes();
560
561        let mut per_layer_avail = Vec::new();
562        // Always add the CPU as fallback
563        for dev in [devices, &[Device::Cpu]].concat() {
564            let avail = MemoryUsage.get_memory_available(&dev)?;
565            per_layer_avail.push((avail, dev));
566        }
567        // Reverse so we don't use the cpu first!
568        per_layer_avail.reverse();
569
570        // Reverse layer sizes so we can pop
571        layer_sizes_in_bytes.reverse();
572
573        let mut device_layers = Vec::new();
574
575        info!("Using automatic device mapping parameters: {params}.");
576        if let Some(sub_models) = self.non_mapped_sub_models() {
577            let (_, last) = per_layer_avail.last().unwrap();
578            info!(
579                "The following sub-models will not be device mapped and will be loaded on {}: {}",
580                last.device_pretty_repr(),
581                sub_models.iter().map(|x| x.to_string()).join(", ")
582            );
583        }
584
585        let mut current_ordinal = 0;
586        let mut current_layer = 0;
587        let per_layer_avail_cpy = per_layer_avail.clone();
588        let mut mapping_includes_cpu = false;
589        while remaining_to_map > 0 && !per_layer_avail.is_empty() {
590            let (device_capacity, device) = per_layer_avail
591                .pop()
592                .context("No more devices to map to. The model does not fit on this system.")?;
593            // All usage of 90% of the memory as a maximum.
594            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
595            let device_capacity = (device_capacity as f64 * 0.90) as usize;
596
597            // Algorithm is to check the following:
598            // 1) (no mapping) if *everything* fits on the first dev (non mapped and mapped)
599            // 2) if the mapped activations plus remaining fits on the nth device
600            // 3) common case, iteratively find the optimal amount of layers to put on the nth device
601            //   - if this is the first dev: must hold the non-mapped act and non-mapped model
602            //   - otherwise, must hold the mapped act
603            #[allow(clippy::if_same_then_else)]
604            let layers_on_device = if current_ordinal == 0
605                && device_capacity
606                    >= remaining_to_map
607                        + non_mapped_max_act_size_in_bytes.max(mapped_max_act_size_in_bytes)
608                        + non_mapped_size_in_bytes
609                        + kv_cache_size_in_bytes * (num_layers - current_layer)
610            {
611                remaining_to_map = 0;
612
613                num_layers - current_layer
614            } else if current_ordinal != 0
615                && device_capacity
616                    >= remaining_to_map
617                        + mapped_max_act_size_in_bytes
618                        + kv_cache_size_in_bytes * (num_layers - current_layer)
619            {
620                remaining_to_map = 0;
621
622                num_layers - current_layer
623            } else {
624                // All devices need to account for the max mapped act size
625                let mut used_capacity = mapped_max_act_size_in_bytes;
626                let mut used_capacity_no_act = 0;
627                let mut layers_on_device = 0;
628
629                // Device w/ ordinal 0 carries the non-mapped things
630                if current_ordinal == 0 {
631                    // Ensure the activations are properly handled
632                    used_capacity = used_capacity.max(non_mapped_max_act_size_in_bytes);
633                    used_capacity += non_mapped_size_in_bytes;
634                    used_capacity_no_act += non_mapped_size_in_bytes;
635                }
636
637                while let Some(&last) = layer_sizes_in_bytes.last() {
638                    let delta = last + kv_cache_size_in_bytes;
639                    if used_capacity + delta > device_capacity {
640                        break;
641                    }
642                    let _ = layer_sizes_in_bytes.pop().unwrap();
643                    used_capacity += delta;
644                    used_capacity_no_act += delta;
645                    layers_on_device += 1;
646                }
647
648                // Do not reduce amount to map if this device can't fit any
649                // If the device cannot fit any layers, warn the user.
650                if layers_on_device > 0 {
651                    remaining_to_map = remaining_to_map.saturating_sub(used_capacity_no_act);
652                } else {
653                    warn!(
654                        "Device {} can fit 0 layers. Consider reducing auto map params from current: {params} (ex. reducing max seq len or max num images)",
655                        device.device_pretty_repr(),
656                    );
657                    current_ordinal += 1;
658                    continue;
659                }
660                layers_on_device
661            };
662
663            // CPU mappings are automatically handled by the traditional device mapper, we can just leave them out here.
664            if !device.is_cpu() {
665                device_layers.push(DeviceLayerMapMetadata {
666                    ordinal: current_ordinal,
667                    layers: layers_on_device,
668                });
669                current_ordinal += 1;
670            } else {
671                mapping_includes_cpu = true;
672            }
673
674            current_layer += layers_on_device;
675        }
676        if remaining_to_map > 0 {
677            anyhow::bail!(
678                "This model does not fit on the devices {:?}, and exceeds total capacity by {}MB. Auto device mapping params: {params}",
679                per_layer_avail_cpy
680                    .iter()
681                    .rev()
682                    .map(|(avail, dev)| format!(
683                        "{} (avail: {}MB)",
684                        dev.device_pretty_repr(),
685                        avail / (1024 * 1024),
686                    ))
687                    .collect::<Vec<_>>(),
688                b_to_mb!(remaining_to_map)
689            );
690        }
691
692        // TODO: PagedAttention is not supported with CPU for now.
693        // Recalculate without PagedAttention metadata.
694        if paged_attn_config.is_some_and(|_| mapping_includes_cpu) {
695            return self.get_device_layers(
696                config,
697                num_layers,
698                layer_sizes_in_bytes,
699                non_mapped_size_in_bytes,
700                total_model_size_in_bytes,
701                devices,
702                dtype,
703                params,
704                prompt_chunksize,
705                None,
706            );
707        }
708
709        Ok(DeviceMapMetadata::from_num_device_layers(device_layers))
710    }
711}
712
713/// The `Loader` trait abstracts the loading process. The primary entrypoint is the
714/// `load_model` method.
715///
716/// # Example
717/// ```no_run
718/// use mistralrs_core::{Loader, TokenSource, DeviceMapSetting, AutoDeviceMapParams, ModelDType};
719/// use candle_core::Device;
720///
721/// let loader: Box<dyn Loader> = todo!();
722/// let pipeline = loader.load_model_from_hf(
723///     None,
724///     TokenSource::CacheToken,
725///     &ModelDType::Auto,
726///     &Device::cuda_if_available(0).unwrap(),
727///     false,
728///     DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
729///     None,
730///     None,
731/// ).unwrap();
732/// ```
733pub trait Loader: Send + Sync {
734    /// If `revision` is None, then it defaults to `main`.
735    /// If `dtype` is None, then it defaults to the model default (usually BF16).
736    /// If model is not found on HF, will attempt to resolve locally.
737    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
738    fn load_model_from_hf(
739        &self,
740        revision: Option<String>,
741        token_source: TokenSource,
742        dtype: &dyn TryIntoDType,
743        device: &Device,
744        silent: bool,
745        mapper: DeviceMapSetting,
746        in_situ_quant: Option<IsqType>,
747        paged_attn_config: Option<PagedAttentionConfig>,
748    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
749
750    /// Load a model from the specified paths.
751    /// Also initializes `DEBUG`.
752    #[allow(
753        clippy::type_complexity,
754        clippy::too_many_arguments,
755        clippy::borrowed_box
756    )]
757    fn load_model_from_path(
758        &self,
759        paths: &Box<dyn ModelPaths>,
760        dtype: &dyn TryIntoDType,
761        device: &Device,
762        silent: bool,
763        mapper: DeviceMapSetting,
764        in_situ_quant: Option<IsqType>,
765        paged_attn_config: Option<PagedAttentionConfig>,
766    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
767
768    fn get_id(&self) -> String;
769    fn get_kind(&self) -> ModelKind;
770}