mistralrs_core/pipeline/loaders/
vision_loaders.rs

1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::ShardedVarBuilder;
11
12#[cfg(feature = "pyo3_macros")]
13use pyo3::pyclass;
14
15use regex::Regex;
16use serde::Deserialize;
17
18use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
19
20use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
21use crate::amoe::AnyMoeBaseModelMixin;
22use crate::device_map::DeviceMapper;
23use crate::layers::Conv3dConfig;
24use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
25use crate::pipeline::isq::IsqModelLoader;
26use crate::pipeline::loaders::AutoDeviceMapParams;
27use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
28use crate::pipeline::{EitherCache, IsqModel, Processor, ProcessorCreator, VisionPromptPrefixer};
29use crate::utils::varbuilder_utils::DeviceForLoadTensor;
30use crate::vision_models::clip::ClipConfig;
31use crate::vision_models::gemma3::config::Gemma3Config;
32use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
33use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
34use crate::vision_models::idefics2_input_processor::Idefics2Processor;
35use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
36use crate::vision_models::image_processor::ImagePreProcessor;
37use crate::vision_models::inputs_processor::Phi4MMProcessor;
38use crate::vision_models::llama4::{
39    self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
40};
41use crate::vision_models::llava::config::Config as LLaVAConfig;
42use crate::vision_models::llava15::Model as LLaVA;
43use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
44use crate::vision_models::llava_next::Model as LLaVANext;
45use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
46use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
47use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
48use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
49use crate::vision_models::phi3_inputs_processor::Phi3Processor;
50use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
51use crate::vision_models::preprocessor_config::PreProcessorConfig;
52use crate::vision_models::processor_config::ProcessorConfig;
53use crate::vision_models::qwen2_5_vl::{
54    Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
55};
56use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
57use crate::vision_models::{minicpmo, phi4};
58
59pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
60    // pixel_values and pixel_attention_mask only specified for prompt seqs
61    #[allow(clippy::too_many_arguments)]
62    fn forward(
63        &self,
64        input_ids: &Tensor,
65        pixel_values: Option<Tensor>,
66        seqlen_offsets: &[usize],
67        context_lens: Vec<(usize, usize)>,
68        position_ids: Vec<usize>,
69        model_specific_args: Box<dyn Any>, // pixel attention mask, or image sizes, or anything else
70        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
71        flash_params: &FlashParams,
72    ) -> candle_core::Result<Tensor>;
73    fn device(&self) -> &Device;
74    fn cache(&self) -> &EitherCache;
75    fn cache_mut(&mut self) -> &mut EitherCache;
76    fn max_seq_len(&self) -> usize;
77    fn has_conv2d(&self) -> bool;
78    fn config(&self) -> &ModelConfigMetadata;
79    /// For a prompt without images. Requires batch size of 1!
80    fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
81}
82
83pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
84    fn load(
85        &self,
86        config: &str,
87        use_flash_attn: bool,
88        vb: ShardedVarBuilder,
89        normal_loading_metadata: NormalLoadingMetadata,
90        attention_mechanism: AttentionImplementation,
91    ) -> Result<Box<dyn VisionModel + Send + Sync>>;
92    fn is_gptx(&self) -> bool;
93    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>>;
94    fn get_processor(
95        &self,
96        model_config: &str,
97        processor_config: Option<ProcessorConfig>,
98        preprocessor_config: PreProcessorConfig,
99        max_edge: Option<u32>,
100    ) -> Arc<dyn Processor + Send + Sync>;
101    fn supports_paged_attention(&self) -> bool;
102    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer>;
103    fn get_device_for_tensor(
104        &self,
105        config: &str,
106        _mapper: &dyn DeviceMapper,
107        loading_isq: bool,
108    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
109        if loading_isq {
110            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
111        } else {
112            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
113            let num_layers = self.model_config(config)?.num_layers();
114            let closure = move |name: String| {
115                if let Some(captures) = re.captures(&name) {
116                    captures
117                        .get(1)
118                        .and_then(|m| m.as_str().parse::<usize>().ok())
119                        .map(|l| l.min(num_layers))
120                        .map(DeviceForLoadTensor::Idx)
121                        .unwrap_or(DeviceForLoadTensor::Base)
122                } else {
123                    DeviceForLoadTensor::Base
124                }
125            };
126
127            Ok(Arc::new(closure))
128        }
129    }
130}
131
132#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
133#[derive(Clone, Debug, Deserialize, PartialEq)]
134/// The architecture to load the vision model as.
135pub enum VisionLoaderType {
136    #[serde(rename = "phi3v")]
137    Phi3V,
138    #[serde(rename = "idefics2")]
139    Idefics2,
140    #[serde(rename = "llava_next")]
141    LLaVANext,
142    #[serde(rename = "llava")]
143    LLaVA,
144    #[serde(rename = "vllama")]
145    VLlama,
146    #[serde(rename = "qwen2vl")]
147    Qwen2VL,
148    #[serde(rename = "idefics3")]
149    Idefics3,
150    #[serde(rename = "minicpmo")]
151    MiniCpmO,
152    #[serde(rename = "phi4mm")]
153    Phi4MM,
154    #[serde(rename = "qwen2_5vl")]
155    Qwen2_5VL,
156    #[serde(rename = "gemma3")]
157    Gemma3,
158    #[serde(rename = "mistral3")]
159    Mistral3,
160    #[serde(rename = "llama4")]
161    Llama4,
162}
163
164impl FromStr for VisionLoaderType {
165    type Err = String;
166    fn from_str(s: &str) -> Result<Self, Self::Err> {
167        match s {
168            "phi3v" => Ok(Self::Phi3V),
169            "idefics2" => Ok(Self::Idefics2),
170            "llava_next" => Ok(Self::LLaVANext),
171            "llava" => Ok(Self::LLaVA),
172            "vllama" => Ok(Self::VLlama),
173            "qwen2vl" => Ok(Self::Qwen2VL),
174            "idefics3" => Ok(Self::Idefics3),
175            "minicpmo" => Ok(Self::MiniCpmO),
176            "phi4mm" => Ok(Self::Phi4MM),
177            "qwen2_5vl" => Ok(Self::Qwen2_5VL),
178            "gemma3" => Ok(Self::Gemma3),
179            "mistral3" => Ok(Self::Mistral3),
180            "llama4" => Ok(Self::Llama4),
181            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`.")),
182        }
183    }
184}
185
186macro_rules! bias_if {
187    ($cond:expr, $size:expr) => {
188        if $cond {
189            $size
190        } else {
191            0
192        }
193    };
194}
195
196fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
197    let pre_layer_norm = cfg.hidden_size;
198    let final_layer_norm = cfg.hidden_size;
199
200    let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
201    let num_positions = num_patches + 1;
202
203    let class_embedding = cfg.hidden_size;
204
205    let position_ids = num_positions;
206    let position_embedding = num_positions * cfg.hidden_size;
207
208    let conv2dconfig = Conv2dConfig {
209        stride: cfg.patch_size,
210        ..Default::default()
211    };
212    let patch_embedding =
213        cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
214
215    let encoder_layer_elems = {
216        let layer_norm1 = cfg.hidden_size;
217        let layer_norm2 = cfg.hidden_size;
218
219        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
220        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
221        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
222        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
223
224        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
225        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
226
227        layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
228    };
229
230    pre_layer_norm
231        + final_layer_norm
232        + class_embedding
233        + position_ids
234        + position_embedding
235        + patch_embedding
236        + cfg.num_hidden_layers * encoder_layer_elems
237}
238
239// ======================== Phi 3 loader
240
241/// [`VisionLoader`] for a Phi 3 Vision model.
242///
243/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
244pub struct Phi3VLoader;
245
246pub struct Phi3VPrefixer;
247
248impl VisionPromptPrefixer for Phi3VPrefixer {
249    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
250        // Image indexing starts at 0.
251        format!(
252            "{}{prompt}",
253            image_indexes
254                .into_iter()
255                .map(|image_index| format!("<|image_{}|>", image_index + 1))
256                .join("")
257        )
258    }
259}
260
261impl VisionModelLoader for Phi3VLoader {
262    fn load(
263        &self,
264        config: &str,
265        use_flash_attn: bool,
266        vb: ShardedVarBuilder,
267        normal_loading_metadata: NormalLoadingMetadata,
268        attention_mechanism: AttentionImplementation,
269    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
270        let mut config: Phi3Config = serde_json::from_str(config)?;
271        config.use_flash_attn = use_flash_attn;
272        Ok(Box::new(Phi3::new(
273            &config,
274            vb,
275            self.is_gptx(),
276            normal_loading_metadata,
277            attention_mechanism,
278        )?))
279    }
280    fn is_gptx(&self) -> bool {
281        true
282    }
283    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
284        let mut config: Phi3Config = serde_json::from_str(config)?;
285        config.use_flash_attn = use_flash_attn;
286        Ok(Box::new(config))
287    }
288    fn get_processor(
289        &self,
290        _model_config: &str,
291        processor_config: Option<ProcessorConfig>,
292        preprocessor_config: PreProcessorConfig,
293        _max_edge: Option<u32>,
294    ) -> Arc<dyn Processor + Send + Sync> {
295        Phi3Processor::new_processor(processor_config, preprocessor_config)
296    }
297    fn supports_paged_attention(&self) -> bool {
298        true
299    }
300    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
301        Arc::new(Phi3VPrefixer)
302    }
303}
304
305impl IsqModelLoader for Phi3VLoader {
306    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
307        Ok(vec![
308            Regex::new(r"lm_head\.(weight|bias)$")?,
309            // Attention
310            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
311            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
312            // MLP
313            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
314            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
315        ])
316    }
317}
318
319impl DeviceMappedModelLoader for Phi3VLoader {
320    fn mapped_max_act_size_elems(
321        &self,
322        config: &str,
323        params: &AutoDeviceMapParams,
324        _prompt_chunksize: usize,
325    ) -> Result<usize> {
326        // NOTE: we ignore max_num_images although it can only be one...
327        let AutoDeviceMapParams::Vision {
328            max_seq_len,
329            max_batch_size,
330            max_image_shape: _,
331            max_num_images,
332        } = params
333        else {
334            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
335        };
336
337        let cfg: Phi3Config = serde_json::from_str(config)?;
338
339        let vcfg = &PHI3V_CLIP_CONFIG;
340
341        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
342        let img_seq_len = (num_patches + 1) * max_num_images;
343
344        let max_text_attn = {
345            // This model injects the vision information directly into the input embeddings
346            let max_seq_len = img_seq_len + max_seq_len;
347            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
348        };
349
350        Ok(max_text_attn)
351    }
352
353    fn non_mapped_max_act_size_elems(
354        &self,
355        config: &str,
356        params: &AutoDeviceMapParams,
357    ) -> Result<usize> {
358        // NOTE: we ignore max_num_images although it can only be one...
359        let AutoDeviceMapParams::Vision {
360            max_seq_len: _,
361            max_batch_size,
362            max_image_shape: _,
363            max_num_images,
364        } = params
365        else {
366            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
367        };
368
369        let cfg: Phi3Config = serde_json::from_str(config)?;
370
371        let vcfg = &PHI3V_CLIP_CONFIG;
372
373        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
374        let img_seq_len = num_patches + 1;
375
376        let max_vision_attn = {
377            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
378        };
379
380        Ok(max_vision_attn)
381    }
382
383    fn non_mapped_size_in_bytes(
384        &self,
385        config: &str,
386        dtype: DType,
387        weight_pack_factor: usize,
388    ) -> Result<usize> {
389        let cfg: Phi3Config = serde_json::from_str(config)?;
390        let elems = {
391            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
392            let lm_head = if !cfg.tie_word_embeddings {
393                cfg.hidden_size * cfg.vocab_size
394            } else {
395                0
396            };
397            let norm = cfg.hidden_size;
398
399            let image_embed = {
400                let projection_cls = cfg
401                    .embd_layer
402                    .projection_cls
403                    .clone()
404                    .unwrap_or("linear".to_string());
405                let with_learnable_separator =
406                    cfg.embd_layer.with_learnable_separator.unwrap_or(false);
407                let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
408                let image_dim_out = cfg.img_processor.image_dim_out;
409
410                let proj = match (projection_cls.as_str(), use_hd_transform) {
411                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
412                    ("mlp", true) => {
413                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
414                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
415                        a + b
416                    }
417                    ("mlp", false) => {
418                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
419                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
420                        a + b
421                    }
422                    _ => {
423                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
424                    }
425                };
426
427                let (glb_gn, sub_gn) = if with_learnable_separator {
428                    let glb_gn = image_dim_out * 4;
429                    let sub_gn = image_dim_out * 4;
430                    (glb_gn, sub_gn)
431                } else {
432                    (0, 0)
433                };
434
435                let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
436
437                proj + glb_gn + sub_gn + clip_vit
438            };
439
440            embed_tokens + lm_head + norm + image_embed
441        };
442
443        Ok(elems * dtype.size_in_bytes())
444    }
445
446    fn layer_sizes_in_bytes(
447        &self,
448        config: &str,
449        dtype: DType,
450        weight_pack_factor: usize,
451    ) -> Result<Vec<usize>> {
452        let cfg: Phi3Config = serde_json::from_str(config)?;
453        let per_layer_elems = {
454            let input_layernorm = cfg.hidden_size;
455            let post_attention_layernorm = cfg.hidden_size;
456
457            let size_in = cfg.hidden_size;
458            let head_dim = cfg.head_dim();
459            let op_size = head_dim * head_dim + 2 * cfg.num_key_value_heads * head_dim;
460            let qkv_proj = size_in * op_size / weight_pack_factor;
461            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
462
463            let h_size = cfg.hidden_size;
464            let i_size = cfg.intermediate_size;
465            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
466            let down_proj = h_size * i_size / weight_pack_factor;
467
468            input_layernorm
469                + post_attention_layernorm
470                + qkv_proj
471                + o_proj
472                + gate_up_proj
473                + down_proj
474        };
475        Ok(vec![
476            per_layer_elems * dtype.size_in_bytes();
477            cfg.num_hidden_layers
478        ])
479    }
480
481    fn num_layers(&self, config: &str) -> Result<usize> {
482        let cfg: Phi3Config = serde_json::from_str(config)?;
483        Ok(cfg.num_hidden_layers)
484    }
485
486    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
487        let cfg: Phi3Config = serde_json::from_str(config)?;
488
489        let cfg = ModelConfigMetadata {
490            max_seq_len: cfg.max_position_embeddings,
491            num_layers: cfg.num_hidden_layers,
492            hidden_size: cfg.hidden_size,
493            num_kv_heads: cfg.num_key_value_heads,
494            num_attn_heads: cfg.num_attention_heads,
495            sliding_window: cfg.sliding_window,
496            k_head_dim: cfg.head_dim(),
497            v_head_dim: cfg.head_dim(),
498        };
499
500        Ok(Box::new(cfg))
501    }
502
503    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
504        Some(vec![NonMappedSubModel::Vision])
505    }
506}
507
508// ======================== Idefics 2 loader
509
510/// [`VisionLoader`] for an Idefics 2 Vision model.
511///
512/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
513pub struct Idefics2Loader;
514
515pub struct Idefics2Prefixer;
516
517impl VisionPromptPrefixer for Idefics2Prefixer {
518    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
519        // Chat template does it
520        prompt.to_string()
521    }
522}
523
524impl VisionModelLoader for Idefics2Loader {
525    fn load(
526        &self,
527        config: &str,
528        use_flash_attn: bool,
529        vb: ShardedVarBuilder,
530        normal_loading_metadata: NormalLoadingMetadata,
531        attention_mechanism: AttentionImplementation,
532    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
533        let mut config: Idefics2Config = serde_json::from_str(config)?;
534        config.text_config.use_flash_attn = use_flash_attn;
535        Ok(Box::new(Idefics2::new(
536            &config,
537            vb,
538            self.is_gptx(),
539            normal_loading_metadata,
540            attention_mechanism,
541        )?))
542    }
543    fn is_gptx(&self) -> bool {
544        true
545    }
546    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
547        let mut config: Idefics2Config = serde_json::from_str(config)?;
548        config.text_config.use_flash_attn = use_flash_attn;
549        Ok(Box::new(config))
550    }
551    fn get_processor(
552        &self,
553        _model_config: &str,
554        processor_config: Option<ProcessorConfig>,
555        preprocessor_config: PreProcessorConfig,
556        max_edge: Option<u32>,
557    ) -> Arc<dyn Processor + Send + Sync> {
558        Arc::new(Idefics2Processor::new(
559            processor_config.unwrap(),
560            preprocessor_config,
561            max_edge,
562        ))
563    }
564    fn supports_paged_attention(&self) -> bool {
565        true
566    }
567    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
568        Arc::new(Idefics2Prefixer)
569    }
570}
571
572impl IsqModelLoader for Idefics2Loader {
573    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
574        Ok(vec![
575            Regex::new(r"lm_head\.(weight|bias)$")?,
576            // Attention
577            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
578            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
579            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
580            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
581            // MLP
582            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
583            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
584            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
585        ])
586    }
587}
588
589impl DeviceMappedModelLoader for Idefics2Loader {
590    fn mapped_max_act_size_elems(
591        &self,
592        config: &str,
593        params: &AutoDeviceMapParams,
594        _prompt_chunksize: usize,
595    ) -> Result<usize> {
596        let AutoDeviceMapParams::Vision {
597            max_seq_len,
598            max_batch_size,
599            max_image_shape: _,
600            max_num_images,
601        } = params
602        else {
603            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
604        };
605
606        let cfg: Idefics2Config = serde_json::from_str(config)?;
607
608        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
609        let img_seq_len = (num_patches + 1) * max_num_images;
610
611        let max_text_attn = {
612            // This model injects the vision information directly into the input embeddings
613            let max_seq_len = img_seq_len + max_seq_len;
614            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
615        };
616
617        Ok(max_text_attn)
618    }
619
620    fn non_mapped_max_act_size_elems(
621        &self,
622        config: &str,
623        params: &AutoDeviceMapParams,
624    ) -> Result<usize> {
625        let AutoDeviceMapParams::Vision {
626            max_seq_len: _,
627            max_batch_size,
628            max_image_shape: _,
629            max_num_images,
630        } = params
631        else {
632            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
633        };
634
635        let cfg: Idefics2Config = serde_json::from_str(config)?;
636
637        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
638        let img_seq_len = num_patches + 1;
639
640        let max_vision_attn = {
641            // do_image_splitting = true
642            let images_factor = 5;
643
644            (max_batch_size * images_factor * max_num_images)
645                * cfg.vision_config.num_attention_heads
646                * img_seq_len
647                * img_seq_len
648        };
649
650        Ok(max_vision_attn)
651    }
652
653    fn non_mapped_size_in_bytes(
654        &self,
655        config: &str,
656        dtype: DType,
657        weight_pack_factor: usize,
658    ) -> Result<usize> {
659        let cfg: Idefics2Config = serde_json::from_str(config)?;
660        let text_elems = {
661            let tie_word_embeddings = cfg.tie_word_embeddings;
662            let cfg = &cfg.text_config;
663
664            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
665            let lm_head = if !tie_word_embeddings {
666                cfg.hidden_size * cfg.vocab_size
667            } else {
668                0
669            };
670            let norm = cfg.hidden_size;
671            embed_tokens + lm_head + norm
672        };
673
674        let connector_elems = {
675            let tcfg = &cfg.text_config;
676            let vcfg = &cfg.vision_config;
677            let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
678            let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
679            let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
680
681            let perceiver_elems = {
682                let tcfg = &cfg.text_config;
683                let pcfg = &cfg.perceiver_config;
684
685                let n_latents = pcfg.resampler_n_latents;
686                let hidden_size = tcfg.hidden_size;
687                let depth = pcfg.resampler_depth;
688
689                let norm = tcfg.hidden_size;
690                let latents = n_latents * hidden_size;
691
692                let layer_elems = {
693                    let input_latents_norm = hidden_size;
694                    let input_context_norm = hidden_size;
695                    let post_attn_norm = hidden_size;
696
697                    let num_heads = pcfg.resampler_n_heads;
698                    let head_dim = pcfg.resampler_head_dim;
699                    let num_key_value_heads = pcfg.num_key_value_heads;
700
701                    let q_proj = hidden_size * num_heads * head_dim;
702                    let k_proj = hidden_size * num_key_value_heads * head_dim;
703                    let v_proj = hidden_size * num_key_value_heads * head_dim;
704                    let o_proj = num_heads * head_dim * hidden_size;
705
706                    let gate_proj = hidden_size * hidden_size * 4;
707                    let up_proj = hidden_size * hidden_size * 4;
708                    let down_proj = hidden_size * 4 * hidden_size;
709
710                    input_latents_norm
711                        + input_context_norm
712                        + post_attn_norm
713                        + q_proj
714                        + k_proj
715                        + v_proj
716                        + o_proj
717                        + gate_proj
718                        + up_proj
719                        + down_proj
720                };
721
722                norm + latents + layer_elems * depth
723            };
724
725            gate_proj + up_proj + down_proj + perceiver_elems
726        };
727
728        let vision_transformer = {
729            let cfg = &cfg.vision_config;
730
731            let post_layernorm = cfg.hidden_size;
732
733            let conv_config = Conv2dConfig {
734                stride: cfg.patch_size,
735                ..Default::default()
736            };
737            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
738                * cfg.patch_size
739                * cfg.patch_size;
740
741            let num_patches_per_side = cfg.image_size / cfg.patch_size;
742            let num_patches = num_patches_per_side.pow(2);
743            let position_embedding = num_patches * cfg.hidden_size;
744
745            let layer_elems = {
746                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
747                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
748
749                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
750                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
751
752                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
753                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
754                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
755                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
756
757                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
758            };
759
760            post_layernorm + patch_embedding + position_embedding + layer_elems
761        };
762
763        let elems = text_elems + connector_elems + vision_transformer;
764
765        Ok(elems * dtype.size_in_bytes())
766    }
767
768    fn layer_sizes_in_bytes(
769        &self,
770        config: &str,
771        dtype: DType,
772        weight_pack_factor: usize,
773    ) -> Result<Vec<usize>> {
774        let cfg: Idefics2Config = serde_json::from_str(config)?;
775        let cfg = cfg.text_config;
776        let per_layer_elems = {
777            let input_layernorm = cfg.hidden_size;
778            let post_attention_layernorm = cfg.hidden_size;
779
780            let size_in = cfg.hidden_size;
781            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
782            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
783            let q_proj = size_in * size_q / weight_pack_factor;
784            let k_proj = size_in * size_kv / weight_pack_factor;
785            let v_proj = size_in * size_kv / weight_pack_factor;
786            let o_proj = size_q * size_in / weight_pack_factor;
787
788            let h_size = cfg.hidden_size;
789            let i_size = cfg.intermediate_size;
790            let gate_proj = h_size * i_size / weight_pack_factor;
791            let up_proj = h_size * i_size / weight_pack_factor;
792            let down_proj = i_size * h_size / weight_pack_factor;
793
794            input_layernorm
795                + post_attention_layernorm
796                + q_proj
797                + k_proj
798                + v_proj
799                + o_proj
800                + gate_proj
801                + up_proj
802                + down_proj
803        };
804        Ok(vec![
805            per_layer_elems * dtype.size_in_bytes();
806            cfg.num_hidden_layers
807        ])
808    }
809
810    fn num_layers(&self, config: &str) -> Result<usize> {
811        let cfg: Idefics2Config = serde_json::from_str(config)?;
812        Ok(cfg.text_config.num_hidden_layers)
813    }
814    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
815        let cfg: Idefics2Config = serde_json::from_str(config)?;
816        let cfg = &cfg.text_config;
817
818        let cfg = ModelConfigMetadata {
819            max_seq_len: cfg.max_position_embeddings,
820            num_layers: cfg.num_hidden_layers,
821            hidden_size: cfg.hidden_size,
822            num_kv_heads: cfg.num_key_value_heads,
823            num_attn_heads: cfg.num_attention_heads,
824            sliding_window: cfg.sliding_window,
825            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
826            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
827        };
828
829        Ok(Box::new(cfg))
830    }
831
832    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
833        Some(vec![NonMappedSubModel::Vision])
834    }
835}
836
837// ======================== LLaVANext Loader
838
839/// [`VisionLoader`] for an LLaVANext Vision model.
840///
841/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
842pub struct LLaVANextLoader;
843
844pub struct LLaVANextPrefixer;
845
846impl VisionPromptPrefixer for LLaVANextPrefixer {
847    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
848        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
849    }
850}
851
852impl VisionModelLoader for LLaVANextLoader {
853    fn load(
854        &self,
855        config: &str,
856        use_flash_attn: bool,
857        vb: ShardedVarBuilder,
858        normal_loading_metadata: NormalLoadingMetadata,
859        attention_mechanism: AttentionImplementation,
860    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
861        let mut config: LLaVAConfig = serde_json::from_str(config)?;
862        config.use_flash_attn = use_flash_attn;
863        Ok(Box::new(LLaVANext::new(
864            &config,
865            vb,
866            self.is_gptx(),
867            normal_loading_metadata,
868            attention_mechanism,
869        )?))
870    }
871    fn is_gptx(&self) -> bool {
872        false
873    }
874    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
875        let mut config: LLaVAConfig = serde_json::from_str(config)?;
876        config.use_flash_attn = use_flash_attn;
877        Ok(Box::new(config))
878    }
879    fn get_processor(
880        &self,
881        model_config: &str,
882        _processor_config: Option<ProcessorConfig>,
883        _preprocessor_config: PreProcessorConfig,
884        _max_edge: Option<u32>,
885    ) -> Arc<dyn Processor + Send + Sync> {
886        Arc::new(LLaVANextProcessor::new(model_config))
887    }
888    fn supports_paged_attention(&self) -> bool {
889        true
890    }
891    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
892        Arc::new(LLaVANextPrefixer)
893    }
894}
895
896impl IsqModelLoader for LLaVANextLoader {
897    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
898        Ok(vec![
899            Regex::new(r"lm_head\.(weight|bias)$")?,
900            // Attention
901            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
902            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
903            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
904            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
905            // MLP
906            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
907            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
908            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
909        ])
910    }
911}
912
913impl DeviceMappedModelLoader for LLaVANextLoader {
914    fn mapped_max_act_size_elems(
915        &self,
916        config: &str,
917        params: &AutoDeviceMapParams,
918        _prompt_chunksize: usize,
919    ) -> Result<usize> {
920        let AutoDeviceMapParams::Vision {
921            max_seq_len,
922            max_batch_size,
923            max_image_shape,
924            max_num_images,
925        } = params
926        else {
927            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
928        };
929
930        let config: LLaVAConfig = serde_json::from_str(config)?;
931
932        #[allow(clippy::cast_possible_truncation)]
933        let img_seq_len =
934            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
935                &config,
936                (max_image_shape.0 as u32, max_image_shape.1 as u32),
937            );
938        let img_seq_len = img_seq_len * max_num_images;
939
940        let max_text_attn = {
941            let cfg = &config.text_config;
942            // This model injects the vision information directly into the input embeddings
943            let max_seq_len = img_seq_len + max_seq_len;
944
945            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
946        };
947
948        Ok(max_text_attn)
949    }
950
951    fn non_mapped_max_act_size_elems(
952        &self,
953        config: &str,
954        params: &AutoDeviceMapParams,
955    ) -> Result<usize> {
956        let AutoDeviceMapParams::Vision {
957            max_seq_len: _,
958            max_batch_size,
959            max_image_shape,
960            max_num_images,
961        } = params
962        else {
963            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
964        };
965
966        let config: LLaVAConfig = serde_json::from_str(config)?;
967
968        #[allow(clippy::cast_possible_truncation)]
969        let img_seq_len =
970            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
971                &config,
972                (max_image_shape.0 as u32, max_image_shape.1 as u32),
973            );
974
975        let max_vision_attn = {
976            (max_batch_size * max_num_images)
977                * config.vision_config.num_attention_heads
978                * img_seq_len
979                * img_seq_len
980        };
981
982        Ok(max_vision_attn)
983    }
984
985    fn non_mapped_size_in_bytes(
986        &self,
987        config: &str,
988        dtype: DType,
989        weight_pack_factor: usize,
990    ) -> Result<usize> {
991        let cfg: LLaVAConfig = serde_json::from_str(config)?;
992        let text_elems = {
993            let cfg = &cfg.text_config;
994            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
995            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
996            let norm = cfg.hidden_size;
997            embed_tokens + lm_head + norm
998        };
999
1000        let image_newline = cfg.text_config.hidden_size;
1001        let mmproj = {
1002            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1003                + cfg.text_config.hidden_size;
1004            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1005                + cfg.text_config.hidden_size;
1006
1007            linear_1 + linear_2
1008        };
1009        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1010
1011        let elems = text_elems + image_newline + mmproj + vision_tower;
1012        Ok(elems * dtype.size_in_bytes())
1013    }
1014
1015    fn layer_sizes_in_bytes(
1016        &self,
1017        config: &str,
1018        dtype: DType,
1019        weight_pack_factor: usize,
1020    ) -> Result<Vec<usize>> {
1021        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1022        let per_layer_elems = {
1023            let cfg = &cfg.text_config;
1024            let input_layernorm = cfg.hidden_size;
1025            let post_attention_layernorm = cfg.hidden_size;
1026
1027            let size_in = cfg.hidden_size;
1028            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1029            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1030            let q_proj = size_in * size_q / weight_pack_factor;
1031            let k_proj = size_in * size_kv / weight_pack_factor;
1032            let v_proj = size_in * size_kv / weight_pack_factor;
1033            let o_proj = size_q * size_in / weight_pack_factor;
1034
1035            let h_size = cfg.hidden_size;
1036            let i_size = cfg.intermediate_size;
1037            let gate_proj = h_size * i_size / weight_pack_factor;
1038            let up_proj = h_size * i_size / weight_pack_factor;
1039            let down_proj = i_size * h_size / weight_pack_factor;
1040
1041            input_layernorm
1042                + post_attention_layernorm
1043                + q_proj
1044                + k_proj
1045                + v_proj
1046                + o_proj
1047                + gate_proj
1048                + up_proj
1049                + down_proj
1050        };
1051        Ok(vec![
1052            per_layer_elems * dtype.size_in_bytes();
1053            cfg.text_config.num_hidden_layers
1054        ])
1055    }
1056
1057    fn num_layers(&self, config: &str) -> Result<usize> {
1058        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1059        Ok(cfg.text_config.num_hidden_layers)
1060    }
1061
1062    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1063        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1064        let cfg = &cfg.text_config;
1065
1066        let cfg = ModelConfigMetadata {
1067            max_seq_len: cfg.max_position_embeddings,
1068            num_layers: cfg.num_hidden_layers,
1069            hidden_size: cfg.hidden_size,
1070            num_kv_heads: cfg.num_key_value_heads,
1071            num_attn_heads: cfg.num_attention_heads,
1072            sliding_window: cfg.sliding_window,
1073            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1074            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1075        };
1076
1077        Ok(Box::new(cfg))
1078    }
1079
1080    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1081        Some(vec![NonMappedSubModel::Vision])
1082    }
1083}
1084
1085// ======================== LLaVA Loader
1086
1087/// [`VisionLoader`] for an LLaVA Vision model.
1088///
1089/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1090pub struct LLaVALoader;
1091
1092pub struct LLaVAPrefixer;
1093
1094impl VisionPromptPrefixer for LLaVAPrefixer {
1095    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1096        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1097    }
1098}
1099
1100impl VisionModelLoader for LLaVALoader {
1101    fn load(
1102        &self,
1103        config: &str,
1104        use_flash_attn: bool,
1105        vb: ShardedVarBuilder,
1106        normal_loading_metadata: NormalLoadingMetadata,
1107        attention_mechanism: AttentionImplementation,
1108    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1109        let mut config: LLaVAConfig = serde_json::from_str(config)?;
1110        config.use_flash_attn = use_flash_attn;
1111        Ok(Box::new(LLaVA::new(
1112            &config,
1113            vb,
1114            self.is_gptx(),
1115            normal_loading_metadata,
1116            attention_mechanism,
1117        )?))
1118    }
1119    fn is_gptx(&self) -> bool {
1120        false
1121    }
1122    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1123        let mut config: LLaVAConfig = serde_json::from_str(config)?;
1124        config.use_flash_attn = use_flash_attn;
1125        Ok(Box::new(config))
1126    }
1127    fn get_processor(
1128        &self,
1129        model_config: &str,
1130        _processor_config: Option<ProcessorConfig>,
1131        _preprocessor_config: PreProcessorConfig,
1132        _max_edge: Option<u32>,
1133    ) -> Arc<dyn Processor + Send + Sync> {
1134        Arc::new(LLaVAProcessor::new(model_config))
1135    }
1136    fn supports_paged_attention(&self) -> bool {
1137        true
1138    }
1139    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
1140        Arc::new(LLaVAPrefixer)
1141    }
1142}
1143
1144impl IsqModelLoader for LLaVALoader {
1145    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1146        Ok(vec![
1147            Regex::new(r"lm_head\.(weight|bias)$")?,
1148            // Attention
1149            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1150            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1151            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1152            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1153            // MLP
1154            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1155            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1156            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1157        ])
1158    }
1159}
1160
1161impl DeviceMappedModelLoader for LLaVALoader {
1162    fn mapped_max_act_size_elems(
1163        &self,
1164        config: &str,
1165        params: &AutoDeviceMapParams,
1166        _prompt_chunksize: usize,
1167    ) -> Result<usize> {
1168        let AutoDeviceMapParams::Vision {
1169            max_seq_len,
1170            max_batch_size,
1171            max_image_shape: _,
1172            max_num_images,
1173        } = params
1174        else {
1175            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1176        };
1177
1178        let config: LLaVAConfig = serde_json::from_str(config)?;
1179
1180        let img_seq_len =
1181            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1182        let img_seq_len = img_seq_len * max_num_images;
1183
1184        let max_text_attn = {
1185            let cfg = &config.text_config;
1186            // This model injects the vision information directly into the input embeddings
1187            let max_seq_len = img_seq_len + max_seq_len;
1188
1189            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1190        };
1191
1192        Ok(max_text_attn)
1193    }
1194
1195    fn non_mapped_max_act_size_elems(
1196        &self,
1197        config: &str,
1198        params: &AutoDeviceMapParams,
1199    ) -> Result<usize> {
1200        let AutoDeviceMapParams::Vision {
1201            max_seq_len: _,
1202            max_batch_size,
1203            max_image_shape: _,
1204            max_num_images,
1205        } = params
1206        else {
1207            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1208        };
1209
1210        let config: LLaVAConfig = serde_json::from_str(config)?;
1211
1212        let img_seq_len =
1213            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1214
1215        let max_vision_attn = {
1216            (max_batch_size * max_num_images)
1217                * config.vision_config.num_attention_heads
1218                * img_seq_len
1219                * img_seq_len
1220        };
1221
1222        Ok(max_vision_attn)
1223    }
1224
1225    fn non_mapped_size_in_bytes(
1226        &self,
1227        config: &str,
1228        dtype: DType,
1229        weight_pack_factor: usize,
1230    ) -> Result<usize> {
1231        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1232        let text_elems = {
1233            let cfg = &cfg.text_config;
1234            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1235            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1236            let norm = cfg.hidden_size;
1237            embed_tokens + lm_head + norm
1238        };
1239
1240        let image_newline = cfg.text_config.hidden_size;
1241        let mmproj = {
1242            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1243                + cfg.text_config.hidden_size;
1244            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1245                + cfg.text_config.hidden_size;
1246
1247            linear_1 + linear_2
1248        };
1249        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1250
1251        let elems = text_elems + image_newline + mmproj + vision_tower;
1252        Ok(elems * dtype.size_in_bytes())
1253    }
1254
1255    fn layer_sizes_in_bytes(
1256        &self,
1257        config: &str,
1258        dtype: DType,
1259        weight_pack_factor: usize,
1260    ) -> Result<Vec<usize>> {
1261        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1262        let per_layer_elems = {
1263            let cfg = &cfg.text_config;
1264            let input_layernorm = cfg.hidden_size;
1265            let post_attention_layernorm = cfg.hidden_size;
1266
1267            let size_in = cfg.hidden_size;
1268            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1269            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1270            let q_proj = size_in * size_q / weight_pack_factor;
1271            let k_proj = size_in * size_kv / weight_pack_factor;
1272            let v_proj = size_in * size_kv / weight_pack_factor;
1273            let o_proj = size_q * size_in / weight_pack_factor;
1274
1275            let h_size = cfg.hidden_size;
1276            let i_size = cfg.intermediate_size;
1277            let gate_proj = h_size * i_size / weight_pack_factor;
1278            let up_proj = h_size * i_size / weight_pack_factor;
1279            let down_proj = i_size * h_size / weight_pack_factor;
1280
1281            input_layernorm
1282                + post_attention_layernorm
1283                + q_proj
1284                + k_proj
1285                + v_proj
1286                + o_proj
1287                + gate_proj
1288                + up_proj
1289                + down_proj
1290        };
1291        Ok(vec![
1292            per_layer_elems * dtype.size_in_bytes();
1293            cfg.text_config.num_hidden_layers
1294        ])
1295    }
1296
1297    fn num_layers(&self, config: &str) -> Result<usize> {
1298        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1299        Ok(cfg.text_config.num_hidden_layers)
1300    }
1301
1302    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1303        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1304        let cfg = &cfg.text_config;
1305
1306        let cfg = ModelConfigMetadata {
1307            max_seq_len: cfg.max_position_embeddings,
1308            num_layers: cfg.num_hidden_layers,
1309            hidden_size: cfg.hidden_size,
1310            num_kv_heads: cfg.num_key_value_heads,
1311            num_attn_heads: cfg.num_attention_heads,
1312            sliding_window: cfg.sliding_window,
1313            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1314            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1315        };
1316
1317        Ok(Box::new(cfg))
1318    }
1319
1320    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1321        Some(vec![NonMappedSubModel::Vision])
1322    }
1323}
1324
1325// ======================== MLlama Loader
1326
1327/// [`VisionLoader`] for an Llama Vision model.
1328///
1329/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1330pub struct VLlamaLoader;
1331
1332pub struct VLlamaPrefixer;
1333
1334impl VisionPromptPrefixer for VLlamaPrefixer {
1335    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1336        format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1337    }
1338}
1339
1340impl VisionModelLoader for VLlamaLoader {
1341    fn load(
1342        &self,
1343        config: &str,
1344        use_flash_attn: bool,
1345        vb: ShardedVarBuilder,
1346        normal_loading_metadata: NormalLoadingMetadata,
1347        attention_mechanism: AttentionImplementation,
1348    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1349        let mut config: MLlamaConfig = serde_json::from_str(config)?;
1350        config.text_config.use_flash_attn = use_flash_attn;
1351        Ok(Box::new(MLlamaModel::new(
1352            &config,
1353            vb,
1354            self.is_gptx(),
1355            normal_loading_metadata,
1356            attention_mechanism,
1357        )?))
1358    }
1359    fn is_gptx(&self) -> bool {
1360        true
1361    }
1362    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1363        let mut config: MLlamaConfig = serde_json::from_str(config)?;
1364        config.text_config.use_flash_attn = use_flash_attn;
1365        Ok(Box::new(config))
1366    }
1367    fn get_processor(
1368        &self,
1369        _model_config: &str,
1370        _processor_config: Option<ProcessorConfig>,
1371        _preprocessor_config: PreProcessorConfig,
1372        _max_edge: Option<u32>,
1373    ) -> Arc<dyn Processor + Send + Sync> {
1374        Arc::new(MLlamaProcessor::new())
1375    }
1376    fn supports_paged_attention(&self) -> bool {
1377        false
1378    }
1379    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
1380        Arc::new(VLlamaPrefixer)
1381    }
1382}
1383
1384impl IsqModelLoader for VLlamaLoader {
1385    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1386        let config: MLlamaConfig = serde_json::from_str(config)?;
1387        let cross_attn_layers = &config.text_config.cross_attention_layers;
1388        let transformer_layers =
1389            (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1390        let mut text_regexes = Vec::new();
1391        for layer in transformer_layers {
1392            text_regexes.extend(vec![
1393                // Attention text
1394                Regex::new(&format!(
1395                    r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1396                ))?,
1397                Regex::new(&format!(
1398                    r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1399                ))?,
1400                Regex::new(&format!(
1401                    r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1402                ))?,
1403                Regex::new(&format!(
1404                    r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1405                ))?,
1406                // MLP text
1407                Regex::new(&format!(
1408                    r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1409                ))?,
1410                Regex::new(&format!(
1411                    r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1412                ))?,
1413                Regex::new(&format!(
1414                    r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1415                ))?,
1416            ]);
1417        }
1418        let vision_regexes = vec![
1419            // Vision attention (transformer)
1420            Regex::new(
1421                r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1422            )?,
1423            Regex::new(
1424                r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1425            )?,
1426            Regex::new(
1427                r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1428            )?,
1429            Regex::new(
1430                r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1431            )?,
1432            // Vision attention (global transforemr)
1433            Regex::new(
1434                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1435            )?,
1436            Regex::new(
1437                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1438            )?,
1439            Regex::new(
1440                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1441            )?,
1442            Regex::new(
1443                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1444            )?,
1445            // MLP vision
1446            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1447            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1448        ];
1449
1450        Ok([text_regexes, vision_regexes].concat())
1451    }
1452}
1453
1454impl DeviceMappedModelLoader for VLlamaLoader {
1455    fn mapped_max_act_size_elems(
1456        &self,
1457        config: &str,
1458        params: &AutoDeviceMapParams,
1459        _prompt_chunksize: usize,
1460    ) -> Result<usize> {
1461        let AutoDeviceMapParams::Vision {
1462            max_seq_len,
1463            max_batch_size,
1464            max_image_shape: _,
1465            max_num_images,
1466        } = params
1467        else {
1468            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1469        };
1470
1471        let config: MLlamaConfig = serde_json::from_str(config)?;
1472
1473        let img_seq_len = {
1474            let cfg = &config.vision_config;
1475            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1476            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1477            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1478        };
1479        let img_seq_len = img_seq_len * max_num_images;
1480
1481        let max_cross_text_attn = {
1482            let cfg = &config.text_config;
1483            max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1484        };
1485
1486        let max_self_text_attn = {
1487            let cfg = &config.text_config;
1488            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1489        };
1490
1491        Ok(max_self_text_attn.max(max_cross_text_attn))
1492    }
1493
1494    fn non_mapped_max_act_size_elems(
1495        &self,
1496        config: &str,
1497        params: &AutoDeviceMapParams,
1498    ) -> Result<usize> {
1499        let AutoDeviceMapParams::Vision {
1500            max_seq_len: _,
1501            max_batch_size,
1502            max_image_shape: _,
1503            max_num_images,
1504        } = params
1505        else {
1506            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1507        };
1508
1509        let config: MLlamaConfig = serde_json::from_str(config)?;
1510
1511        let img_seq_len = {
1512            let cfg = &config.vision_config;
1513            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1514            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1515            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1516        };
1517        let max_vision_attn = {
1518            let cfg = &config.vision_config;
1519            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1520        };
1521
1522        Ok(max_vision_attn)
1523    }
1524
1525    fn non_mapped_size_in_bytes(
1526        &self,
1527        config: &str,
1528        dtype: DType,
1529        weight_pack_factor: usize,
1530    ) -> Result<usize> {
1531        let config: MLlamaConfig = serde_json::from_str(config)?;
1532        let text_elems = {
1533            let cfg = &config.text_config;
1534            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1535            let lm_head = if !cfg.tie_word_embeddings {
1536                cfg.hidden_size * cfg.vocab_size
1537            } else {
1538                0
1539            };
1540            let norm = cfg.hidden_size;
1541            embed_tokens + lm_head + norm
1542        };
1543
1544        let vision_elems = {
1545            let cfg = &config.vision_config;
1546
1547            let conv_cfg = Conv2dConfig {
1548                stride: cfg.patch_size,
1549                ..Default::default()
1550            };
1551            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1552                * cfg.patch_size
1553                * cfg.patch_size;
1554
1555            let class_embedding = cfg.hidden_size;
1556
1557            let gated_positional_embedding = {
1558                let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1559                let embedding = num_patches * cfg.hidden_size;
1560                let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1561                    * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1562
1563                embedding + tile_embedding
1564            };
1565
1566            let pre_tile_positional_embedding =
1567                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1568            let post_tile_positional_embedding =
1569                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1570
1571            let layernorm_pre = cfg.hidden_size;
1572            let layernorm_post = cfg.hidden_size;
1573
1574            let encoder_layer = {
1575                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1576                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1577
1578                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1579                let q_proj =
1580                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1581                let k_proj =
1582                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1583                let v_proj =
1584                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1585                let o_proj =
1586                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1587
1588                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1589                    + cfg.intermediate_size;
1590                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1591                    + cfg.hidden_size;
1592
1593                input_layernorm
1594                    + post_attention_layernorm
1595                    + q_proj
1596                    + k_proj
1597                    + v_proj
1598                    + o_proj
1599                    + fc1
1600                    + fc2
1601            };
1602
1603            patch_embedding
1604                + class_embedding
1605                + gated_positional_embedding
1606                + pre_tile_positional_embedding
1607                + post_tile_positional_embedding
1608                + layernorm_pre
1609                + layernorm_post
1610                + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1611        };
1612
1613        let elems = text_elems + vision_elems;
1614        Ok(elems * dtype.size_in_bytes())
1615    }
1616
1617    fn layer_sizes_in_bytes(
1618        &self,
1619        config: &str,
1620        dtype: DType,
1621        weight_pack_factor: usize,
1622    ) -> Result<Vec<usize>> {
1623        let config: MLlamaConfig = serde_json::from_str(config)?;
1624        let cfg = &config.text_config;
1625
1626        let mut layer_sizes = Vec::new();
1627
1628        for i in 0..cfg.num_hidden_layers {
1629            let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1630                // No isq for cross attention
1631                1
1632            } else {
1633                weight_pack_factor
1634            };
1635
1636            let per_layer_elems = {
1637                let input_layernorm = cfg.hidden_size;
1638                let post_attention_layernorm = cfg.hidden_size;
1639
1640                let size_in = cfg.hidden_size;
1641                let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1642                let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1643                let q_proj = size_in * size_q / weight_pack_factor;
1644                let k_proj = size_in * size_kv / weight_pack_factor;
1645                let v_proj = size_in * size_kv / weight_pack_factor;
1646                let o_proj = size_q * size_in / weight_pack_factor;
1647
1648                let h_size = cfg.hidden_size;
1649                let i_size = cfg.intermediate_size;
1650                let gate_proj = h_size * i_size / weight_pack_factor;
1651                let up_proj = h_size * i_size / weight_pack_factor;
1652                let down_proj = i_size * h_size / weight_pack_factor;
1653
1654                input_layernorm
1655                    + post_attention_layernorm
1656                    + q_proj
1657                    + k_proj
1658                    + v_proj
1659                    + o_proj
1660                    + gate_proj
1661                    + up_proj
1662                    + down_proj
1663            };
1664
1665            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1666        }
1667
1668        Ok(layer_sizes)
1669    }
1670
1671    fn num_layers(&self, config: &str) -> Result<usize> {
1672        let config: MLlamaConfig = serde_json::from_str(config)?;
1673        Ok(config.text_config.num_hidden_layers)
1674    }
1675
1676    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1677        let cfg: MLlamaConfig = serde_json::from_str(config)?;
1678        let cfg = &cfg.text_config;
1679
1680        let cfg = ModelConfigMetadata {
1681            max_seq_len: cfg.max_position_embeddings,
1682            num_layers: cfg.num_hidden_layers,
1683            hidden_size: cfg.hidden_size,
1684            num_kv_heads: cfg.num_key_value_heads,
1685            num_attn_heads: cfg.num_attention_heads,
1686            sliding_window: None,
1687            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1688            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1689        };
1690
1691        Ok(Box::new(cfg))
1692    }
1693
1694    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1695        Some(vec![NonMappedSubModel::Vision])
1696    }
1697}
1698
1699// ======================== Qwen2VL Loader
1700
1701/// [`VisionLoader`] for an Qwen2-VL model.
1702///
1703/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1704pub struct Qwen2VLLoader;
1705
1706pub struct Qwen2VLPrefixer;
1707
1708impl VisionPromptPrefixer for Qwen2VLPrefixer {
1709    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1710        format!(
1711            "{}{prompt}",
1712            format!(
1713                "{}{}{}",
1714                Qwen2VLProcessor::VISION_START,
1715                Qwen2VLProcessor::IMAGE_PAD,
1716                Qwen2VLProcessor::VISION_END
1717            )
1718            .repeat(image_indexes.len())
1719        )
1720    }
1721}
1722
1723impl VisionModelLoader for Qwen2VLLoader {
1724    fn load(
1725        &self,
1726        config: &str,
1727        _use_flash_attn: bool,
1728        vb: ShardedVarBuilder,
1729        normal_loading_metadata: NormalLoadingMetadata,
1730        attention_mechanism: AttentionImplementation,
1731    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1732        let config: Qwen2VLConfig = serde_json::from_str(config)?;
1733        Ok(Box::new(Qwen2VLModel::new(
1734            &config,
1735            vb,
1736            self.is_gptx(),
1737            normal_loading_metadata,
1738            attention_mechanism,
1739        )?))
1740    }
1741    fn is_gptx(&self) -> bool {
1742        true
1743    }
1744    fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1745        let config: Qwen2VLConfig = serde_json::from_str(config)?;
1746        Ok(Box::new(config))
1747    }
1748    fn get_processor(
1749        &self,
1750        _model_config: &str,
1751        _processor_config: Option<ProcessorConfig>,
1752        _preprocessor_config: PreProcessorConfig,
1753        max_edge: Option<u32>,
1754    ) -> Arc<dyn Processor + Send + Sync> {
1755        Arc::new(Qwen2VLProcessor::new(max_edge))
1756    }
1757    fn supports_paged_attention(&self) -> bool {
1758        false
1759    }
1760    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
1761        Arc::new(Qwen2VLPrefixer)
1762    }
1763}
1764
1765impl IsqModelLoader for Qwen2VLLoader {
1766    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1767        Ok(vec![
1768            Regex::new(r"lm_head\.(weight|bias)$")?,
1769            // Attention
1770            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1771            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1772            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1773            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1774            // MLP
1775            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1776            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1777            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1778        ])
1779    }
1780}
1781
1782impl DeviceMappedModelLoader for Qwen2VLLoader {
1783    fn mapped_max_act_size_elems(
1784        &self,
1785        config: &str,
1786        params: &AutoDeviceMapParams,
1787        _prompt_chunksize: usize,
1788    ) -> Result<usize> {
1789        let AutoDeviceMapParams::Vision {
1790            max_seq_len,
1791            max_batch_size,
1792            max_image_shape,
1793            max_num_images,
1794        } = params
1795        else {
1796            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1797        };
1798
1799        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1800
1801        let img_seq_len = {
1802            let cfg = &cfg.vision_config;
1803            let grid_t = max_num_images / cfg.temporal_patch_size;
1804            let grid_h = max_image_shape.0 / cfg.patch_size;
1805            let grid_w = max_image_shape.1 / cfg.patch_size;
1806            grid_t * grid_h * grid_w
1807        };
1808        let img_seq_len = img_seq_len * max_num_images;
1809
1810        let max_text_attn = {
1811            // This model injects the vision information directly into the input embeddings
1812            let max_seq_len = img_seq_len + max_seq_len;
1813            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1814        };
1815
1816        Ok(max_text_attn)
1817    }
1818
1819    fn non_mapped_max_act_size_elems(
1820        &self,
1821        config: &str,
1822        params: &AutoDeviceMapParams,
1823    ) -> Result<usize> {
1824        let AutoDeviceMapParams::Vision {
1825            max_seq_len: _,
1826            max_batch_size,
1827            max_image_shape,
1828            max_num_images,
1829        } = params
1830        else {
1831            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1832        };
1833
1834        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1835
1836        let img_seq_len = {
1837            let cfg = &cfg.vision_config;
1838            let grid_t = max_num_images / cfg.temporal_patch_size;
1839            let grid_h = max_image_shape.0 / cfg.patch_size;
1840            let grid_w = max_image_shape.1 / cfg.patch_size;
1841            grid_t * grid_h * grid_w
1842        };
1843
1844        let max_vision_attn = {
1845            let cfg = &cfg.vision_config;
1846            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
1847        };
1848
1849        Ok(max_vision_attn)
1850    }
1851
1852    fn non_mapped_size_in_bytes(
1853        &self,
1854        config: &str,
1855        dtype: DType,
1856        weight_pack_factor: usize,
1857    ) -> Result<usize> {
1858        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1859        let text_elems = {
1860            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1861            let lm_head = if !cfg.tie_word_embeddings {
1862                cfg.hidden_size * cfg.vocab_size
1863            } else {
1864                0
1865            };
1866            let norm = cfg.hidden_size;
1867            embed_tokens + lm_head + norm
1868        };
1869
1870        let patch_merger = {
1871            let cfg = &cfg.vision_config;
1872            let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
1873
1874            let mlp0 = hidden_size * hidden_size + hidden_size;
1875            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
1876
1877            let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
1878
1879            mlp0 + mlp2 + ln_q
1880        };
1881
1882        let patch_embed = {
1883            let cfg = &cfg.vision_config;
1884            let conv_cfg = Conv3dConfig {
1885                stride: cfg.patch_size,
1886                ..Default::default()
1887            };
1888            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
1889            cfg.in_channels * cfg.embed_dim / conv_cfg.groups
1890                * kernel_sizes[0]
1891                * kernel_sizes[1]
1892                * kernel_sizes[2]
1893        };
1894
1895        let encoder_layer = {
1896            let cfg = &cfg.vision_config;
1897            let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
1898            let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
1899
1900            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
1901            let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
1902            let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
1903            let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
1904
1905            let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
1906            let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
1907
1908            norm1 + norm2 + fc1 + fc2 + qkv + out
1909        };
1910
1911        let elems =
1912            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
1913
1914        Ok(elems * dtype.size_in_bytes())
1915    }
1916
1917    fn layer_sizes_in_bytes(
1918        &self,
1919        config: &str,
1920        dtype: DType,
1921        weight_pack_factor: usize,
1922    ) -> Result<Vec<usize>> {
1923        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1924        let per_layer_elems = {
1925            let input_layernorm = cfg.hidden_size;
1926            let post_attention_layernorm = cfg.hidden_size;
1927
1928            let size_in = cfg.hidden_size;
1929            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1930            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1931            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1932            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1933            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1934            let o_proj = size_q * size_in / weight_pack_factor;
1935
1936            let h_size = cfg.hidden_size;
1937            let i_size = cfg.intermediate_size;
1938            let gate_proj = h_size * i_size / weight_pack_factor;
1939            let up_proj = h_size * i_size / weight_pack_factor;
1940            let down_proj = i_size * h_size / weight_pack_factor;
1941
1942            input_layernorm
1943                + post_attention_layernorm
1944                + q_proj
1945                + k_proj
1946                + v_proj
1947                + o_proj
1948                + gate_proj
1949                + up_proj
1950                + down_proj
1951        };
1952        Ok(vec![
1953            per_layer_elems * dtype.size_in_bytes();
1954            cfg.num_hidden_layers
1955        ])
1956    }
1957
1958    fn num_layers(&self, config: &str) -> Result<usize> {
1959        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1960        Ok(cfg.num_hidden_layers)
1961    }
1962
1963    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1964        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1965
1966        let cfg = ModelConfigMetadata {
1967            max_seq_len: cfg.max_position_embeddings,
1968            num_layers: cfg.num_hidden_layers,
1969            hidden_size: cfg.hidden_size,
1970            num_kv_heads: cfg.num_key_value_heads,
1971            num_attn_heads: cfg.num_attention_heads,
1972            sliding_window: cfg.sliding_window,
1973            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1974            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1975        };
1976
1977        Ok(Box::new(cfg))
1978    }
1979
1980    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1981        Some(vec![NonMappedSubModel::Vision])
1982    }
1983}
1984
1985// ======================== Idefics 3 loader
1986
1987/// [`VisionLoader`] for an Idefics 3 Vision model.
1988///
1989/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1990pub struct Idefics3Loader;
1991
1992pub struct Idefics3Prefixer;
1993
1994impl VisionPromptPrefixer for Idefics3Prefixer {
1995    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
1996        // Chat template does it
1997        prompt.to_string()
1998    }
1999}
2000
2001impl VisionModelLoader for Idefics3Loader {
2002    fn load(
2003        &self,
2004        config: &str,
2005        use_flash_attn: bool,
2006        vb: ShardedVarBuilder,
2007        normal_loading_metadata: NormalLoadingMetadata,
2008        attention_mechanism: AttentionImplementation,
2009    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2010        let mut config: Idefics3Config = serde_json::from_str(config)?;
2011        config.text_config.use_flash_attn = use_flash_attn;
2012        Ok(Box::new(Idefics3Model::new(
2013            &config,
2014            vb,
2015            self.is_gptx(),
2016            normal_loading_metadata,
2017            attention_mechanism,
2018        )?))
2019    }
2020    fn is_gptx(&self) -> bool {
2021        true
2022    }
2023    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2024        let mut config: Idefics3Config = serde_json::from_str(config)?;
2025        config.text_config.use_flash_attn = use_flash_attn;
2026        Ok(Box::new(config))
2027    }
2028    fn get_processor(
2029        &self,
2030        _model_config: &str,
2031        processor_config: Option<ProcessorConfig>,
2032        preprocessor_config: PreProcessorConfig,
2033        max_edge: Option<u32>,
2034    ) -> Arc<dyn Processor + Send + Sync> {
2035        Arc::new(Idefics3Processor::new(
2036            processor_config.unwrap_or_default(),
2037            preprocessor_config,
2038            max_edge,
2039        ))
2040    }
2041    fn supports_paged_attention(&self) -> bool {
2042        true
2043    }
2044    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2045        Arc::new(Idefics3Prefixer)
2046    }
2047}
2048
2049impl IsqModelLoader for Idefics3Loader {
2050    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2051        Ok(vec![
2052            Regex::new(r"lm_head\.(weight|bias)$")?,
2053            // Attention
2054            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2055            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2056            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2057            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2058            // MLP
2059            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2060            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2061            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2062        ])
2063    }
2064}
2065
2066impl DeviceMappedModelLoader for Idefics3Loader {
2067    fn mapped_max_act_size_elems(
2068        &self,
2069        config: &str,
2070        params: &AutoDeviceMapParams,
2071        _prompt_chunksize: usize,
2072    ) -> Result<usize> {
2073        let AutoDeviceMapParams::Vision {
2074            max_seq_len,
2075            max_batch_size,
2076            max_image_shape: _,
2077            max_num_images,
2078        } = params
2079        else {
2080            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2081        };
2082
2083        let cfg: Idefics3Config = serde_json::from_str(config)?;
2084
2085        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2086        let img_seq_len = (num_patches + 1) * max_num_images;
2087
2088        let max_text_attn = {
2089            // This model injects the vision information directly into the input embeddings
2090            let max_seq_len = img_seq_len + max_seq_len;
2091            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2092        };
2093
2094        Ok(max_text_attn)
2095    }
2096
2097    fn non_mapped_max_act_size_elems(
2098        &self,
2099        config: &str,
2100        params: &AutoDeviceMapParams,
2101    ) -> Result<usize> {
2102        let AutoDeviceMapParams::Vision {
2103            max_seq_len: _,
2104            max_batch_size,
2105            max_image_shape: _,
2106            max_num_images,
2107        } = params
2108        else {
2109            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2110        };
2111
2112        let cfg: Idefics3Config = serde_json::from_str(config)?;
2113
2114        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2115        let img_seq_len = num_patches + 1;
2116
2117        let max_vision_attn = {
2118            // do_image_splitting = true
2119            let images_factor = 5;
2120
2121            (max_batch_size * images_factor * max_num_images)
2122                * cfg.vision_config.num_attention_heads
2123                * img_seq_len
2124                * img_seq_len
2125        };
2126
2127        Ok(max_vision_attn)
2128    }
2129
2130    fn non_mapped_size_in_bytes(
2131        &self,
2132        config: &str,
2133        dtype: DType,
2134        weight_pack_factor: usize,
2135    ) -> Result<usize> {
2136        let cfg: Idefics3Config = serde_json::from_str(config)?;
2137        let text_elems = {
2138            let cfg = &cfg.text_config;
2139
2140            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2141            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2142            let norm = cfg.hidden_size;
2143            embed_tokens + lm_head + norm
2144        };
2145
2146        let connector_elems = {
2147            let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2148            let out_dim = cfg.text_config.hidden_size;
2149
2150            in_dim * out_dim
2151        };
2152
2153        let vision_transformer = {
2154            let cfg = &cfg.vision_config;
2155
2156            let post_layernorm = cfg.hidden_size;
2157
2158            let conv_config = Conv2dConfig {
2159                stride: cfg.patch_size,
2160                ..Default::default()
2161            };
2162            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2163                * cfg.patch_size
2164                * cfg.patch_size;
2165
2166            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2167            let num_patches = num_patches_per_side.pow(2);
2168            let position_embedding = num_patches * cfg.hidden_size;
2169
2170            let layer_elems = {
2171                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2172                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2173
2174                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2175                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2176
2177                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2178                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2179                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2180                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2181
2182                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2183            };
2184
2185            post_layernorm
2186                + patch_embedding
2187                + position_embedding
2188                + layer_elems * cfg.num_hidden_layers
2189        };
2190
2191        let elems = text_elems + connector_elems + vision_transformer;
2192
2193        Ok(elems * dtype.size_in_bytes())
2194    }
2195
2196    fn layer_sizes_in_bytes(
2197        &self,
2198        config: &str,
2199        dtype: DType,
2200        weight_pack_factor: usize,
2201    ) -> Result<Vec<usize>> {
2202        let cfg: Idefics3Config = serde_json::from_str(config)?;
2203        let cfg = cfg.text_config;
2204        let per_layer_elems = {
2205            let input_layernorm = cfg.hidden_size;
2206            let post_attention_layernorm = cfg.hidden_size;
2207
2208            let size_in = cfg.hidden_size;
2209            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2210            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2211            let q_proj = size_in * size_q / weight_pack_factor;
2212            let k_proj = size_in * size_kv / weight_pack_factor;
2213            let v_proj = size_in * size_kv / weight_pack_factor;
2214            let o_proj = size_q * size_in / weight_pack_factor;
2215
2216            let h_size = cfg.hidden_size;
2217            let i_size = cfg.intermediate_size;
2218            let gate_proj = h_size * i_size / weight_pack_factor;
2219            let up_proj = h_size * i_size / weight_pack_factor;
2220            let down_proj = i_size * h_size / weight_pack_factor;
2221
2222            input_layernorm
2223                + post_attention_layernorm
2224                + q_proj
2225                + k_proj
2226                + v_proj
2227                + o_proj
2228                + gate_proj
2229                + up_proj
2230                + down_proj
2231        };
2232        Ok(vec![
2233            per_layer_elems * dtype.size_in_bytes();
2234            cfg.num_hidden_layers
2235        ])
2236    }
2237
2238    fn num_layers(&self, config: &str) -> Result<usize> {
2239        let cfg: Idefics3Config = serde_json::from_str(config)?;
2240        Ok(cfg.text_config.num_hidden_layers)
2241    }
2242    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2243        let cfg: Idefics3Config = serde_json::from_str(config)?;
2244        let cfg = &cfg.text_config;
2245
2246        let cfg = ModelConfigMetadata {
2247            max_seq_len: cfg.max_position_embeddings,
2248            num_layers: cfg.num_hidden_layers,
2249            hidden_size: cfg.hidden_size,
2250            num_kv_heads: cfg.num_key_value_heads,
2251            num_attn_heads: cfg.num_attention_heads,
2252            sliding_window: None,
2253            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2254            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2255        };
2256
2257        Ok(Box::new(cfg))
2258    }
2259
2260    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2261        Some(vec![NonMappedSubModel::Vision])
2262    }
2263}
2264
2265// ======================== MiniCpm-O loader
2266
2267/// [`VisionLoader`] for an MiniCpm-O model.
2268///
2269/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2270pub struct MiniCpmOLoader;
2271
2272pub struct MiniCpmOPrefixer;
2273
2274impl VisionPromptPrefixer for MiniCpmOPrefixer {
2275    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2276        format!(
2277            "{}{prompt}",
2278            "(<image>./</image>)".repeat(image_indexes.len())
2279        )
2280    }
2281}
2282
2283impl VisionModelLoader for MiniCpmOLoader {
2284    fn load(
2285        &self,
2286        config: &str,
2287        use_flash_attn: bool,
2288        vb: ShardedVarBuilder,
2289        normal_loading_metadata: NormalLoadingMetadata,
2290        attention_mechanism: AttentionImplementation,
2291    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2292        let mut config: MiniCpmOConfig = serde_json::from_str(config)?;
2293        config.text_config.use_flash_attn = use_flash_attn;
2294        Ok(Box::new(MiniCpmOModel::new(
2295            &config,
2296            vb,
2297            self.is_gptx(),
2298            normal_loading_metadata,
2299            attention_mechanism,
2300        )?))
2301    }
2302    fn is_gptx(&self) -> bool {
2303        true
2304    }
2305    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2306        let mut config: MiniCpmOConfig = serde_json::from_str(config)?;
2307        config.text_config.use_flash_attn = use_flash_attn;
2308        Ok(Box::new(config))
2309    }
2310    fn get_processor(
2311        &self,
2312        _model_config: &str,
2313        processor_config: Option<ProcessorConfig>,
2314        preprocessor_config: PreProcessorConfig,
2315        max_edge: Option<u32>,
2316    ) -> Arc<dyn Processor + Send + Sync> {
2317        Arc::new(MiniCpmOProcessor::new(
2318            processor_config.unwrap_or_default(),
2319            preprocessor_config,
2320            max_edge,
2321        ))
2322    }
2323    fn supports_paged_attention(&self) -> bool {
2324        true
2325    }
2326    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2327        Arc::new(MiniCpmOPrefixer)
2328    }
2329}
2330
2331impl IsqModelLoader for MiniCpmOLoader {
2332    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2333        Ok(vec![
2334            Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2335            // Attention
2336            Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2337            Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2338            Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2339            Regex::new(r"llm.layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2340            // MLP
2341            Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2342            Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2343            Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2344        ])
2345    }
2346}
2347
2348impl DeviceMappedModelLoader for MiniCpmOLoader {
2349    fn mapped_max_act_size_elems(
2350        &self,
2351        config: &str,
2352        params: &AutoDeviceMapParams,
2353        _prompt_chunksize: usize,
2354    ) -> Result<usize> {
2355        let AutoDeviceMapParams::Vision {
2356            max_seq_len,
2357            max_batch_size,
2358            max_image_shape: _,
2359            max_num_images,
2360        } = params
2361        else {
2362            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2363        };
2364
2365        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2366
2367        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2368        let img_seq_len = (num_patches + 1) * max_num_images;
2369
2370        let max_text_attn = {
2371            // This model injects the vision information directly into the input embeddings
2372            let max_seq_len = img_seq_len + max_seq_len;
2373            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2374        };
2375
2376        Ok(max_text_attn)
2377    }
2378
2379    fn non_mapped_max_act_size_elems(
2380        &self,
2381        config: &str,
2382        params: &AutoDeviceMapParams,
2383    ) -> Result<usize> {
2384        let AutoDeviceMapParams::Vision {
2385            max_seq_len: _,
2386            max_batch_size,
2387            max_image_shape: _,
2388            max_num_images,
2389        } = params
2390        else {
2391            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2392        };
2393
2394        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2395
2396        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2397        let img_seq_len = num_patches + 1;
2398
2399        let max_vision_attn = {
2400            // do_image_splitting = true
2401            let images_factor = 5;
2402
2403            (max_batch_size * images_factor * max_num_images)
2404                * cfg.vision_config.num_attention_heads
2405                * img_seq_len
2406                * img_seq_len
2407        };
2408
2409        Ok(max_vision_attn)
2410    }
2411
2412    fn non_mapped_size_in_bytes(
2413        &self,
2414        config: &str,
2415        dtype: DType,
2416        weight_pack_factor: usize,
2417    ) -> Result<usize> {
2418        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2419        let text_elems = {
2420            let cfg = &cfg.text_config;
2421
2422            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2423            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2424            let norm = cfg.hidden_size;
2425            embed_tokens + lm_head + norm
2426        };
2427
2428        let vision_transformer = {
2429            let cfg = &cfg.vision_config;
2430
2431            let post_layernorm = cfg.hidden_size;
2432
2433            let conv_config = Conv2dConfig {
2434                stride: cfg.patch_size,
2435                ..Default::default()
2436            };
2437            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2438                * cfg.patch_size
2439                * cfg.patch_size;
2440
2441            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2442            let num_patches = num_patches_per_side.pow(2);
2443            let position_embedding = num_patches * cfg.hidden_size;
2444
2445            let layer_elems = {
2446                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2447                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2448
2449                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2450                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2451
2452                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2453                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2454                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2455                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2456
2457                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2458            };
2459
2460            post_layernorm
2461                + patch_embedding
2462                + position_embedding
2463                + layer_elems * cfg.num_hidden_layers
2464        };
2465
2466        let elems = text_elems + vision_transformer;
2467
2468        Ok(elems * dtype.size_in_bytes())
2469    }
2470
2471    fn layer_sizes_in_bytes(
2472        &self,
2473        config: &str,
2474        dtype: DType,
2475        weight_pack_factor: usize,
2476    ) -> Result<Vec<usize>> {
2477        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2478        let cfg = cfg.text_config;
2479        let per_layer_elems = {
2480            let input_layernorm = cfg.hidden_size;
2481            let post_attention_layernorm = cfg.hidden_size;
2482
2483            let size_in = cfg.hidden_size;
2484            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2485            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2486            let q_proj = size_in * size_q / weight_pack_factor;
2487            let k_proj = size_in * size_kv / weight_pack_factor;
2488            let v_proj = size_in * size_kv / weight_pack_factor;
2489            let o_proj = size_q * size_in / weight_pack_factor;
2490
2491            let h_size = cfg.hidden_size;
2492            let i_size = cfg.intermediate_size;
2493            let gate_proj = h_size * i_size / weight_pack_factor;
2494            let up_proj = h_size * i_size / weight_pack_factor;
2495            let down_proj = i_size * h_size / weight_pack_factor;
2496
2497            input_layernorm
2498                + post_attention_layernorm
2499                + q_proj
2500                + k_proj
2501                + v_proj
2502                + o_proj
2503                + gate_proj
2504                + up_proj
2505                + down_proj
2506        };
2507        Ok(vec![
2508            per_layer_elems * dtype.size_in_bytes();
2509            cfg.num_hidden_layers
2510        ])
2511    }
2512
2513    fn num_layers(&self, config: &str) -> Result<usize> {
2514        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2515        Ok(cfg.text_config.num_hidden_layers)
2516    }
2517    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2518        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2519        let cfg = &cfg.text_config;
2520
2521        let cfg = ModelConfigMetadata {
2522            max_seq_len: cfg.max_position_embeddings,
2523            num_layers: cfg.num_hidden_layers,
2524            hidden_size: cfg.hidden_size,
2525            num_kv_heads: cfg.num_key_value_heads,
2526            num_attn_heads: cfg.num_attention_heads,
2527            sliding_window: None,
2528            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2529            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2530        };
2531
2532        Ok(Box::new(cfg))
2533    }
2534}
2535
2536// ======================== Phi 4MM loader
2537
2538/// [`VisionLoader`] for a Phi 4MM Vision model.
2539///
2540/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2541pub struct Phi4MMLoader;
2542
2543pub struct Phi4MMPrefixer;
2544
2545impl VisionPromptPrefixer for Phi4MMPrefixer {
2546    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2547        // Image indexing starts at 0.
2548
2549        format!(
2550            "{}{prompt}",
2551            image_indexes
2552                .into_iter()
2553                .map(|image_index| format!("<|image_{}|>", image_index + 1))
2554                .join("")
2555        )
2556    }
2557}
2558
2559impl VisionModelLoader for Phi4MMLoader {
2560    fn load(
2561        &self,
2562        config: &str,
2563        use_flash_attn: bool,
2564        vb: ShardedVarBuilder,
2565        normal_loading_metadata: NormalLoadingMetadata,
2566        attention_mechanism: AttentionImplementation,
2567    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2568        let mut config: Phi4MMConfig = serde_json::from_str(config)?;
2569        config.use_flash_attn = use_flash_attn;
2570        Ok(Box::new(Phi4MMModel::new(
2571            &config,
2572            vb,
2573            self.is_gptx(),
2574            normal_loading_metadata,
2575            attention_mechanism,
2576        )?))
2577    }
2578    fn is_gptx(&self) -> bool {
2579        true
2580    }
2581    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2582        let mut config: Phi4MMConfig = serde_json::from_str(config)?;
2583        config.use_flash_attn = use_flash_attn;
2584        Ok(Box::new(config))
2585    }
2586    fn get_processor(
2587        &self,
2588        _model_config: &str,
2589        processor_config: Option<ProcessorConfig>,
2590        preprocessor_config: PreProcessorConfig,
2591        _max_edge: Option<u32>,
2592    ) -> Arc<dyn Processor + Send + Sync> {
2593        Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2594    }
2595    fn supports_paged_attention(&self) -> bool {
2596        true
2597    }
2598    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2599        Arc::new(Phi4MMPrefixer)
2600    }
2601}
2602
2603impl IsqModelLoader for Phi4MMLoader {
2604    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2605        Ok(vec![
2606            Regex::new(r"lm_head\.(weight|bias)$")?,
2607            // Attention
2608            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
2609            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2610            // MLP
2611            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
2612            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2613        ])
2614    }
2615}
2616
2617impl DeviceMappedModelLoader for Phi4MMLoader {
2618    fn mapped_max_act_size_elems(
2619        &self,
2620        config: &str,
2621        params: &AutoDeviceMapParams,
2622        _prompt_chunksize: usize,
2623    ) -> Result<usize> {
2624        // NOTE: we ignore max_num_images although it can only be one...
2625        let AutoDeviceMapParams::Vision {
2626            max_seq_len,
2627            max_batch_size,
2628            max_image_shape: _,
2629            max_num_images,
2630        } = params
2631        else {
2632            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2633        };
2634
2635        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2636
2637        let vcfg = &PHI4_MM_VISION_CFG;
2638
2639        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
2640        let img_seq_len = (num_patches + 1) * max_num_images;
2641
2642        let max_text_attn = {
2643            // This model injects the vision information directly into the input embeddings
2644            let max_seq_len = img_seq_len + max_seq_len;
2645            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2646        };
2647
2648        Ok(max_text_attn)
2649    }
2650
2651    fn non_mapped_max_act_size_elems(
2652        &self,
2653        _config: &str,
2654        params: &AutoDeviceMapParams,
2655    ) -> Result<usize> {
2656        let AutoDeviceMapParams::Vision {
2657            max_seq_len: _,
2658            max_batch_size,
2659            max_image_shape,
2660            max_num_images,
2661        } = params
2662        else {
2663            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2664        };
2665
2666        let vcfg = &PHI4_MM_VISION_CFG;
2667
2668        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
2669        let img_seq_len = num_patches + 1;
2670
2671        let max_batch_size = max_batch_size
2672            * (max_image_shape
2673                .0
2674                .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
2675                * max_image_shape
2676                    .1
2677                    .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
2678                + 1);
2679
2680        let max_vision_attn = (max_batch_size * max_num_images)
2681            * vcfg.num_attention_heads
2682            * img_seq_len
2683            * img_seq_len;
2684        let max_qkv = 3
2685            * (max_batch_size
2686                * vcfg.num_attention_heads
2687                * img_seq_len
2688                * (vcfg.hidden_size / vcfg.num_attention_heads));
2689
2690        Ok(max_vision_attn + max_qkv)
2691    }
2692
2693    fn non_mapped_size_in_bytes(
2694        &self,
2695        config: &str,
2696        dtype: DType,
2697        weight_pack_factor: usize,
2698    ) -> Result<usize> {
2699        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2700        let elems = {
2701            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2702            let lm_head = if !cfg.tie_word_embeddings {
2703                cfg.hidden_size * cfg.vocab_size
2704            } else {
2705                0
2706            };
2707            let norm = cfg.hidden_size;
2708
2709            let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
2710                let projection_cls = img_embed
2711                    .projection_cls
2712                    .clone()
2713                    .unwrap_or("linear".to_string());
2714                let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
2715                let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
2716                let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
2717
2718                let proj = match (projection_cls.as_str(), use_hd_transform) {
2719                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
2720                    ("mlp", true) => {
2721                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
2722                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2723                        a + b
2724                    }
2725                    ("mlp", false) => {
2726                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
2727                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2728                        a + b
2729                    }
2730                    _ => {
2731                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
2732                    }
2733                };
2734
2735                let (glb_gn, sub_gn) = if with_learnable_separator {
2736                    let glb_gn = image_dim_out * 4;
2737                    let sub_gn = image_dim_out * 4;
2738                    (glb_gn, sub_gn)
2739                } else {
2740                    (0, 0)
2741                };
2742
2743                let vision_transformer = {
2744                    let cfg = &PHI4_MM_VISION_CFG;
2745
2746                    let post_layernorm = cfg.hidden_size;
2747
2748                    let conv_config = Conv2dConfig {
2749                        stride: cfg.patch_size,
2750                        ..Default::default()
2751                    };
2752                    let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2753                        * cfg.patch_size
2754                        * cfg.patch_size;
2755
2756                    let num_patches_per_side = cfg.image_size / cfg.patch_size;
2757                    let num_patches = num_patches_per_side.pow(2);
2758                    let position_embedding = num_patches * cfg.hidden_size;
2759
2760                    let layer_elems = {
2761                        let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2762                        let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2763
2764                        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2765                        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2766
2767                        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2768                        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2769                        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2770                        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2771
2772                        layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2773                    };
2774
2775                    post_layernorm
2776                        + patch_embedding
2777                        + position_embedding
2778                        + layer_elems * cfg.num_hidden_layers
2779                };
2780
2781                proj + glb_gn + sub_gn + vision_transformer
2782            } else {
2783                0
2784            };
2785
2786            embed_tokens + lm_head + norm + image_embed
2787        };
2788
2789        Ok(elems * dtype.size_in_bytes())
2790    }
2791
2792    fn layer_sizes_in_bytes(
2793        &self,
2794        config: &str,
2795        dtype: DType,
2796        weight_pack_factor: usize,
2797    ) -> Result<Vec<usize>> {
2798        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2799        let per_layer_elems = {
2800            let input_layernorm = cfg.hidden_size;
2801            let post_attention_layernorm = cfg.hidden_size;
2802
2803            let size_in = cfg.hidden_size;
2804            let head_dim = cfg.head_dim();
2805            let op_size = head_dim * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
2806            let qkv_proj = size_in * op_size / weight_pack_factor;
2807            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
2808
2809            let h_size = cfg.hidden_size;
2810            let i_size = cfg.intermediate_size;
2811            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
2812            let down_proj = h_size * i_size / weight_pack_factor;
2813
2814            input_layernorm
2815                + post_attention_layernorm
2816                + qkv_proj
2817                + o_proj
2818                + gate_up_proj
2819                + down_proj
2820        };
2821        Ok(vec![
2822            per_layer_elems * dtype.size_in_bytes();
2823            cfg.num_hidden_layers
2824        ])
2825    }
2826
2827    fn num_layers(&self, config: &str) -> Result<usize> {
2828        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2829        Ok(cfg.num_hidden_layers)
2830    }
2831
2832    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2833        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2834
2835        let cfg = ModelConfigMetadata {
2836            max_seq_len: cfg.max_position_embeddings,
2837            num_layers: cfg.num_hidden_layers,
2838            hidden_size: cfg.hidden_size,
2839            num_kv_heads: cfg.num_key_value_heads(),
2840            num_attn_heads: cfg.num_attention_heads,
2841            sliding_window: cfg.sliding_window,
2842            k_head_dim: cfg.head_dim(),
2843            v_head_dim: cfg.head_dim(),
2844        };
2845
2846        Ok(Box::new(cfg))
2847    }
2848
2849    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2850        Some(vec![NonMappedSubModel::Vision])
2851    }
2852}
2853
2854// ======================== Qwen2_5VL Loader
2855
2856/// [`VisionLoader`] for an Qwen2_5VL model.
2857///
2858/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2859pub struct Qwen2_5VLLoader;
2860
2861pub struct Qwen2_5VLPrefixer;
2862
2863impl VisionPromptPrefixer for Qwen2_5VLPrefixer {
2864    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2865        format!(
2866            "{}{prompt}",
2867            format!(
2868                "{}{}{}",
2869                Qwen2_5VLProcessor::VISION_START,
2870                Qwen2_5VLProcessor::IMAGE_PAD,
2871                Qwen2_5VLProcessor::VISION_END
2872            )
2873            .repeat(image_indexes.len())
2874        )
2875    }
2876}
2877
2878impl VisionModelLoader for Qwen2_5VLLoader {
2879    fn load(
2880        &self,
2881        config: &str,
2882        _use_flash_attn: bool,
2883        vb: ShardedVarBuilder,
2884        normal_loading_metadata: NormalLoadingMetadata,
2885        attention_mechanism: AttentionImplementation,
2886    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2887        let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
2888        Ok(Box::new(Qwen2_5VLModel::new(
2889            &config,
2890            vb,
2891            self.is_gptx(),
2892            normal_loading_metadata,
2893            attention_mechanism,
2894        )?))
2895    }
2896    fn is_gptx(&self) -> bool {
2897        true
2898    }
2899    fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2900        let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
2901        Ok(Box::new(config))
2902    }
2903    fn get_processor(
2904        &self,
2905        _model_config: &str,
2906        _processor_config: Option<ProcessorConfig>,
2907        _preprocessor_config: PreProcessorConfig,
2908        max_edge: Option<u32>,
2909    ) -> Arc<dyn Processor + Send + Sync> {
2910        Arc::new(Qwen2_5VLProcessor::new(max_edge))
2911    }
2912    fn supports_paged_attention(&self) -> bool {
2913        false
2914    }
2915    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2916        Arc::new(Qwen2_5VLPrefixer)
2917    }
2918}
2919
2920impl IsqModelLoader for Qwen2_5VLLoader {
2921    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2922        Ok(vec![
2923            Regex::new(r"lm_head\.(weight|bias)$")?,
2924            // Attention
2925            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2926            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2927            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2928            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2929            // MLP
2930            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2931            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2932            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2933        ])
2934    }
2935}
2936
2937impl DeviceMappedModelLoader for Qwen2_5VLLoader {
2938    fn mapped_max_act_size_elems(
2939        &self,
2940        config: &str,
2941        params: &AutoDeviceMapParams,
2942        _prompt_chunksize: usize,
2943    ) -> Result<usize> {
2944        let AutoDeviceMapParams::Vision {
2945            max_seq_len,
2946            max_batch_size,
2947            max_image_shape,
2948            max_num_images,
2949        } = params
2950        else {
2951            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2952        };
2953
2954        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
2955
2956        let img_seq_len = {
2957            let cfg = &cfg.vision_config;
2958            let grid_t = max_num_images / cfg.temporal_patch_size;
2959            let grid_h = max_image_shape.0 / cfg.patch_size;
2960            let grid_w = max_image_shape.1 / cfg.patch_size;
2961            grid_t * grid_h * grid_w
2962        };
2963        let img_seq_len = img_seq_len * max_num_images;
2964
2965        let max_text_attn = {
2966            // This model injects the vision information directly into the input embeddings
2967            let max_seq_len = img_seq_len + max_seq_len;
2968            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2969        };
2970
2971        Ok(max_text_attn)
2972    }
2973
2974    fn non_mapped_max_act_size_elems(
2975        &self,
2976        config: &str,
2977        params: &AutoDeviceMapParams,
2978    ) -> Result<usize> {
2979        let AutoDeviceMapParams::Vision {
2980            max_seq_len: _,
2981            max_batch_size,
2982            max_image_shape,
2983            max_num_images,
2984        } = params
2985        else {
2986            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2987        };
2988
2989        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
2990
2991        let img_seq_len = {
2992            let cfg = &cfg.vision_config;
2993            let grid_t = max_num_images / cfg.temporal_patch_size;
2994            let grid_h = max_image_shape.0 / cfg.patch_size;
2995            let grid_w = max_image_shape.1 / cfg.patch_size;
2996            grid_t * grid_h * grid_w
2997        };
2998
2999        let max_vision_attn = {
3000            let cfg = &cfg.vision_config;
3001            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3002        };
3003
3004        Ok(max_vision_attn)
3005    }
3006
3007    fn non_mapped_size_in_bytes(
3008        &self,
3009        config: &str,
3010        dtype: DType,
3011        weight_pack_factor: usize,
3012    ) -> Result<usize> {
3013        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3014        let text_elems = {
3015            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3016            let lm_head = if !cfg.tie_word_embeddings {
3017                cfg.hidden_size * cfg.vocab_size
3018            } else {
3019                0
3020            };
3021            let norm = cfg.hidden_size;
3022            embed_tokens + lm_head + norm
3023        };
3024
3025        let patch_merger = {
3026            let cfg = &cfg.vision_config;
3027            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3028
3029            let mlp0 = hidden_size * hidden_size + hidden_size;
3030            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3031
3032            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3033
3034            mlp0 + mlp2 + ln_q
3035        };
3036
3037        let patch_embed = {
3038            let cfg = &cfg.vision_config;
3039            let conv_cfg = Conv3dConfig {
3040                stride: cfg.patch_size,
3041                ..Default::default()
3042            };
3043            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3044            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3045                * kernel_sizes[0]
3046                * kernel_sizes[1]
3047                * kernel_sizes[2]
3048        };
3049
3050        let encoder_layer = {
3051            let cfg = &cfg.vision_config;
3052            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3053            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3054
3055            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3056            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3057            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3058
3059            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3060            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3061
3062            norm1 + norm2 + fc1 + fc2 + qkv + out
3063        };
3064
3065        let elems =
3066            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3067
3068        Ok(elems * dtype.size_in_bytes())
3069    }
3070
3071    fn layer_sizes_in_bytes(
3072        &self,
3073        config: &str,
3074        dtype: DType,
3075        weight_pack_factor: usize,
3076    ) -> Result<Vec<usize>> {
3077        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3078        let per_layer_elems = {
3079            let input_layernorm = cfg.hidden_size;
3080            let post_attention_layernorm = cfg.hidden_size;
3081
3082            let size_in = cfg.hidden_size;
3083            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3084            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3085            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3086            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3087            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3088            let o_proj = size_q * size_in / weight_pack_factor;
3089
3090            let h_size = cfg.hidden_size;
3091            let i_size = cfg.intermediate_size;
3092            let gate_proj = h_size * i_size / weight_pack_factor;
3093            let up_proj = h_size * i_size / weight_pack_factor;
3094            let down_proj = i_size * h_size / weight_pack_factor;
3095
3096            input_layernorm
3097                + post_attention_layernorm
3098                + q_proj
3099                + k_proj
3100                + v_proj
3101                + o_proj
3102                + gate_proj
3103                + up_proj
3104                + down_proj
3105        };
3106        Ok(vec![
3107            per_layer_elems * dtype.size_in_bytes();
3108            cfg.num_hidden_layers
3109        ])
3110    }
3111
3112    fn num_layers(&self, config: &str) -> Result<usize> {
3113        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3114        Ok(cfg.num_hidden_layers)
3115    }
3116
3117    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3118        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3119
3120        let cfg = ModelConfigMetadata {
3121            max_seq_len: cfg.max_position_embeddings,
3122            num_layers: cfg.num_hidden_layers,
3123            hidden_size: cfg.hidden_size,
3124            num_kv_heads: cfg.num_key_value_heads,
3125            num_attn_heads: cfg.num_attention_heads,
3126            sliding_window: cfg.sliding_window,
3127            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3128            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3129        };
3130
3131        Ok(Box::new(cfg))
3132    }
3133
3134    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3135        Some(vec![NonMappedSubModel::Vision])
3136    }
3137}
3138
3139// ======================== Gemma 3 Loader
3140
3141/// [`VisionLoader`] for an Gemma 3 model.
3142///
3143/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3144pub struct Gemma3Loader;
3145
3146pub struct Gemma3Prefixer;
3147
3148impl VisionPromptPrefixer for Gemma3Prefixer {
3149    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3150        prompt.to_string()
3151    }
3152}
3153
3154impl VisionModelLoader for Gemma3Loader {
3155    fn load(
3156        &self,
3157        config: &str,
3158        _use_flash_attn: bool,
3159        vb: ShardedVarBuilder,
3160        normal_loading_metadata: NormalLoadingMetadata,
3161        attention_mechanism: AttentionImplementation,
3162    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3163        let config: Gemma3Config = serde_json::from_str(config)?;
3164        Ok(Box::new(Gemma3Model::new(
3165            &config,
3166            vb,
3167            self.is_gptx(),
3168            normal_loading_metadata,
3169            attention_mechanism,
3170        )?))
3171    }
3172    fn is_gptx(&self) -> bool {
3173        true
3174    }
3175    fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3176        let config: Gemma3Config = serde_json::from_str(config)?;
3177        Ok(Box::new(config))
3178    }
3179    fn get_processor(
3180        &self,
3181        config: &str,
3182        processor_config: Option<ProcessorConfig>,
3183        _preprocessor_config: PreProcessorConfig,
3184        _max_edge: Option<u32>,
3185    ) -> Arc<dyn Processor + Send + Sync> {
3186        let config: Gemma3Config = serde_json::from_str(config).unwrap();
3187        // Handle the Gemma 3 1b case here
3188        Arc::new(Gemma3Processor::new(
3189            processor_config.unwrap_or_default(),
3190            matches!(config, Gemma3Config::WithVision { .. }),
3191        ))
3192    }
3193    fn supports_paged_attention(&self) -> bool {
3194        true
3195    }
3196    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
3197        Arc::new(Gemma3Prefixer)
3198    }
3199}
3200
3201impl IsqModelLoader for Gemma3Loader {
3202    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3203        Ok(vec![
3204            Regex::new(r"lm_head\.(weight|bias)$")?,
3205            // Attention
3206            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3207            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3208            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3209            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
3210            // MLP
3211            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3212            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3213            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3214        ])
3215    }
3216}
3217
3218impl DeviceMappedModelLoader for Gemma3Loader {
3219    fn mapped_max_act_size_elems(
3220        &self,
3221        config: &str,
3222        params: &AutoDeviceMapParams,
3223        prompt_chunksize: usize,
3224    ) -> Result<usize> {
3225        let AutoDeviceMapParams::Vision {
3226            max_seq_len,
3227            max_batch_size,
3228            max_image_shape: _,
3229            max_num_images,
3230        } = params
3231        else {
3232            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3233        };
3234
3235        let cfg: Gemma3Config = serde_json::from_str(config)?;
3236
3237        match cfg {
3238            Gemma3Config::Text(text_config) => Ok(max_batch_size
3239                * text_config.num_attention_heads
3240                * prompt_chunksize
3241                * prompt_chunksize),
3242            Gemma3Config::WithVision {
3243                text_config,
3244                vision_config,
3245                ..
3246            } => {
3247                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3248                let img_seq_len = (num_patches + 1) * max_num_images;
3249
3250                let max_text_attn = {
3251                    // This model injects the vision information directly into the input embeddings
3252                    let max_seq_len = img_seq_len + *max_seq_len;
3253                    max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3254                };
3255                Ok(max_text_attn)
3256            }
3257        }
3258    }
3259
3260    fn non_mapped_max_act_size_elems(
3261        &self,
3262        config: &str,
3263        params: &AutoDeviceMapParams,
3264    ) -> Result<usize> {
3265        let AutoDeviceMapParams::Vision {
3266            max_seq_len: _,
3267            max_batch_size,
3268            max_image_shape: _,
3269            max_num_images,
3270        } = params
3271        else {
3272            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3273        };
3274
3275        let cfg: Gemma3Config = serde_json::from_str(config)?;
3276
3277        match cfg {
3278            Gemma3Config::WithVision { vision_config, .. } => {
3279                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3280                let img_seq_len = num_patches + 1;
3281
3282                let max_vision_attn = {
3283                    (max_batch_size * max_num_images)
3284                        * vision_config.num_attention_heads
3285                        * img_seq_len
3286                        * img_seq_len
3287                };
3288
3289                Ok(max_vision_attn)
3290            }
3291            Gemma3Config::Text(_) => Ok(0),
3292        }
3293    }
3294
3295    fn non_mapped_size_in_bytes(
3296        &self,
3297        config: &str,
3298        dtype: DType,
3299        weight_pack_factor: usize,
3300    ) -> Result<usize> {
3301        let cfg: Gemma3Config = serde_json::from_str(config)?;
3302
3303        let text_elems = {
3304            let cfg = match &cfg {
3305                Gemma3Config::Text(cfg) => cfg,
3306                Gemma3Config::WithVision { text_config, .. } => text_config,
3307            };
3308            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3309            let lm_head = if !cfg.tie_word_embeddings {
3310                cfg.hidden_size * cfg.vocab_size
3311            } else {
3312                0
3313            };
3314            let norm = cfg.hidden_size;
3315            embed_tokens + lm_head + norm
3316        };
3317
3318        let vision_transformer = if let Gemma3Config::WithVision {
3319            vision_config: cfg, ..
3320        } = &cfg
3321        {
3322            let post_layernorm = cfg.hidden_size;
3323
3324            let conv_config = Conv2dConfig {
3325                stride: cfg.patch_size,
3326                ..Default::default()
3327            };
3328            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3329                * cfg.patch_size
3330                * cfg.patch_size;
3331
3332            let num_patches_per_side = cfg.image_size / cfg.patch_size;
3333            let num_patches = num_patches_per_side.pow(2);
3334            let position_embedding = num_patches * cfg.hidden_size;
3335
3336            let layer_elems = {
3337                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3338                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3339
3340                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3341                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3342
3343                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3344                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3345                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3346                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3347
3348                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3349            };
3350
3351            post_layernorm
3352                + patch_embedding
3353                + position_embedding
3354                + layer_elems * cfg.num_hidden_layers
3355        } else {
3356            0
3357        };
3358
3359        let elems = text_elems + vision_transformer;
3360
3361        Ok(elems * dtype.size_in_bytes())
3362    }
3363
3364    fn layer_sizes_in_bytes(
3365        &self,
3366        config: &str,
3367        dtype: DType,
3368        weight_pack_factor: usize,
3369    ) -> Result<Vec<usize>> {
3370        let cfg: Gemma3Config = serde_json::from_str(config)?;
3371
3372        let txt_cfg = match &cfg {
3373            Gemma3Config::Text(cfg) => cfg,
3374            Gemma3Config::WithVision { text_config, .. } => text_config,
3375        };
3376        let per_layer_elems = {
3377            let cfg = txt_cfg;
3378
3379            let input_layernorm = cfg.hidden_size;
3380            let post_attention_layernorm = cfg.hidden_size;
3381
3382            let size_in = cfg.hidden_size;
3383            let size_q = cfg.head_dim * cfg.num_attention_heads;
3384            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3385            let q_proj =
3386                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3387            let k_proj =
3388                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3389            let v_proj =
3390                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3391            let o_proj =
3392                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3393
3394            let h_size = cfg.hidden_size;
3395            let i_size = cfg.intermediate_size;
3396            let gate_proj = h_size * i_size / weight_pack_factor;
3397            let up_proj = h_size * i_size / weight_pack_factor;
3398            let down_proj = i_size * h_size / weight_pack_factor;
3399
3400            input_layernorm
3401                + post_attention_layernorm
3402                + q_proj
3403                + k_proj
3404                + v_proj
3405                + o_proj
3406                + gate_proj
3407                + up_proj
3408                + down_proj
3409        };
3410        Ok(vec![
3411            per_layer_elems * dtype.size_in_bytes();
3412            txt_cfg.num_hidden_layers
3413        ])
3414    }
3415
3416    fn num_layers(&self, config: &str) -> Result<usize> {
3417        let cfg: Gemma3Config = serde_json::from_str(config)?;
3418
3419        let txt_cfg = match &cfg {
3420            Gemma3Config::Text(cfg) => cfg,
3421            Gemma3Config::WithVision { text_config, .. } => text_config,
3422        };
3423
3424        Ok(txt_cfg.num_hidden_layers)
3425    }
3426
3427    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3428        let cfg: Gemma3Config = serde_json::from_str(config)?;
3429
3430        let cfg = match &cfg {
3431            Gemma3Config::Text(cfg) => cfg,
3432            Gemma3Config::WithVision { text_config, .. } => text_config,
3433        };
3434
3435        let cfg = ModelConfigMetadata {
3436            max_seq_len: cfg.max_position_embeddings,
3437            num_layers: cfg.num_hidden_layers,
3438            hidden_size: cfg.hidden_size,
3439            num_kv_heads: cfg.num_key_value_heads,
3440            num_attn_heads: cfg.num_attention_heads,
3441            sliding_window: None, // None to be more forgiving, some do not
3442            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3443            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3444        };
3445
3446        Ok(Box::new(cfg))
3447    }
3448
3449    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3450        Some(vec![NonMappedSubModel::Vision])
3451    }
3452}
3453
3454// ======================== Mistral 3 Loader
3455
3456/// [`VisionLoader`] for an Mistral 3 model.
3457///
3458/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3459pub struct Mistral3Loader;
3460
3461pub struct Mistral3Prefixer;
3462
3463impl VisionPromptPrefixer for Mistral3Prefixer {
3464    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3465        prompt.to_string()
3466    }
3467}
3468
3469impl VisionModelLoader for Mistral3Loader {
3470    fn load(
3471        &self,
3472        config: &str,
3473        use_flash_attn: bool,
3474        vb: ShardedVarBuilder,
3475        normal_loading_metadata: NormalLoadingMetadata,
3476        attention_mechanism: AttentionImplementation,
3477    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3478        let mut config: Mistral3Config = serde_json::from_str(config)?;
3479        config.text_config.use_flash_attn = use_flash_attn;
3480        Ok(Box::new(Mistral3Model::new(
3481            &config,
3482            vb,
3483            self.is_gptx(),
3484            normal_loading_metadata,
3485            attention_mechanism,
3486        )?))
3487    }
3488    fn is_gptx(&self) -> bool {
3489        true
3490    }
3491    fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3492        let config: Mistral3Config = serde_json::from_str(config)?;
3493        Ok(Box::new(config))
3494    }
3495    fn get_processor(
3496        &self,
3497        _model_config: &str,
3498        processor_config: Option<ProcessorConfig>,
3499        _preprocessor_config: PreProcessorConfig,
3500        _max_edge: Option<u32>,
3501    ) -> Arc<dyn Processor + Send + Sync> {
3502        Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3503    }
3504    fn supports_paged_attention(&self) -> bool {
3505        true
3506    }
3507    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
3508        Arc::new(Mistral3Prefixer)
3509    }
3510}
3511
3512impl IsqModelLoader for Mistral3Loader {
3513    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3514        Ok(vec![
3515            Regex::new(r"lm_head\.(weight|bias)$")?,
3516            // Attention
3517            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3518            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3519            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3520            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
3521            // MLP
3522            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3523            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3524            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3525        ])
3526    }
3527}
3528
3529#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3530impl DeviceMappedModelLoader for Mistral3Loader {
3531    fn mapped_max_act_size_elems(
3532        &self,
3533        config: &str,
3534        params: &AutoDeviceMapParams,
3535        _prompt_chunksize: usize,
3536    ) -> Result<usize> {
3537        let cfg: Mistral3Config = serde_json::from_str(config)?;
3538        let vcfg = &cfg.vision_config;
3539        let tcfg = &cfg.text_config;
3540
3541        let AutoDeviceMapParams::Vision {
3542            max_seq_len,
3543            max_batch_size,
3544            max_image_shape: (mut height, mut width),
3545            max_num_images,
3546        } = params
3547        else {
3548            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3549        };
3550
3551        let img_seq_len = {
3552            // Reshaping algorithm
3553
3554            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
3555            let (max_height, max_width) = (1540, 1540);
3556            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3557            if ratio > 1. {
3558                height = (height as f64 / ratio).floor() as usize;
3559                width = (width as f64 / ratio).floor() as usize;
3560            }
3561
3562            let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
3563            let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
3564
3565            height = num_height_tokens * vcfg.patch_size;
3566            width = num_width_tokens * vcfg.patch_size;
3567
3568            let num_height_tokens = height / vcfg.patch_size;
3569            let num_width_tokens = width / vcfg.patch_size;
3570
3571            (num_width_tokens + 1) * num_height_tokens
3572        };
3573
3574        // This model injects the vision information directly into the input embeddings
3575        let max_seq_len = img_seq_len * max_num_images + *max_seq_len;
3576        Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
3577    }
3578
3579    fn non_mapped_max_act_size_elems(
3580        &self,
3581        config: &str,
3582        params: &AutoDeviceMapParams,
3583    ) -> Result<usize> {
3584        let cfg: Mistral3Config = serde_json::from_str(config)?;
3585        let cfg = &cfg.vision_config;
3586
3587        let AutoDeviceMapParams::Vision {
3588            max_seq_len: _,
3589            max_batch_size,
3590            max_image_shape: (mut height, mut width),
3591            max_num_images,
3592        } = params
3593        else {
3594            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3595        };
3596
3597        let img_seq_len = {
3598            // Reshaping algorithm
3599
3600            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
3601            let (max_height, max_width) = (1540, 1540);
3602            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3603            if ratio > 1. {
3604                height = (height as f64 / ratio).floor() as usize;
3605                width = (width as f64 / ratio).floor() as usize;
3606            }
3607
3608            let num_height_tokens = (height - 1) / cfg.patch_size + 1;
3609            let num_width_tokens = (width - 1) / cfg.patch_size + 1;
3610
3611            height = num_height_tokens * cfg.patch_size;
3612            width = num_width_tokens * cfg.patch_size;
3613
3614            let num_height_tokens = height / cfg.patch_size;
3615            let num_width_tokens = width / cfg.patch_size;
3616
3617            (num_width_tokens + 1) * num_height_tokens
3618        };
3619
3620        Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
3621    }
3622
3623    fn non_mapped_size_in_bytes(
3624        &self,
3625        config: &str,
3626        dtype: DType,
3627        weight_pack_factor: usize,
3628    ) -> Result<usize> {
3629        let cfg: Mistral3Config = serde_json::from_str(config)?;
3630
3631        let text_elems = {
3632            let cfg = &cfg.text_config;
3633
3634            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3635            let lm_head = if !cfg.tie_word_embeddings {
3636                cfg.hidden_size * cfg.vocab_size
3637            } else {
3638                0
3639            };
3640            let norm = cfg.hidden_size;
3641            embed_tokens + lm_head + norm
3642        };
3643
3644        let vision_elems = {
3645            let cfg = &cfg.vision_config;
3646
3647            let patch_embed = {
3648                let conv_cfg = Conv2dConfig {
3649                    stride: cfg.patch_size,
3650                    ..Default::default()
3651                };
3652                cfg.num_channels * cfg.hidden_size / conv_cfg.groups
3653                    * cfg.patch_size
3654                    * cfg.patch_size
3655                    * cfg.patch_size
3656            };
3657            let ln_pre = cfg.hidden_size;
3658            let vision_layer = {
3659                let attn_norm = cfg.hidden_size;
3660                let ffn_norm = cfg.hidden_size;
3661
3662                let gate = cfg.hidden_size * cfg.intermediate_size;
3663                let up = cfg.hidden_size * cfg.intermediate_size;
3664                let down = cfg.hidden_size * cfg.intermediate_size;
3665
3666                let q = cfg.hidden_size * cfg.hidden_size;
3667                let k = cfg.hidden_size * cfg.hidden_size;
3668                let v = cfg.hidden_size * cfg.hidden_size;
3669                let o = cfg.hidden_size * cfg.hidden_size;
3670
3671                attn_norm + ffn_norm + gate + up + down + q + k + v + o
3672            };
3673
3674            patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
3675        };
3676
3677        let elems = text_elems + vision_elems;
3678
3679        Ok(elems * dtype.size_in_bytes())
3680    }
3681
3682    fn layer_sizes_in_bytes(
3683        &self,
3684        config: &str,
3685        dtype: DType,
3686        weight_pack_factor: usize,
3687    ) -> Result<Vec<usize>> {
3688        let cfg: Mistral3Config = serde_json::from_str(config)?;
3689        let cfg = &cfg.text_config;
3690
3691        let per_layer_elems = {
3692            let input_layernorm = cfg.hidden_size;
3693            let post_attention_layernorm = cfg.hidden_size;
3694
3695            let size_in = cfg.hidden_size;
3696            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3697            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3698            let q_proj = size_in * size_q / weight_pack_factor;
3699            let k_proj = size_in * size_kv / weight_pack_factor;
3700            let v_proj = size_in * size_kv / weight_pack_factor;
3701            let o_proj = size_q * size_in / weight_pack_factor;
3702
3703            let h_size = cfg.hidden_size;
3704            let i_size = cfg.intermediate_size;
3705            let gate_proj = h_size * i_size / weight_pack_factor;
3706            let up_proj = h_size * i_size / weight_pack_factor;
3707            let down_proj = i_size * h_size / weight_pack_factor;
3708
3709            input_layernorm
3710                + post_attention_layernorm
3711                + q_proj
3712                + k_proj
3713                + v_proj
3714                + o_proj
3715                + gate_proj
3716                + up_proj
3717                + down_proj
3718        };
3719        Ok(vec![
3720            per_layer_elems * dtype.size_in_bytes();
3721            cfg.num_hidden_layers
3722        ])
3723    }
3724
3725    fn num_layers(&self, config: &str) -> Result<usize> {
3726        let cfg: Mistral3Config = serde_json::from_str(config)?;
3727        let cfg = &cfg.text_config;
3728        Ok(cfg.num_hidden_layers)
3729    }
3730
3731    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3732        let cfg: Mistral3Config = serde_json::from_str(config)?;
3733        let cfg = &cfg.text_config;
3734
3735        let cfg = ModelConfigMetadata {
3736            max_seq_len: cfg.max_position_embeddings,
3737            num_layers: cfg.num_hidden_layers,
3738            hidden_size: cfg.hidden_size,
3739            num_kv_heads: cfg.num_key_value_heads,
3740            num_attn_heads: cfg.num_attention_heads,
3741            sliding_window: cfg.sliding_window,
3742            k_head_dim: cfg.head_dim(),
3743            v_head_dim: cfg.head_dim(),
3744        };
3745
3746        Ok(Box::new(cfg))
3747    }
3748
3749    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3750        Some(vec![NonMappedSubModel::Vision])
3751    }
3752}
3753
3754// ======================== Llama 4 Loader
3755
3756/// [`VisionLoader`] for an Llama Vision model.
3757///
3758/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3759pub struct VLlama4Loader;
3760
3761pub struct VLlama4Prefixer;
3762
3763impl VisionPromptPrefixer for VLlama4Prefixer {
3764    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3765        format!(
3766            "{}{prompt}",
3767            llama4::IMAGE_TOKEN.repeat(image_indexes.len())
3768        )
3769    }
3770}
3771
3772impl VisionModelLoader for VLlama4Loader {
3773    fn load(
3774        &self,
3775        config: &str,
3776        use_flash_attn: bool,
3777        vb: ShardedVarBuilder,
3778        normal_loading_metadata: NormalLoadingMetadata,
3779        attention_mechanism: AttentionImplementation,
3780    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3781        let mut config: Llama4Config = serde_json::from_str(config)?;
3782        config.text_config.use_flash_attn = use_flash_attn;
3783        Ok(Box::new(Llama4Model::new(
3784            &config,
3785            vb,
3786            self.is_gptx(),
3787            normal_loading_metadata,
3788            attention_mechanism,
3789        )?))
3790    }
3791    fn is_gptx(&self) -> bool {
3792        false
3793    }
3794    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3795        let mut config: Llama4Config = serde_json::from_str(config)?;
3796        config.text_config.use_flash_attn = use_flash_attn;
3797        Ok(Box::new(config))
3798    }
3799    fn get_processor(
3800        &self,
3801        _model_config: &str,
3802        processor_config: Option<ProcessorConfig>,
3803        _preprocessor_config: PreProcessorConfig,
3804        _max_edge: Option<u32>,
3805    ) -> Arc<dyn Processor + Send + Sync> {
3806        Arc::new(Llama4Processor::new(&processor_config.unwrap()))
3807    }
3808    fn supports_paged_attention(&self) -> bool {
3809        true
3810    }
3811    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
3812        Arc::new(VLlama4Prefixer)
3813    }
3814}
3815
3816impl IsqModelLoader for VLlama4Loader {
3817    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3818        Ok(vec![
3819            Regex::new(r"lm_head\.(weight|bias)$")?,
3820            // Attention
3821            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3822            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3823            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3824            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3825            // FF MoE
3826            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
3827            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
3828            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
3829            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
3830            Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
3831            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
3832            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
3833            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
3834            // FF MLP
3835            Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
3836            Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
3837            Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
3838        ])
3839    }
3840}
3841
3842impl VLlama4Loader {
3843    /// This incorporates the max batch size!
3844    /// Returns (pixels max batch size, num text image tokens)
3845    #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3846    fn run_dummy_processing(
3847        &self,
3848        cfg: &Llama4Config,
3849        height: usize,
3850        width: usize,
3851        max_num_images: usize,
3852        max_batch_size: usize,
3853    ) -> Result<(usize, usize)> {
3854        let cfg = &cfg.vision_config;
3855
3856        let img_processor =
3857            Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
3858        let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
3859        let res = img_processor.preprocess(
3860            vec![image; max_num_images],
3861            vec![],
3862            &PreProcessorConfig::default(),
3863            &Device::Cpu,
3864            (max_batch_size, max_num_images),
3865        )?;
3866
3867        let pixels_batch_size = res.pixel_values.dim(0)?;
3868        let pixels_max_batch_size = pixels_batch_size * max_batch_size;
3869
3870        let (image_h, image_w) = (
3871            res.pixel_values.dim(D::Minus2).unwrap(),
3872            res.pixel_values.dim(D::Minus1).unwrap(),
3873        );
3874        let num_patches_per_chunk = (image_h / img_processor.patch_size)
3875            * (image_w / img_processor.patch_size)
3876            / img_processor.downsample_ratio;
3877
3878        Ok((
3879            pixels_max_batch_size,
3880            num_patches_per_chunk * pixels_max_batch_size,
3881        ))
3882    }
3883}
3884
3885impl DeviceMappedModelLoader for VLlama4Loader {
3886    fn mapped_max_act_size_elems(
3887        &self,
3888        config: &str,
3889        params: &AutoDeviceMapParams,
3890        _prompt_chunksize: usize,
3891    ) -> Result<usize> {
3892        let AutoDeviceMapParams::Vision {
3893            max_seq_len,
3894            max_batch_size,
3895            max_image_shape: (height, width),
3896            max_num_images,
3897        } = params
3898        else {
3899            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3900        };
3901
3902        let cfg: Llama4Config = serde_json::from_str(config)?;
3903
3904        let (_pixels_batch_size, num_text_image_toks) =
3905            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
3906
3907        let max_seq_len = max_seq_len + num_text_image_toks;
3908
3909        Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
3910    }
3911    fn non_mapped_max_act_size_elems(
3912        &self,
3913        config: &str,
3914        params: &AutoDeviceMapParams,
3915    ) -> Result<usize> {
3916        let AutoDeviceMapParams::Vision {
3917            max_seq_len: _,
3918            max_batch_size,
3919            max_image_shape: (height, width),
3920            max_num_images,
3921        } = params
3922        else {
3923            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3924        };
3925
3926        let cfg: Llama4Config = serde_json::from_str(config)?;
3927
3928        let (pixels_batch_size, _num_text_image_toks) =
3929            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
3930        let max_seq_len = cfg.vision_config.num_patches();
3931
3932        Ok((max_batch_size * pixels_batch_size)
3933            * cfg.vision_config.num_attention_heads
3934            * max_seq_len
3935            * max_seq_len)
3936    }
3937
3938    fn non_mapped_size_in_bytes(
3939        &self,
3940        config: &str,
3941        dtype: DType,
3942        weight_pack_factor: usize,
3943    ) -> Result<usize> {
3944        let cfg: Llama4Config = serde_json::from_str(config)?;
3945        let tcfg = &cfg.text_config;
3946
3947        let text_elems = {
3948            let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
3949            let lm_head = if !tcfg.tie_word_embeddings {
3950                tcfg.hidden_size * tcfg.vocab_size
3951            } else {
3952                0
3953            };
3954            let norm = tcfg.hidden_size;
3955            embed_tokens + lm_head + norm
3956        };
3957
3958        let vision_elems = {
3959            let cfg = &cfg.vision_config;
3960
3961            let num_patches = cfg.num_patches();
3962
3963            let unfold_elems =
3964                (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
3965            let class_embeddng_elems = cfg.hidden_size;
3966            let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
3967            let layernorm_pre_elems = cfg.hidden_size;
3968            let layernorm_post_elems = cfg.hidden_size;
3969
3970            let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
3971                / weight_pack_factor
3972                + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
3973
3974            let encoder_layer = {
3975                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
3976                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
3977
3978                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
3979                let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
3980                    / weight_pack_factor
3981                    + cfg.num_attention_heads * head_dim;
3982                let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
3983                    / weight_pack_factor
3984                    + cfg.num_attention_heads * head_dim;
3985                let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
3986                    / weight_pack_factor
3987                    + cfg.num_attention_heads * head_dim;
3988                let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
3989                    / weight_pack_factor
3990                    + cfg.num_attention_heads * head_dim;
3991
3992                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
3993                    + cfg.intermediate_size;
3994                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
3995                    + cfg.hidden_size;
3996
3997                input_layernorm
3998                    + post_attention_layernorm
3999                    + q_proj
4000                    + k_proj
4001                    + v_proj
4002                    + o_proj
4003                    + fc1
4004                    + fc2
4005            };
4006
4007            unfold_elems
4008                + class_embeddng_elems
4009                + positional_embedding_vlm_elems
4010                + layernorm_post_elems
4011                + layernorm_pre_elems
4012                + pixel_shuffle_elems
4013                + encoder_layer * cfg.num_hidden_layers
4014        };
4015
4016        let elems = text_elems + vision_elems;
4017
4018        Ok(elems * dtype.size_in_bytes())
4019    }
4020
4021    fn layer_sizes_in_bytes(
4022        &self,
4023        config: &str,
4024        dtype: DType,
4025        weight_pack_factor: usize,
4026    ) -> Result<Vec<usize>> {
4027        let cfg: Llama4Config = serde_json::from_str(config)?;
4028        let tcfg = &cfg.text_config;
4029
4030        let mut per_layer_elems = Vec::new();
4031
4032        for layer_idx in 0..tcfg.num_hidden_layers {
4033            let input_layernorm = tcfg.hidden_size;
4034            let post_attention_layernorm = tcfg.hidden_size;
4035
4036            let size_in = tcfg.hidden_size;
4037            let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4038            let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4039            let q_proj = size_in * size_q / weight_pack_factor;
4040            let k_proj = size_in * size_kv / weight_pack_factor;
4041            let v_proj = size_in * size_kv / weight_pack_factor;
4042            let o_proj = size_q * size_in / weight_pack_factor;
4043
4044            let use_moe = tcfg.moe_layers().contains(&layer_idx);
4045            let moe_block = if use_moe {
4046                let h_size = tcfg.hidden_size;
4047                let i_size = tcfg.intermediate_size;
4048                let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4049                let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4050                let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4051
4052                gate_proj + up_proj + down_proj
4053            } else {
4054                let h_size = tcfg.hidden_size;
4055                let i_size = tcfg.intermediate_size_mlp;
4056                let gate_proj = h_size * i_size / weight_pack_factor;
4057                let up_proj = h_size * i_size / weight_pack_factor;
4058                let down_proj = i_size * h_size / weight_pack_factor;
4059
4060                gate_proj + up_proj + down_proj
4061            };
4062
4063            per_layer_elems.push(
4064                input_layernorm
4065                    + post_attention_layernorm
4066                    + q_proj
4067                    + k_proj
4068                    + v_proj
4069                    + o_proj
4070                    + moe_block,
4071            );
4072        }
4073
4074        Ok(per_layer_elems
4075            .into_iter()
4076            .map(|x| x * dtype.size_in_bytes())
4077            .collect())
4078    }
4079
4080    fn num_layers(&self, config: &str) -> Result<usize> {
4081        let cfg: Llama4Config = serde_json::from_str(config)?;
4082        Ok(cfg.text_config.num_hidden_layers)
4083    }
4084
4085    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4086        let cfg: Llama4Config = serde_json::from_str(config)?;
4087        let cfg = &cfg.text_config;
4088
4089        let cfg = ModelConfigMetadata {
4090            max_seq_len: cfg.max_position_embeddings,
4091            num_layers: cfg.num_hidden_layers,
4092            hidden_size: cfg.hidden_size,
4093            num_kv_heads: cfg.num_attention_heads,
4094            num_attn_heads: cfg.num_attention_heads,
4095            sliding_window: None,
4096            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4097            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4098        };
4099
4100        Ok(Box::new(cfg))
4101    }
4102
4103    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4104        Some(vec![NonMappedSubModel::Vision])
4105    }
4106}