mistralrs_core/pipeline/loaders/
vision_loaders.rs

1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::device_map::DeviceMapper;
24use crate::layers::Conv3dConfig;
25use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
26use crate::pipeline::isq::IsqModelLoader;
27use crate::pipeline::loaders::AutoDeviceMapParams;
28use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
29use crate::pipeline::{
30    EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
31    SupportedModality,
32};
33use crate::utils::varbuilder_utils::DeviceForLoadTensor;
34use crate::vision_models::clip::ClipConfig;
35use crate::vision_models::gemma3::config::Gemma3Config;
36use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
37use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
38use crate::vision_models::idefics2_input_processor::Idefics2Processor;
39use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
40use crate::vision_models::image_processor::ImagePreProcessor;
41use crate::vision_models::inputs_processor::Phi4MMProcessor;
42use crate::vision_models::llama4::{
43    self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
44};
45use crate::vision_models::llava::config::Config as LLaVAConfig;
46use crate::vision_models::llava15::Model as LLaVA;
47use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
48use crate::vision_models::llava_next::Model as LLaVANext;
49use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
50use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
51use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
52use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
53use crate::vision_models::phi3_inputs_processor::Phi3Processor;
54use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
55use crate::vision_models::preprocessor_config::PreProcessorConfig;
56use crate::vision_models::processor_config::ProcessorConfig;
57use crate::vision_models::qwen2_5_vl::{
58    Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
59};
60use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
61use crate::vision_models::{minicpmo, phi4};
62
63pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
64    // pixel_values and pixel_attention_mask only specified for prompt seqs
65    #[allow(clippy::too_many_arguments)]
66    fn forward(
67        &self,
68        input_ids: &Tensor,
69        pixel_values: Option<Tensor>,
70        seqlen_offsets: &[usize],
71        context_lens: Vec<(usize, usize)>,
72        position_ids: Vec<usize>,
73        model_specific_args: Box<dyn Any>, // pixel attention mask, or image sizes, or anything else
74        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
75        flash_params: &FlashParams,
76    ) -> candle_core::Result<Tensor>;
77    fn device(&self) -> &Device;
78    fn cache(&self) -> &EitherCache;
79    fn cache_mut(&mut self) -> &mut EitherCache;
80    fn max_seq_len(&self) -> usize;
81    fn config(&self) -> &ModelConfigMetadata;
82    /// For a prompt without images. Requires batch size of 1!
83    fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
84}
85
86pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
87    fn load(
88        &self,
89        config: &str,
90        vb: ShardedVarBuilder,
91        normal_loading_metadata: NormalLoadingMetadata,
92        attention_mechanism: AttentionImplementation,
93    ) -> Result<Box<dyn VisionModel + Send + Sync>>;
94    fn is_gptx(&self, config: &str) -> bool;
95    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
96    fn get_processor(
97        &self,
98        model_config: &str,
99        processor_config: Option<ProcessorConfig>,
100        preprocessor_config: PreProcessorConfig,
101        max_edge: Option<u32>,
102    ) -> Arc<dyn Processor + Send + Sync>;
103    fn supports_paged_attention(&self, config: &str) -> bool;
104    fn supports_prefix_cacher(&self, _config: &str) -> bool {
105        // Default is false, specific model must override.
106        false
107    }
108    fn modalities(&self, config: &str) -> Result<Modalities>;
109    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
110    fn get_device_for_tensor(
111        &self,
112        config: &str,
113        _mapper: &dyn DeviceMapper,
114        loading_isq: bool,
115    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
116        if loading_isq {
117            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
118        } else {
119            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
120            let num_layers = self.model_config(config)?.num_layers();
121            let closure = move |name: String| {
122                if let Some(captures) = re.captures(&name) {
123                    captures
124                        .get(1)
125                        .and_then(|m| m.as_str().parse::<usize>().ok())
126                        .map(|l| l.min(num_layers))
127                        .map(DeviceForLoadTensor::Idx)
128                        .unwrap_or(DeviceForLoadTensor::Base)
129                } else {
130                    DeviceForLoadTensor::Base
131                }
132            };
133
134            Ok(Arc::new(closure))
135        }
136    }
137}
138
139#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
140#[derive(Clone, Debug, Deserialize, PartialEq)]
141/// The architecture to load the vision model as.
142pub enum VisionLoaderType {
143    #[serde(rename = "phi3v")]
144    Phi3V,
145    #[serde(rename = "idefics2")]
146    Idefics2,
147    #[serde(rename = "llava_next")]
148    LLaVANext,
149    #[serde(rename = "llava")]
150    LLaVA,
151    #[serde(rename = "vllama")]
152    VLlama,
153    #[serde(rename = "qwen2vl")]
154    Qwen2VL,
155    #[serde(rename = "idefics3")]
156    Idefics3,
157    #[serde(rename = "minicpmo")]
158    MiniCpmO,
159    #[serde(rename = "phi4mm")]
160    Phi4MM,
161    #[serde(rename = "qwen2_5vl")]
162    Qwen2_5VL,
163    #[serde(rename = "gemma3")]
164    Gemma3,
165    #[serde(rename = "mistral3")]
166    Mistral3,
167    #[serde(rename = "llama4")]
168    Llama4,
169}
170
171// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
172impl VisionLoaderType {
173    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
174        match name {
175            "Phi3VForCausalLM" => Ok(Self::Phi3V),
176            "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
177            "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
178            "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
179            "MllamaForConditionalGeneration" => Ok(Self::VLlama),
180            "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
181            "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
182            "MiniCPMO" => Ok(Self::MiniCpmO),
183            "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
184            "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
185            "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
186            "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
187            "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
188            other => anyhow::bail!(
189                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
190            ),
191        }
192    }
193}
194
195impl FromStr for VisionLoaderType {
196    type Err = String;
197    fn from_str(s: &str) -> Result<Self, Self::Err> {
198        match s {
199            "phi3v" => Ok(Self::Phi3V),
200            "idefics2" => Ok(Self::Idefics2),
201            "llava_next" => Ok(Self::LLaVANext),
202            "llava" => Ok(Self::LLaVA),
203            "vllama" => Ok(Self::VLlama),
204            "qwen2vl" => Ok(Self::Qwen2VL),
205            "idefics3" => Ok(Self::Idefics3),
206            "minicpmo" => Ok(Self::MiniCpmO),
207            "phi4mm" => Ok(Self::Phi4MM),
208            "qwen2_5vl" => Ok(Self::Qwen2_5VL),
209            "gemma3" => Ok(Self::Gemma3),
210            "mistral3" => Ok(Self::Mistral3),
211            "llama4" => Ok(Self::Llama4),
212            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`.")),
213        }
214    }
215}
216
217impl std::fmt::Display for VisionLoaderType {
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        let name = match self {
220            VisionLoaderType::Phi3V => "phi3v",
221            VisionLoaderType::Idefics2 => "idefics2",
222            VisionLoaderType::LLaVANext => "llava_next",
223            VisionLoaderType::LLaVA => "llava",
224            VisionLoaderType::VLlama => "vllama",
225            VisionLoaderType::Qwen2VL => "qwen2vl",
226            VisionLoaderType::Idefics3 => "idefics3",
227            VisionLoaderType::MiniCpmO => "minicpmo",
228            VisionLoaderType::Phi4MM => "phi4mm",
229            VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
230            VisionLoaderType::Gemma3 => "gemma3",
231            VisionLoaderType::Mistral3 => "mistral3",
232            VisionLoaderType::Llama4 => "llama4",
233        };
234        write!(f, "{name}")
235    }
236}
237
238#[derive(Deserialize)]
239struct AutoVisionLoaderConfig {
240    architectures: Vec<String>,
241}
242
243/// Automatically selects a VisionModelLoader implementation based on the JSON `architectures` field.
244pub struct AutoVisionLoader;
245
246impl AutoVisionLoader {
247    fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
248        let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
249        if auto_cfg.architectures.len() != 1 {
250            anyhow::bail!("Expected exactly one architecture in config");
251        }
252
253        let name = &auto_cfg.architectures[0];
254        let tp = VisionLoaderType::from_causal_lm_name(name)?;
255
256        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
257
258        // Delegate to the concrete loader
259        Ok(match tp {
260            VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
261            VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
262            VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
263            VisionLoaderType::LLaVA => Box::new(LLaVALoader),
264            VisionLoaderType::VLlama => Box::new(VLlamaLoader),
265            VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
266            VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
267            VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
268            VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
269            VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
270            VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
271            VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
272            VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
273        })
274    }
275}
276
277impl VisionModelLoader for AutoVisionLoader {
278    fn load(
279        &self,
280        config: &str,
281        vb: ShardedVarBuilder,
282        normal_loading_metadata: NormalLoadingMetadata,
283        attention_mechanism: AttentionImplementation,
284    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
285        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
286    }
287
288    fn is_gptx(&self, config: &str) -> bool {
289        Self::get_loader(config)
290            .expect("AutoVisionLoader get_loader")
291            .is_gptx(config)
292    }
293
294    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
295        Self::get_loader(config)?.get_config_repr(config)
296    }
297
298    fn get_processor(
299        &self,
300        model_config: &str,
301        proc_cfg: Option<ProcessorConfig>,
302        preproc_cfg: PreProcessorConfig,
303        max_edge: Option<u32>,
304    ) -> Arc<dyn Processor + Send + Sync> {
305        Self::get_loader(model_config)
306            .expect("AutoVisionLoader get_loader")
307            .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
308    }
309
310    fn supports_paged_attention(&self, config: &str) -> bool {
311        Self::get_loader(config)
312            .expect("AutoVisionLoader")
313            .supports_paged_attention(config)
314    }
315
316    fn modalities(&self, config: &str) -> Result<Modalities> {
317        Self::get_loader(config)?.modalities(config)
318    }
319
320    fn supports_prefix_cacher(&self, config: &str) -> bool {
321        Self::get_loader(config)
322            .expect("AutoVisionLoader")
323            .supports_prefix_cacher(config)
324    }
325
326    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
327        Self::get_loader(config)
328            .expect("AutoVisionLoader")
329            .prefixer(config)
330    }
331
332    fn get_device_for_tensor(
333        &self,
334        config: &str,
335        mapper: &dyn DeviceMapper,
336        loading_isq: bool,
337    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
338        Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
339    }
340}
341
342impl IsqModelLoader for AutoVisionLoader {
343    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
344        Self::get_loader(config)?.isq_layer_regexes(config)
345    }
346    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
347        Self::get_loader(config)?.immediate_isq_predicates(config)
348    }
349}
350
351impl DeviceMappedModelLoader for AutoVisionLoader {
352    fn mapped_max_act_size_elems(
353        &self,
354        config: &str,
355        params: &AutoDeviceMapParams,
356        prompt_chunksize: usize,
357    ) -> Result<usize> {
358        Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
359    }
360    fn non_mapped_max_act_size_elems(
361        &self,
362        config: &str,
363        params: &AutoDeviceMapParams,
364    ) -> Result<usize> {
365        Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
366    }
367    fn non_mapped_size_in_bytes(
368        &self,
369        config: &str,
370        dtype: DType,
371        weight_pack_factor: usize,
372    ) -> Result<usize> {
373        Self::get_loader(config)?.non_mapped_size_in_bytes(config, dtype, weight_pack_factor)
374    }
375    fn layer_sizes_in_bytes(
376        &self,
377        config: &str,
378        dtype: DType,
379        weight_pack_factor: usize,
380    ) -> Result<Vec<usize>> {
381        Self::get_loader(config)?.layer_sizes_in_bytes(config, dtype, weight_pack_factor)
382    }
383    fn num_layers(&self, config: &str) -> Result<usize> {
384        Self::get_loader(config)?.num_layers(config)
385    }
386    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
387        Self::get_loader(config)?.model_config(config)
388    }
389}
390
391macro_rules! bias_if {
392    ($cond:expr, $size:expr) => {
393        if $cond {
394            $size
395        } else {
396            0
397        }
398    };
399}
400
401fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
402    let pre_layer_norm = cfg.hidden_size;
403    let final_layer_norm = cfg.hidden_size;
404
405    let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
406    let num_positions = num_patches + 1;
407
408    let class_embedding = cfg.hidden_size;
409
410    let position_ids = num_positions;
411    let position_embedding = num_positions * cfg.hidden_size;
412
413    let conv2dconfig = Conv2dConfig {
414        stride: cfg.patch_size,
415        ..Default::default()
416    };
417    let patch_embedding =
418        cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
419
420    let encoder_layer_elems = {
421        let layer_norm1 = cfg.hidden_size;
422        let layer_norm2 = cfg.hidden_size;
423
424        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
425        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
426        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
427        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
428
429        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
430        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
431
432        layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
433    };
434
435    pre_layer_norm
436        + final_layer_norm
437        + class_embedding
438        + position_ids
439        + position_embedding
440        + patch_embedding
441        + cfg.num_hidden_layers * encoder_layer_elems
442}
443
444// ======================== Phi 3 loader
445
446/// [`VisionLoader`] for a Phi 3 Vision model.
447///
448/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
449pub struct Phi3VLoader;
450
451pub struct Phi3VPrefixer;
452
453impl MultimodalPromptPrefixer for Phi3VPrefixer {
454    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
455        // Image indexing starts at 0.
456        format!(
457            "{}{prompt}",
458            image_indexes
459                .into_iter()
460                .map(|image_index| format!("<|image_{}|>", image_index + 1))
461                .join("")
462        )
463    }
464}
465
466impl VisionModelLoader for Phi3VLoader {
467    fn load(
468        &self,
469        config: &str,
470        vb: ShardedVarBuilder,
471        normal_loading_metadata: NormalLoadingMetadata,
472        attention_mechanism: AttentionImplementation,
473    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
474        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
475        Ok(Box::new(Phi3::new(
476            &cfg,
477            vb,
478            self.is_gptx(config),
479            normal_loading_metadata,
480            attention_mechanism,
481        )?))
482    }
483    fn is_gptx(&self, _config: &str) -> bool {
484        true
485    }
486    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
487        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
488        Ok(Box::new(cfg))
489    }
490    fn get_processor(
491        &self,
492        _model_config: &str,
493        processor_config: Option<ProcessorConfig>,
494        preprocessor_config: PreProcessorConfig,
495        _max_edge: Option<u32>,
496    ) -> Arc<dyn Processor + Send + Sync> {
497        Phi3Processor::new_processor(processor_config, preprocessor_config)
498    }
499    fn supports_paged_attention(&self, _config: &str) -> bool {
500        true
501    }
502    fn supports_prefix_cacher(&self, _config: &str) -> bool {
503        true
504    }
505    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
506        Arc::new(Phi3VPrefixer)
507    }
508    fn modalities(&self, _config: &str) -> Result<Modalities> {
509        Ok(Modalities {
510            input: vec![SupportedModality::Text, SupportedModality::Vision],
511            output: vec![SupportedModality::Text],
512        })
513    }
514}
515
516impl IsqModelLoader for Phi3VLoader {
517    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
518        Ok(vec![
519            Regex::new(r"lm_head\.(weight|bias)$")?,
520            // Attention
521            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
522            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
523            // MLP
524            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
525            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
526        ])
527    }
528    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
529        self.isq_layer_regexes(config)
530    }
531}
532
533impl DeviceMappedModelLoader for Phi3VLoader {
534    fn mapped_max_act_size_elems(
535        &self,
536        config: &str,
537        params: &AutoDeviceMapParams,
538        _prompt_chunksize: usize,
539    ) -> Result<usize> {
540        // NOTE: we ignore max_num_images although it can only be one...
541        let AutoDeviceMapParams::Vision {
542            max_seq_len,
543            max_batch_size,
544            max_image_shape: _,
545            max_num_images,
546        } = params
547        else {
548            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
549        };
550
551        let cfg: Phi3Config = serde_json::from_str(config)?;
552
553        let vcfg = &PHI3V_CLIP_CONFIG;
554
555        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
556        let img_seq_len = (num_patches + 1) * max_num_images;
557
558        let max_text_attn = {
559            // This model injects the vision information directly into the input embeddings
560            let max_seq_len = img_seq_len + max_seq_len;
561            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
562        };
563
564        Ok(max_text_attn)
565    }
566
567    fn non_mapped_max_act_size_elems(
568        &self,
569        config: &str,
570        params: &AutoDeviceMapParams,
571    ) -> Result<usize> {
572        // NOTE: we ignore max_num_images although it can only be one...
573        let AutoDeviceMapParams::Vision {
574            max_seq_len: _,
575            max_batch_size,
576            max_image_shape: _,
577            max_num_images,
578        } = params
579        else {
580            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
581        };
582
583        let cfg: Phi3Config = serde_json::from_str(config)?;
584
585        let vcfg = &PHI3V_CLIP_CONFIG;
586
587        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
588        let img_seq_len = num_patches + 1;
589
590        let max_vision_attn = {
591            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
592        };
593
594        Ok(max_vision_attn)
595    }
596
597    fn non_mapped_size_in_bytes(
598        &self,
599        config: &str,
600        dtype: DType,
601        weight_pack_factor: usize,
602    ) -> Result<usize> {
603        let cfg: Phi3Config = serde_json::from_str(config)?;
604        let elems = {
605            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
606            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
607            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
608                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
609            } else {
610                0
611            };
612            let norm = cfg.hidden_size;
613
614            let image_embed = {
615                let projection_cls = cfg
616                    .embd_layer
617                    .projection_cls
618                    .clone()
619                    .unwrap_or("linear".to_string());
620                let with_learnable_separator =
621                    cfg.embd_layer.with_learnable_separator.unwrap_or(false);
622                let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
623                let image_dim_out = cfg.img_processor.image_dim_out;
624
625                let proj = match (projection_cls.as_str(), use_hd_transform) {
626                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
627                    ("mlp", true) => {
628                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
629                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
630                        a + b
631                    }
632                    ("mlp", false) => {
633                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
634                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
635                        a + b
636                    }
637                    _ => {
638                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
639                    }
640                };
641
642                let (glb_gn, sub_gn) = if with_learnable_separator {
643                    let glb_gn = image_dim_out * 4;
644                    let sub_gn = image_dim_out * 4;
645                    (glb_gn, sub_gn)
646                } else {
647                    (0, 0)
648                };
649
650                let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
651
652                proj + glb_gn + sub_gn + clip_vit
653            };
654
655            embed_tokens + lm_head + norm + image_embed
656        };
657
658        Ok(elems * dtype.size_in_bytes())
659    }
660
661    fn layer_sizes_in_bytes(
662        &self,
663        config: &str,
664        dtype: DType,
665        weight_pack_factor: usize,
666    ) -> Result<Vec<usize>> {
667        let cfg: Phi3Config = serde_json::from_str(config)?;
668        let per_layer_elems = {
669            let input_layernorm = cfg.hidden_size;
670            let post_attention_layernorm = cfg.hidden_size;
671
672            let size_in = cfg.hidden_size;
673            let head_dim = cfg.head_dim();
674            let op_size =
675                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
676            let qkv_proj = size_in * op_size / weight_pack_factor;
677            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
678
679            let h_size = cfg.hidden_size;
680            let i_size = cfg.intermediate_size;
681            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
682            let down_proj = h_size * i_size / weight_pack_factor;
683
684            input_layernorm
685                + post_attention_layernorm
686                + qkv_proj
687                + o_proj
688                + gate_up_proj
689                + down_proj
690        };
691        Ok(vec![
692            per_layer_elems * dtype.size_in_bytes();
693            cfg.num_hidden_layers
694        ])
695    }
696
697    fn num_layers(&self, config: &str) -> Result<usize> {
698        let cfg: Phi3Config = serde_json::from_str(config)?;
699        Ok(cfg.num_hidden_layers)
700    }
701
702    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
703        let cfg: Phi3Config = serde_json::from_str(config)?;
704
705        let cfg = ModelConfigMetadata {
706            max_seq_len: cfg.max_position_embeddings,
707            num_layers: cfg.num_hidden_layers,
708            hidden_size: cfg.hidden_size,
709            num_kv_heads: cfg.num_key_value_heads,
710            num_attn_heads: cfg.num_attention_heads,
711            sliding_window: cfg.sliding_window,
712            k_head_dim: cfg.head_dim(),
713            v_head_dim: cfg.head_dim(),
714        };
715
716        Ok(Box::new(cfg))
717    }
718
719    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
720        Some(vec![NonMappedSubModel::Vision])
721    }
722}
723
724// ======================== Idefics 2 loader
725
726/// [`VisionLoader`] for an Idefics 2 Vision model.
727///
728/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
729pub struct Idefics2Loader;
730
731pub struct Idefics2Prefixer;
732
733impl MultimodalPromptPrefixer for Idefics2Prefixer {
734    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
735        // Chat template does it
736        prompt.to_string()
737    }
738}
739
740impl VisionModelLoader for Idefics2Loader {
741    fn load(
742        &self,
743        config: &str,
744        vb: ShardedVarBuilder,
745        normal_loading_metadata: NormalLoadingMetadata,
746        attention_mechanism: AttentionImplementation,
747    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
748        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
749        Ok(Box::new(Idefics2::new(
750            &cfg,
751            vb,
752            self.is_gptx(config),
753            normal_loading_metadata,
754            attention_mechanism,
755        )?))
756    }
757    fn is_gptx(&self, _config: &str) -> bool {
758        true
759    }
760    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
761        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
762        Ok(Box::new(cfg))
763    }
764    fn get_processor(
765        &self,
766        _model_config: &str,
767        processor_config: Option<ProcessorConfig>,
768        preprocessor_config: PreProcessorConfig,
769        max_edge: Option<u32>,
770    ) -> Arc<dyn Processor + Send + Sync> {
771        Arc::new(Idefics2Processor::new(
772            processor_config.unwrap(),
773            preprocessor_config,
774            max_edge,
775        ))
776    }
777    fn supports_paged_attention(&self, _config: &str) -> bool {
778        true
779    }
780    fn supports_prefix_cacher(&self, _config: &str) -> bool {
781        true
782    }
783    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
784        Arc::new(Idefics2Prefixer)
785    }
786    fn modalities(&self, _config: &str) -> Result<Modalities> {
787        Ok(Modalities {
788            input: vec![SupportedModality::Text, SupportedModality::Vision],
789            output: vec![SupportedModality::Text],
790        })
791    }
792}
793
794impl IsqModelLoader for Idefics2Loader {
795    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
796        Ok(vec![
797            Regex::new(r"lm_head\.(weight|bias)$")?,
798            // Attention
799            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
800            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
801            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
802            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
803            // MLP
804            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
805            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
806            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
807        ])
808    }
809    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
810        Ok(vec![
811            Regex::new(r"lm_head\.(weight|bias)$")?,
812            // Attention
813            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
814            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
815            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
816            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
817            // MLP
818            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
819            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
820            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
821        ])
822    }
823}
824
825impl DeviceMappedModelLoader for Idefics2Loader {
826    fn mapped_max_act_size_elems(
827        &self,
828        config: &str,
829        params: &AutoDeviceMapParams,
830        _prompt_chunksize: usize,
831    ) -> Result<usize> {
832        let AutoDeviceMapParams::Vision {
833            max_seq_len,
834            max_batch_size,
835            max_image_shape: _,
836            max_num_images,
837        } = params
838        else {
839            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
840        };
841
842        let cfg: Idefics2Config = serde_json::from_str(config)?;
843
844        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
845        let img_seq_len = (num_patches + 1) * max_num_images;
846
847        let max_text_attn = {
848            // This model injects the vision information directly into the input embeddings
849            let max_seq_len = img_seq_len + max_seq_len;
850            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
851        };
852
853        Ok(max_text_attn)
854    }
855
856    fn non_mapped_max_act_size_elems(
857        &self,
858        config: &str,
859        params: &AutoDeviceMapParams,
860    ) -> Result<usize> {
861        let AutoDeviceMapParams::Vision {
862            max_seq_len: _,
863            max_batch_size,
864            max_image_shape: _,
865            max_num_images,
866        } = params
867        else {
868            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
869        };
870
871        let cfg: Idefics2Config = serde_json::from_str(config)?;
872
873        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
874        let img_seq_len = num_patches + 1;
875
876        let max_vision_attn = {
877            // do_image_splitting = true
878            let images_factor = 5;
879
880            (max_batch_size * images_factor * max_num_images)
881                * cfg.vision_config.num_attention_heads
882                * img_seq_len
883                * img_seq_len
884        };
885
886        Ok(max_vision_attn)
887    }
888
889    fn non_mapped_size_in_bytes(
890        &self,
891        config: &str,
892        dtype: DType,
893        weight_pack_factor: usize,
894    ) -> Result<usize> {
895        let cfg: Idefics2Config = serde_json::from_str(config)?;
896        let text_elems = {
897            let tie_word_embeddings = cfg.tie_word_embeddings;
898            let cfg = &cfg.text_config;
899
900            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
901            let lm_head = if !tie_word_embeddings {
902                cfg.hidden_size * cfg.vocab_size
903            } else {
904                0
905            };
906            let norm = cfg.hidden_size;
907            embed_tokens + lm_head + norm
908        };
909
910        let connector_elems = {
911            let tcfg = &cfg.text_config;
912            let vcfg = &cfg.vision_config;
913            let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
914            let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
915            let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
916
917            let perceiver_elems = {
918                let tcfg = &cfg.text_config;
919                let pcfg = &cfg.perceiver_config;
920
921                let n_latents = pcfg.resampler_n_latents;
922                let hidden_size = tcfg.hidden_size;
923                let depth = pcfg.resampler_depth;
924
925                let norm = tcfg.hidden_size;
926                let latents = n_latents * hidden_size;
927
928                let layer_elems = {
929                    let input_latents_norm = hidden_size;
930                    let input_context_norm = hidden_size;
931                    let post_attn_norm = hidden_size;
932
933                    let num_heads = pcfg.resampler_n_heads;
934                    let head_dim = pcfg.resampler_head_dim;
935                    let num_key_value_heads = pcfg.num_key_value_heads;
936
937                    let q_proj = hidden_size * num_heads * head_dim;
938                    let k_proj = hidden_size * num_key_value_heads * head_dim;
939                    let v_proj = hidden_size * num_key_value_heads * head_dim;
940                    let o_proj = num_heads * head_dim * hidden_size;
941
942                    let gate_proj = hidden_size * hidden_size * 4;
943                    let up_proj = hidden_size * hidden_size * 4;
944                    let down_proj = hidden_size * 4 * hidden_size;
945
946                    input_latents_norm
947                        + input_context_norm
948                        + post_attn_norm
949                        + q_proj
950                        + k_proj
951                        + v_proj
952                        + o_proj
953                        + gate_proj
954                        + up_proj
955                        + down_proj
956                };
957
958                norm + latents + layer_elems * depth
959            };
960
961            gate_proj + up_proj + down_proj + perceiver_elems
962        };
963
964        let vision_transformer = {
965            let cfg = &cfg.vision_config;
966
967            let post_layernorm = cfg.hidden_size;
968
969            let conv_config = Conv2dConfig {
970                stride: cfg.patch_size,
971                ..Default::default()
972            };
973            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
974                * cfg.patch_size
975                * cfg.patch_size;
976
977            let num_patches_per_side = cfg.image_size / cfg.patch_size;
978            let num_patches = num_patches_per_side.pow(2);
979            let position_embedding = num_patches * cfg.hidden_size;
980
981            let layer_elems = {
982                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
983                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
984
985                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
986                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
987
988                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
989                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
990                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
991                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
992
993                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
994            };
995
996            post_layernorm + patch_embedding + position_embedding + layer_elems
997        };
998
999        let elems = text_elems + connector_elems + vision_transformer;
1000
1001        Ok(elems * dtype.size_in_bytes())
1002    }
1003
1004    fn layer_sizes_in_bytes(
1005        &self,
1006        config: &str,
1007        dtype: DType,
1008        weight_pack_factor: usize,
1009    ) -> Result<Vec<usize>> {
1010        let cfg: Idefics2Config = serde_json::from_str(config)?;
1011        let cfg = cfg.text_config;
1012        let per_layer_elems = {
1013            let input_layernorm = cfg.hidden_size;
1014            let post_attention_layernorm = cfg.hidden_size;
1015
1016            let size_in = cfg.hidden_size;
1017            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1018            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1019            let q_proj = size_in * size_q / weight_pack_factor;
1020            let k_proj = size_in * size_kv / weight_pack_factor;
1021            let v_proj = size_in * size_kv / weight_pack_factor;
1022            let o_proj = size_q * size_in / weight_pack_factor;
1023
1024            let h_size = cfg.hidden_size;
1025            let i_size = cfg.intermediate_size;
1026            let gate_proj = h_size * i_size / weight_pack_factor;
1027            let up_proj = h_size * i_size / weight_pack_factor;
1028            let down_proj = i_size * h_size / weight_pack_factor;
1029
1030            input_layernorm
1031                + post_attention_layernorm
1032                + q_proj
1033                + k_proj
1034                + v_proj
1035                + o_proj
1036                + gate_proj
1037                + up_proj
1038                + down_proj
1039        };
1040        Ok(vec![
1041            per_layer_elems * dtype.size_in_bytes();
1042            cfg.num_hidden_layers
1043        ])
1044    }
1045
1046    fn num_layers(&self, config: &str) -> Result<usize> {
1047        let cfg: Idefics2Config = serde_json::from_str(config)?;
1048        Ok(cfg.text_config.num_hidden_layers)
1049    }
1050    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1051        let cfg: Idefics2Config = serde_json::from_str(config)?;
1052        let cfg = &cfg.text_config;
1053
1054        let cfg = ModelConfigMetadata {
1055            max_seq_len: cfg.max_position_embeddings,
1056            num_layers: cfg.num_hidden_layers,
1057            hidden_size: cfg.hidden_size,
1058            num_kv_heads: cfg.num_key_value_heads,
1059            num_attn_heads: cfg.num_attention_heads,
1060            sliding_window: cfg.sliding_window,
1061            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1062            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1063        };
1064
1065        Ok(Box::new(cfg))
1066    }
1067
1068    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1069        Some(vec![NonMappedSubModel::Vision])
1070    }
1071}
1072
1073// ======================== LLaVANext Loader
1074
1075/// [`VisionLoader`] for an LLaVANext Vision model.
1076///
1077/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1078pub struct LLaVANextLoader;
1079
1080pub struct LLaVANextPrefixer;
1081
1082impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1083    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1084        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1085    }
1086}
1087
1088impl VisionModelLoader for LLaVANextLoader {
1089    fn load(
1090        &self,
1091        config: &str,
1092        vb: ShardedVarBuilder,
1093        normal_loading_metadata: NormalLoadingMetadata,
1094        attention_mechanism: AttentionImplementation,
1095    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1096        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1097        Ok(Box::new(LLaVANext::new(
1098            &cfg,
1099            vb,
1100            self.is_gptx(config),
1101            normal_loading_metadata,
1102            attention_mechanism,
1103        )?))
1104    }
1105    fn is_gptx(&self, _config: &str) -> bool {
1106        false
1107    }
1108    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1109        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1110        Ok(Box::new(cfg))
1111    }
1112    fn get_processor(
1113        &self,
1114        model_config: &str,
1115        _processor_config: Option<ProcessorConfig>,
1116        _preprocessor_config: PreProcessorConfig,
1117        _max_edge: Option<u32>,
1118    ) -> Arc<dyn Processor + Send + Sync> {
1119        Arc::new(LLaVANextProcessor::new(model_config))
1120    }
1121    fn supports_paged_attention(&self, _config: &str) -> bool {
1122        true
1123    }
1124    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1125        true
1126    }
1127    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1128        Arc::new(LLaVANextPrefixer)
1129    }
1130    fn modalities(&self, _config: &str) -> Result<Modalities> {
1131        Ok(Modalities {
1132            input: vec![SupportedModality::Text, SupportedModality::Vision],
1133            output: vec![SupportedModality::Text],
1134        })
1135    }
1136}
1137
1138impl IsqModelLoader for LLaVANextLoader {
1139    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1140        Ok(vec![
1141            Regex::new(r"lm_head\.(weight|bias)$")?,
1142            // Attention
1143            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1144            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1145            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1146            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1147            // MLP
1148            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1149            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1150            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1151        ])
1152    }
1153    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1154        Ok(vec![
1155            Regex::new(r"lm_head\.(weight|bias)$")?,
1156            // Attention
1157            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1158            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1159            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1160            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1161            // MLP
1162            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1163            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1164            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1165        ])
1166    }
1167}
1168
1169impl DeviceMappedModelLoader for LLaVANextLoader {
1170    fn mapped_max_act_size_elems(
1171        &self,
1172        config: &str,
1173        params: &AutoDeviceMapParams,
1174        _prompt_chunksize: usize,
1175    ) -> Result<usize> {
1176        let AutoDeviceMapParams::Vision {
1177            max_seq_len,
1178            max_batch_size,
1179            max_image_shape,
1180            max_num_images,
1181        } = params
1182        else {
1183            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1184        };
1185
1186        let config: LLaVAConfig = serde_json::from_str(config)?;
1187
1188        #[allow(clippy::cast_possible_truncation)]
1189        let img_seq_len =
1190            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1191                &config,
1192                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1193            );
1194        let img_seq_len = img_seq_len * max_num_images;
1195
1196        let max_text_attn = {
1197            let cfg = &config.text_config;
1198            // This model injects the vision information directly into the input embeddings
1199            let max_seq_len = img_seq_len + max_seq_len;
1200
1201            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1202        };
1203
1204        Ok(max_text_attn)
1205    }
1206
1207    fn non_mapped_max_act_size_elems(
1208        &self,
1209        config: &str,
1210        params: &AutoDeviceMapParams,
1211    ) -> Result<usize> {
1212        let AutoDeviceMapParams::Vision {
1213            max_seq_len: _,
1214            max_batch_size,
1215            max_image_shape,
1216            max_num_images,
1217        } = params
1218        else {
1219            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1220        };
1221
1222        let config: LLaVAConfig = serde_json::from_str(config)?;
1223
1224        #[allow(clippy::cast_possible_truncation)]
1225        let img_seq_len =
1226            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1227                &config,
1228                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1229            );
1230
1231        let max_vision_attn = {
1232            (max_batch_size * max_num_images)
1233                * config.vision_config.num_attention_heads
1234                * img_seq_len
1235                * img_seq_len
1236        };
1237
1238        Ok(max_vision_attn)
1239    }
1240
1241    fn non_mapped_size_in_bytes(
1242        &self,
1243        config: &str,
1244        dtype: DType,
1245        weight_pack_factor: usize,
1246    ) -> Result<usize> {
1247        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1248        let text_elems = {
1249            let cfg = &cfg.text_config;
1250            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1251            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1252            let norm = cfg.hidden_size;
1253            embed_tokens + lm_head + norm
1254        };
1255
1256        let image_newline = cfg.text_config.hidden_size;
1257        let mmproj = {
1258            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1259                + cfg.text_config.hidden_size;
1260            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1261                + cfg.text_config.hidden_size;
1262
1263            linear_1 + linear_2
1264        };
1265        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1266
1267        let elems = text_elems + image_newline + mmproj + vision_tower;
1268        Ok(elems * dtype.size_in_bytes())
1269    }
1270
1271    fn layer_sizes_in_bytes(
1272        &self,
1273        config: &str,
1274        dtype: DType,
1275        weight_pack_factor: usize,
1276    ) -> Result<Vec<usize>> {
1277        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1278        let per_layer_elems = {
1279            let cfg = &cfg.text_config;
1280            let input_layernorm = cfg.hidden_size;
1281            let post_attention_layernorm = cfg.hidden_size;
1282
1283            let size_in = cfg.hidden_size;
1284            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1285            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1286            let q_proj = size_in * size_q / weight_pack_factor;
1287            let k_proj = size_in * size_kv / weight_pack_factor;
1288            let v_proj = size_in * size_kv / weight_pack_factor;
1289            let o_proj = size_q * size_in / weight_pack_factor;
1290
1291            let h_size = cfg.hidden_size;
1292            let i_size = cfg.intermediate_size;
1293            let gate_proj = h_size * i_size / weight_pack_factor;
1294            let up_proj = h_size * i_size / weight_pack_factor;
1295            let down_proj = i_size * h_size / weight_pack_factor;
1296
1297            input_layernorm
1298                + post_attention_layernorm
1299                + q_proj
1300                + k_proj
1301                + v_proj
1302                + o_proj
1303                + gate_proj
1304                + up_proj
1305                + down_proj
1306        };
1307        Ok(vec![
1308            per_layer_elems * dtype.size_in_bytes();
1309            cfg.text_config.num_hidden_layers
1310        ])
1311    }
1312
1313    fn num_layers(&self, config: &str) -> Result<usize> {
1314        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1315        Ok(cfg.text_config.num_hidden_layers)
1316    }
1317
1318    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1319        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1320        let cfg = &cfg.text_config;
1321
1322        let cfg = ModelConfigMetadata {
1323            max_seq_len: cfg.max_position_embeddings,
1324            num_layers: cfg.num_hidden_layers,
1325            hidden_size: cfg.hidden_size,
1326            num_kv_heads: cfg.num_key_value_heads,
1327            num_attn_heads: cfg.num_attention_heads,
1328            sliding_window: cfg.sliding_window,
1329            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1330            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1331        };
1332
1333        Ok(Box::new(cfg))
1334    }
1335
1336    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1337        Some(vec![NonMappedSubModel::Vision])
1338    }
1339}
1340
1341// ======================== LLaVA Loader
1342
1343/// [`VisionLoader`] for an LLaVA Vision model.
1344///
1345/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1346pub struct LLaVALoader;
1347
1348pub struct LLaVAPrefixer;
1349
1350impl MultimodalPromptPrefixer for LLaVAPrefixer {
1351    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1352        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1353    }
1354}
1355
1356impl VisionModelLoader for LLaVALoader {
1357    fn load(
1358        &self,
1359        config: &str,
1360        vb: ShardedVarBuilder,
1361        normal_loading_metadata: NormalLoadingMetadata,
1362        attention_mechanism: AttentionImplementation,
1363    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1364        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1365        Ok(Box::new(LLaVA::new(
1366            &cfg,
1367            vb,
1368            self.is_gptx(config),
1369            normal_loading_metadata,
1370            attention_mechanism,
1371        )?))
1372    }
1373    fn is_gptx(&self, _config: &str) -> bool {
1374        false
1375    }
1376    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1377        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1378        Ok(Box::new(cfg))
1379    }
1380    fn get_processor(
1381        &self,
1382        model_config: &str,
1383        _processor_config: Option<ProcessorConfig>,
1384        _preprocessor_config: PreProcessorConfig,
1385        _max_edge: Option<u32>,
1386    ) -> Arc<dyn Processor + Send + Sync> {
1387        Arc::new(LLaVAProcessor::new(model_config))
1388    }
1389    fn supports_paged_attention(&self, _config: &str) -> bool {
1390        true
1391    }
1392    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1393        true
1394    }
1395    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1396        Arc::new(LLaVAPrefixer)
1397    }
1398    fn modalities(&self, _config: &str) -> Result<Modalities> {
1399        Ok(Modalities {
1400            input: vec![SupportedModality::Text, SupportedModality::Vision],
1401            output: vec![SupportedModality::Text],
1402        })
1403    }
1404}
1405
1406impl IsqModelLoader for LLaVALoader {
1407    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1408        Ok(vec![
1409            Regex::new(r"lm_head\.(weight|bias)$")?,
1410            // Attention
1411            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1412            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1413            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1414            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1415            // MLP
1416            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1417            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1418            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1419        ])
1420    }
1421    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1422        Ok(vec![
1423            Regex::new(r"lm_head\.(weight|bias)$")?,
1424            // Attention
1425            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1426            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1427            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1428            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1429            // MLP
1430            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1431            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1432            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1433        ])
1434    }
1435}
1436
1437impl DeviceMappedModelLoader for LLaVALoader {
1438    fn mapped_max_act_size_elems(
1439        &self,
1440        config: &str,
1441        params: &AutoDeviceMapParams,
1442        _prompt_chunksize: usize,
1443    ) -> Result<usize> {
1444        let AutoDeviceMapParams::Vision {
1445            max_seq_len,
1446            max_batch_size,
1447            max_image_shape: _,
1448            max_num_images,
1449        } = params
1450        else {
1451            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1452        };
1453
1454        let config: LLaVAConfig = serde_json::from_str(config)?;
1455
1456        let img_seq_len =
1457            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1458        let img_seq_len = img_seq_len * max_num_images;
1459
1460        let max_text_attn = {
1461            let cfg = &config.text_config;
1462            // This model injects the vision information directly into the input embeddings
1463            let max_seq_len = img_seq_len + max_seq_len;
1464
1465            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1466        };
1467
1468        Ok(max_text_attn)
1469    }
1470
1471    fn non_mapped_max_act_size_elems(
1472        &self,
1473        config: &str,
1474        params: &AutoDeviceMapParams,
1475    ) -> Result<usize> {
1476        let AutoDeviceMapParams::Vision {
1477            max_seq_len: _,
1478            max_batch_size,
1479            max_image_shape: _,
1480            max_num_images,
1481        } = params
1482        else {
1483            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1484        };
1485
1486        let config: LLaVAConfig = serde_json::from_str(config)?;
1487
1488        let img_seq_len =
1489            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1490
1491        let max_vision_attn = {
1492            (max_batch_size * max_num_images)
1493                * config.vision_config.num_attention_heads
1494                * img_seq_len
1495                * img_seq_len
1496        };
1497
1498        Ok(max_vision_attn)
1499    }
1500
1501    fn non_mapped_size_in_bytes(
1502        &self,
1503        config: &str,
1504        dtype: DType,
1505        weight_pack_factor: usize,
1506    ) -> Result<usize> {
1507        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1508        let text_elems = {
1509            let cfg = &cfg.text_config;
1510            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1511            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1512            let norm = cfg.hidden_size;
1513            embed_tokens + lm_head + norm
1514        };
1515
1516        let image_newline = cfg.text_config.hidden_size;
1517        let mmproj = {
1518            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1519                + cfg.text_config.hidden_size;
1520            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1521                + cfg.text_config.hidden_size;
1522
1523            linear_1 + linear_2
1524        };
1525        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1526
1527        let elems = text_elems + image_newline + mmproj + vision_tower;
1528        Ok(elems * dtype.size_in_bytes())
1529    }
1530
1531    fn layer_sizes_in_bytes(
1532        &self,
1533        config: &str,
1534        dtype: DType,
1535        weight_pack_factor: usize,
1536    ) -> Result<Vec<usize>> {
1537        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1538        let per_layer_elems = {
1539            let cfg = &cfg.text_config;
1540            let input_layernorm = cfg.hidden_size;
1541            let post_attention_layernorm = cfg.hidden_size;
1542
1543            let size_in = cfg.hidden_size;
1544            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1545            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1546            let q_proj = size_in * size_q / weight_pack_factor;
1547            let k_proj = size_in * size_kv / weight_pack_factor;
1548            let v_proj = size_in * size_kv / weight_pack_factor;
1549            let o_proj = size_q * size_in / weight_pack_factor;
1550
1551            let h_size = cfg.hidden_size;
1552            let i_size = cfg.intermediate_size;
1553            let gate_proj = h_size * i_size / weight_pack_factor;
1554            let up_proj = h_size * i_size / weight_pack_factor;
1555            let down_proj = i_size * h_size / weight_pack_factor;
1556
1557            input_layernorm
1558                + post_attention_layernorm
1559                + q_proj
1560                + k_proj
1561                + v_proj
1562                + o_proj
1563                + gate_proj
1564                + up_proj
1565                + down_proj
1566        };
1567        Ok(vec![
1568            per_layer_elems * dtype.size_in_bytes();
1569            cfg.text_config.num_hidden_layers
1570        ])
1571    }
1572
1573    fn num_layers(&self, config: &str) -> Result<usize> {
1574        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1575        Ok(cfg.text_config.num_hidden_layers)
1576    }
1577
1578    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1579        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1580        let cfg = &cfg.text_config;
1581
1582        let cfg = ModelConfigMetadata {
1583            max_seq_len: cfg.max_position_embeddings,
1584            num_layers: cfg.num_hidden_layers,
1585            hidden_size: cfg.hidden_size,
1586            num_kv_heads: cfg.num_key_value_heads,
1587            num_attn_heads: cfg.num_attention_heads,
1588            sliding_window: cfg.sliding_window,
1589            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1590            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1591        };
1592
1593        Ok(Box::new(cfg))
1594    }
1595
1596    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1597        Some(vec![NonMappedSubModel::Vision])
1598    }
1599}
1600
1601// ======================== MLlama Loader
1602
1603/// [`VisionLoader`] for an Llama Vision model.
1604///
1605/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1606pub struct VLlamaLoader;
1607
1608pub struct VLlamaPrefixer;
1609
1610impl MultimodalPromptPrefixer for VLlamaPrefixer {
1611    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1612        format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1613    }
1614}
1615
1616impl VisionModelLoader for VLlamaLoader {
1617    fn load(
1618        &self,
1619        config: &str,
1620        vb: ShardedVarBuilder,
1621        normal_loading_metadata: NormalLoadingMetadata,
1622        attention_mechanism: AttentionImplementation,
1623    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1624        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1625        Ok(Box::new(MLlamaModel::new(
1626            &cfg,
1627            vb,
1628            self.is_gptx(config),
1629            normal_loading_metadata,
1630            attention_mechanism,
1631        )?))
1632    }
1633    fn is_gptx(&self, _config: &str) -> bool {
1634        true
1635    }
1636    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1637        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1638        Ok(Box::new(cfg))
1639    }
1640    fn get_processor(
1641        &self,
1642        _model_config: &str,
1643        _processor_config: Option<ProcessorConfig>,
1644        _preprocessor_config: PreProcessorConfig,
1645        _max_edge: Option<u32>,
1646    ) -> Arc<dyn Processor + Send + Sync> {
1647        Arc::new(MLlamaProcessor::new())
1648    }
1649    fn supports_paged_attention(&self, _config: &str) -> bool {
1650        false
1651    }
1652    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1653        true
1654    }
1655    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1656        Arc::new(VLlamaPrefixer)
1657    }
1658    fn modalities(&self, _config: &str) -> Result<Modalities> {
1659        Ok(Modalities {
1660            input: vec![SupportedModality::Text, SupportedModality::Vision],
1661            output: vec![SupportedModality::Text],
1662        })
1663    }
1664}
1665
1666impl IsqModelLoader for VLlamaLoader {
1667    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1668        let config: MLlamaConfig = serde_json::from_str(config)?;
1669        let cross_attn_layers = &config.text_config.cross_attention_layers;
1670        let transformer_layers =
1671            (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1672        let mut text_regexes = Vec::new();
1673        for layer in transformer_layers {
1674            text_regexes.extend(vec![
1675                // Attention text
1676                Regex::new(&format!(
1677                    r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1678                ))?,
1679                Regex::new(&format!(
1680                    r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1681                ))?,
1682                Regex::new(&format!(
1683                    r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1684                ))?,
1685                Regex::new(&format!(
1686                    r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1687                ))?,
1688                // MLP text
1689                Regex::new(&format!(
1690                    r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1691                ))?,
1692                Regex::new(&format!(
1693                    r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1694                ))?,
1695                Regex::new(&format!(
1696                    r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1697                ))?,
1698            ]);
1699        }
1700        let vision_regexes = vec![
1701            // Vision attention (transformer)
1702            Regex::new(
1703                r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1704            )?,
1705            Regex::new(
1706                r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1707            )?,
1708            Regex::new(
1709                r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1710            )?,
1711            Regex::new(
1712                r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1713            )?,
1714            // Vision attention (global transforemr)
1715            Regex::new(
1716                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1717            )?,
1718            Regex::new(
1719                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1720            )?,
1721            Regex::new(
1722                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1723            )?,
1724            Regex::new(
1725                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1726            )?,
1727            // MLP vision
1728            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1729            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1730        ];
1731
1732        Ok([text_regexes, vision_regexes].concat())
1733    }
1734    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1735        self.isq_layer_regexes(config)
1736    }
1737}
1738
1739impl DeviceMappedModelLoader for VLlamaLoader {
1740    fn mapped_max_act_size_elems(
1741        &self,
1742        config: &str,
1743        params: &AutoDeviceMapParams,
1744        _prompt_chunksize: usize,
1745    ) -> Result<usize> {
1746        let AutoDeviceMapParams::Vision {
1747            max_seq_len,
1748            max_batch_size,
1749            max_image_shape: _,
1750            max_num_images,
1751        } = params
1752        else {
1753            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1754        };
1755
1756        let config: MLlamaConfig = serde_json::from_str(config)?;
1757
1758        let img_seq_len = {
1759            let cfg = &config.vision_config;
1760            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1761            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1762            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1763        };
1764        let img_seq_len = img_seq_len * max_num_images;
1765
1766        let max_cross_text_attn = {
1767            let cfg = &config.text_config;
1768            max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1769        };
1770
1771        let max_self_text_attn = {
1772            let cfg = &config.text_config;
1773            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1774        };
1775
1776        Ok(max_self_text_attn.max(max_cross_text_attn))
1777    }
1778
1779    fn non_mapped_max_act_size_elems(
1780        &self,
1781        config: &str,
1782        params: &AutoDeviceMapParams,
1783    ) -> Result<usize> {
1784        let AutoDeviceMapParams::Vision {
1785            max_seq_len: _,
1786            max_batch_size,
1787            max_image_shape: _,
1788            max_num_images,
1789        } = params
1790        else {
1791            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1792        };
1793
1794        let config: MLlamaConfig = serde_json::from_str(config)?;
1795
1796        let img_seq_len = {
1797            let cfg = &config.vision_config;
1798            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1799            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1800            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1801        };
1802        let max_vision_attn = {
1803            let cfg = &config.vision_config;
1804            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1805        };
1806
1807        Ok(max_vision_attn)
1808    }
1809
1810    fn non_mapped_size_in_bytes(
1811        &self,
1812        config: &str,
1813        dtype: DType,
1814        weight_pack_factor: usize,
1815    ) -> Result<usize> {
1816        let config: MLlamaConfig = serde_json::from_str(config)?;
1817        let text_elems = {
1818            let cfg = &config.text_config;
1819            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1820            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1821            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1822                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1823            } else {
1824                0
1825            };
1826            let norm = cfg.hidden_size;
1827            embed_tokens + lm_head + norm
1828        };
1829
1830        let vision_elems = {
1831            let cfg = &config.vision_config;
1832
1833            let conv_cfg = Conv2dConfig {
1834                stride: cfg.patch_size,
1835                ..Default::default()
1836            };
1837            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1838                * cfg.patch_size
1839                * cfg.patch_size;
1840
1841            let class_embedding = cfg.hidden_size;
1842
1843            let gated_positional_embedding = {
1844                let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1845                let embedding = num_patches * cfg.hidden_size;
1846                let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1847                    * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1848
1849                embedding + tile_embedding
1850            };
1851
1852            let pre_tile_positional_embedding =
1853                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1854            let post_tile_positional_embedding =
1855                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1856
1857            let layernorm_pre = cfg.hidden_size;
1858            let layernorm_post = cfg.hidden_size;
1859
1860            let encoder_layer = {
1861                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1862                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1863
1864                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1865                let q_proj =
1866                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1867                let k_proj =
1868                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1869                let v_proj =
1870                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1871                let o_proj =
1872                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1873
1874                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1875                    + cfg.intermediate_size;
1876                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1877                    + cfg.hidden_size;
1878
1879                input_layernorm
1880                    + post_attention_layernorm
1881                    + q_proj
1882                    + k_proj
1883                    + v_proj
1884                    + o_proj
1885                    + fc1
1886                    + fc2
1887            };
1888
1889            patch_embedding
1890                + class_embedding
1891                + gated_positional_embedding
1892                + pre_tile_positional_embedding
1893                + post_tile_positional_embedding
1894                + layernorm_pre
1895                + layernorm_post
1896                + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1897        };
1898
1899        let elems = text_elems + vision_elems;
1900        Ok(elems * dtype.size_in_bytes())
1901    }
1902
1903    fn layer_sizes_in_bytes(
1904        &self,
1905        config: &str,
1906        dtype: DType,
1907        weight_pack_factor: usize,
1908    ) -> Result<Vec<usize>> {
1909        let config: MLlamaConfig = serde_json::from_str(config)?;
1910        let cfg = &config.text_config;
1911
1912        let mut layer_sizes = Vec::new();
1913
1914        for i in 0..cfg.num_hidden_layers {
1915            let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1916                // No isq for cross attention
1917                1
1918            } else {
1919                weight_pack_factor
1920            };
1921
1922            let per_layer_elems = {
1923                let input_layernorm = cfg.hidden_size;
1924                let post_attention_layernorm = cfg.hidden_size;
1925
1926                let size_in = cfg.hidden_size;
1927                let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1928                let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1929                let q_proj = size_in * size_q / weight_pack_factor;
1930                let k_proj = size_in * size_kv / weight_pack_factor;
1931                let v_proj = size_in * size_kv / weight_pack_factor;
1932                let o_proj = size_q * size_in / weight_pack_factor;
1933
1934                let h_size = cfg.hidden_size;
1935                let i_size = cfg.intermediate_size;
1936                let gate_proj = h_size * i_size / weight_pack_factor;
1937                let up_proj = h_size * i_size / weight_pack_factor;
1938                let down_proj = i_size * h_size / weight_pack_factor;
1939
1940                input_layernorm
1941                    + post_attention_layernorm
1942                    + q_proj
1943                    + k_proj
1944                    + v_proj
1945                    + o_proj
1946                    + gate_proj
1947                    + up_proj
1948                    + down_proj
1949            };
1950
1951            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1952        }
1953
1954        Ok(layer_sizes)
1955    }
1956
1957    fn num_layers(&self, config: &str) -> Result<usize> {
1958        let config: MLlamaConfig = serde_json::from_str(config)?;
1959        Ok(config.text_config.num_hidden_layers)
1960    }
1961
1962    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1963        let cfg: MLlamaConfig = serde_json::from_str(config)?;
1964        let cfg = &cfg.text_config;
1965
1966        let cfg = ModelConfigMetadata {
1967            max_seq_len: cfg.max_position_embeddings,
1968            num_layers: cfg.num_hidden_layers,
1969            hidden_size: cfg.hidden_size,
1970            num_kv_heads: cfg.num_key_value_heads,
1971            num_attn_heads: cfg.num_attention_heads,
1972            sliding_window: None,
1973            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1974            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1975        };
1976
1977        Ok(Box::new(cfg))
1978    }
1979
1980    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1981        Some(vec![NonMappedSubModel::Vision])
1982    }
1983}
1984
1985// ======================== Qwen2VL Loader
1986
1987/// [`VisionLoader`] for an Qwen2-VL model.
1988///
1989/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1990pub struct Qwen2VLLoader;
1991
1992pub struct Qwen2VLPrefixer;
1993
1994impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
1995    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1996        format!(
1997            "{}{prompt}",
1998            format!(
1999                "{}{}{}",
2000                Qwen2VLProcessor::VISION_START,
2001                Qwen2VLProcessor::IMAGE_PAD,
2002                Qwen2VLProcessor::VISION_END
2003            )
2004            .repeat(image_indexes.len())
2005        )
2006    }
2007}
2008
2009impl VisionModelLoader for Qwen2VLLoader {
2010    fn load(
2011        &self,
2012        config: &str,
2013        vb: ShardedVarBuilder,
2014        normal_loading_metadata: NormalLoadingMetadata,
2015        attention_mechanism: AttentionImplementation,
2016    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2017        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2018        Ok(Box::new(Qwen2VLModel::new(
2019            &cfg,
2020            vb,
2021            self.is_gptx(config),
2022            normal_loading_metadata,
2023            attention_mechanism,
2024        )?))
2025    }
2026    fn is_gptx(&self, _config: &str) -> bool {
2027        true
2028    }
2029    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2030        let config: Qwen2VLConfig = serde_json::from_str(config)?;
2031        Ok(Box::new(config))
2032    }
2033    fn get_processor(
2034        &self,
2035        _model_config: &str,
2036        _processor_config: Option<ProcessorConfig>,
2037        _preprocessor_config: PreProcessorConfig,
2038        max_edge: Option<u32>,
2039    ) -> Arc<dyn Processor + Send + Sync> {
2040        Arc::new(Qwen2VLProcessor::new(max_edge))
2041    }
2042    fn supports_paged_attention(&self, _config: &str) -> bool {
2043        false
2044    }
2045    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2046        Arc::new(Qwen2VLPrefixer)
2047    }
2048    fn modalities(&self, _config: &str) -> Result<Modalities> {
2049        Ok(Modalities {
2050            input: vec![SupportedModality::Text, SupportedModality::Vision],
2051            output: vec![SupportedModality::Text],
2052        })
2053    }
2054}
2055
2056impl IsqModelLoader for Qwen2VLLoader {
2057    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2058        Ok(vec![
2059            Regex::new(r"lm_head\.(weight|bias)$")?,
2060            // Attention
2061            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2062            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2063            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2064            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2065            // MLP
2066            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2067            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2068            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2069        ])
2070    }
2071    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2072        self.isq_layer_regexes(config)
2073    }
2074}
2075
2076impl DeviceMappedModelLoader for Qwen2VLLoader {
2077    fn mapped_max_act_size_elems(
2078        &self,
2079        config: &str,
2080        params: &AutoDeviceMapParams,
2081        _prompt_chunksize: usize,
2082    ) -> Result<usize> {
2083        let AutoDeviceMapParams::Vision {
2084            max_seq_len,
2085            max_batch_size,
2086            max_image_shape,
2087            max_num_images,
2088        } = params
2089        else {
2090            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2091        };
2092
2093        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2094
2095        let img_seq_len = {
2096            let cfg = &cfg.vision_config;
2097            let grid_t = max_num_images / cfg.temporal_patch_size;
2098            let grid_h = max_image_shape.0 / cfg.patch_size;
2099            let grid_w = max_image_shape.1 / cfg.patch_size;
2100            grid_t * grid_h * grid_w
2101        };
2102        let img_seq_len = img_seq_len * max_num_images;
2103
2104        let max_text_attn = {
2105            // This model injects the vision information directly into the input embeddings
2106            let max_seq_len = img_seq_len + max_seq_len;
2107            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2108        };
2109
2110        Ok(max_text_attn)
2111    }
2112
2113    fn non_mapped_max_act_size_elems(
2114        &self,
2115        config: &str,
2116        params: &AutoDeviceMapParams,
2117    ) -> Result<usize> {
2118        let AutoDeviceMapParams::Vision {
2119            max_seq_len: _,
2120            max_batch_size,
2121            max_image_shape,
2122            max_num_images,
2123        } = params
2124        else {
2125            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2126        };
2127
2128        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2129
2130        let img_seq_len = {
2131            let cfg = &cfg.vision_config;
2132            let grid_t = max_num_images / cfg.temporal_patch_size;
2133            let grid_h = max_image_shape.0 / cfg.patch_size;
2134            let grid_w = max_image_shape.1 / cfg.patch_size;
2135            grid_t * grid_h * grid_w
2136        };
2137
2138        let max_vision_attn = {
2139            let cfg = &cfg.vision_config;
2140            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2141        };
2142
2143        Ok(max_vision_attn)
2144    }
2145
2146    fn non_mapped_size_in_bytes(
2147        &self,
2148        config: &str,
2149        dtype: DType,
2150        weight_pack_factor: usize,
2151    ) -> Result<usize> {
2152        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2153        let text_elems = {
2154            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2155            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2156            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2157                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2158            } else {
2159                0
2160            };
2161            let norm = cfg.hidden_size;
2162            embed_tokens + lm_head + norm
2163        };
2164
2165        let patch_merger = {
2166            let cfg = &cfg.vision_config;
2167            let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2168
2169            let mlp0 = hidden_size * hidden_size + hidden_size;
2170            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2171
2172            let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2173
2174            mlp0 + mlp2 + ln_q
2175        };
2176
2177        let patch_embed = {
2178            let cfg = &cfg.vision_config;
2179            let conv_cfg = Conv3dConfig {
2180                stride: cfg.patch_size,
2181                ..Default::default()
2182            };
2183            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2184            cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2185                * kernel_sizes[0]
2186                * kernel_sizes[1]
2187                * kernel_sizes[2]
2188        };
2189
2190        let encoder_layer = {
2191            let cfg = &cfg.vision_config;
2192            let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2193            let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2194
2195            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2196            let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2197            let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2198            let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2199
2200            let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2201            let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2202
2203            norm1 + norm2 + fc1 + fc2 + qkv + out
2204        };
2205
2206        let elems =
2207            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2208
2209        Ok(elems * dtype.size_in_bytes())
2210    }
2211
2212    fn layer_sizes_in_bytes(
2213        &self,
2214        config: &str,
2215        dtype: DType,
2216        weight_pack_factor: usize,
2217    ) -> Result<Vec<usize>> {
2218        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2219        let per_layer_elems = {
2220            let input_layernorm = cfg.hidden_size;
2221            let post_attention_layernorm = cfg.hidden_size;
2222
2223            let size_in = cfg.hidden_size;
2224            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2225            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2226            let q_proj = size_in * size_q / weight_pack_factor + size_q;
2227            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2228            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2229            let o_proj = size_q * size_in / weight_pack_factor;
2230
2231            let h_size = cfg.hidden_size;
2232            let i_size = cfg.intermediate_size;
2233            let gate_proj = h_size * i_size / weight_pack_factor;
2234            let up_proj = h_size * i_size / weight_pack_factor;
2235            let down_proj = i_size * h_size / weight_pack_factor;
2236
2237            input_layernorm
2238                + post_attention_layernorm
2239                + q_proj
2240                + k_proj
2241                + v_proj
2242                + o_proj
2243                + gate_proj
2244                + up_proj
2245                + down_proj
2246        };
2247        Ok(vec![
2248            per_layer_elems * dtype.size_in_bytes();
2249            cfg.num_hidden_layers
2250        ])
2251    }
2252
2253    fn num_layers(&self, config: &str) -> Result<usize> {
2254        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2255        Ok(cfg.num_hidden_layers)
2256    }
2257
2258    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2259        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2260
2261        let cfg = ModelConfigMetadata {
2262            max_seq_len: cfg.max_position_embeddings,
2263            num_layers: cfg.num_hidden_layers,
2264            hidden_size: cfg.hidden_size,
2265            num_kv_heads: cfg.num_key_value_heads,
2266            num_attn_heads: cfg.num_attention_heads,
2267            sliding_window: cfg.sliding_window,
2268            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2269            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2270        };
2271
2272        Ok(Box::new(cfg))
2273    }
2274
2275    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2276        Some(vec![NonMappedSubModel::Vision])
2277    }
2278}
2279
2280// ======================== Idefics 3 loader
2281
2282/// [`VisionLoader`] for an Idefics 3 Vision model.
2283///
2284/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2285pub struct Idefics3Loader;
2286
2287pub struct Idefics3Prefixer;
2288
2289impl MultimodalPromptPrefixer for Idefics3Prefixer {
2290    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2291        // Chat template does it
2292        prompt.to_string()
2293    }
2294}
2295
2296impl VisionModelLoader for Idefics3Loader {
2297    fn load(
2298        &self,
2299        config: &str,
2300        vb: ShardedVarBuilder,
2301        normal_loading_metadata: NormalLoadingMetadata,
2302        attention_mechanism: AttentionImplementation,
2303    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2304        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2305        Ok(Box::new(Idefics3Model::new(
2306            &cfg,
2307            vb,
2308            self.is_gptx(config),
2309            normal_loading_metadata,
2310            attention_mechanism,
2311        )?))
2312    }
2313    fn is_gptx(&self, _config: &str) -> bool {
2314        true
2315    }
2316    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2317        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2318        Ok(Box::new(cfg))
2319    }
2320    fn get_processor(
2321        &self,
2322        _model_config: &str,
2323        processor_config: Option<ProcessorConfig>,
2324        preprocessor_config: PreProcessorConfig,
2325        max_edge: Option<u32>,
2326    ) -> Arc<dyn Processor + Send + Sync> {
2327        Arc::new(Idefics3Processor::new(
2328            processor_config.unwrap_or_default(),
2329            preprocessor_config,
2330            max_edge,
2331        ))
2332    }
2333    fn supports_paged_attention(&self, _config: &str) -> bool {
2334        true
2335    }
2336    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2337        true
2338    }
2339    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2340        Arc::new(Idefics3Prefixer)
2341    }
2342    fn modalities(&self, _config: &str) -> Result<Modalities> {
2343        Ok(Modalities {
2344            input: vec![SupportedModality::Text, SupportedModality::Vision],
2345            output: vec![SupportedModality::Text],
2346        })
2347    }
2348}
2349
2350impl IsqModelLoader for Idefics3Loader {
2351    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2352        Ok(vec![
2353            Regex::new(r"lm_head\.(weight|bias)$")?,
2354            // Attention
2355            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2356            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2357            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2358            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2359            // MLP
2360            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2361            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2362            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2363        ])
2364    }
2365    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2366        Ok(vec![
2367            Regex::new(r"lm_head\.(weight|bias)$")?,
2368            // Attention
2369            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2370            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2371            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2372            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2373            // MLP
2374            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2375            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2376            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2377            // // Attention (vision)
2378            // Regex::new(
2379            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2380            // )?,
2381            // Regex::new(
2382            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
2383            // )?,
2384            // Regex::new(
2385            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
2386            // )?,
2387            // Regex::new(
2388            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(weight|bias)$",
2389            // )?,
2390            // MLP (vision)
2391            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2392            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
2393        ])
2394    }
2395}
2396
2397impl DeviceMappedModelLoader for Idefics3Loader {
2398    fn mapped_max_act_size_elems(
2399        &self,
2400        config: &str,
2401        params: &AutoDeviceMapParams,
2402        _prompt_chunksize: usize,
2403    ) -> Result<usize> {
2404        let AutoDeviceMapParams::Vision {
2405            max_seq_len,
2406            max_batch_size,
2407            max_image_shape: _,
2408            max_num_images,
2409        } = params
2410        else {
2411            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2412        };
2413
2414        let cfg: Idefics3Config = serde_json::from_str(config)?;
2415
2416        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2417        let img_seq_len = (num_patches + 1) * max_num_images;
2418
2419        let max_text_attn = {
2420            // This model injects the vision information directly into the input embeddings
2421            let max_seq_len = img_seq_len + max_seq_len;
2422            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2423        };
2424
2425        Ok(max_text_attn)
2426    }
2427
2428    fn non_mapped_max_act_size_elems(
2429        &self,
2430        config: &str,
2431        params: &AutoDeviceMapParams,
2432    ) -> Result<usize> {
2433        let AutoDeviceMapParams::Vision {
2434            max_seq_len: _,
2435            max_batch_size,
2436            max_image_shape: _,
2437            max_num_images,
2438        } = params
2439        else {
2440            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2441        };
2442
2443        let cfg: Idefics3Config = serde_json::from_str(config)?;
2444
2445        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2446        let img_seq_len = num_patches + 1;
2447
2448        let max_vision_attn = {
2449            // do_image_splitting = true
2450            let images_factor = 5;
2451
2452            (max_batch_size * images_factor * max_num_images)
2453                * cfg.vision_config.num_attention_heads
2454                * img_seq_len
2455                * img_seq_len
2456        };
2457
2458        Ok(max_vision_attn)
2459    }
2460
2461    fn non_mapped_size_in_bytes(
2462        &self,
2463        config: &str,
2464        dtype: DType,
2465        weight_pack_factor: usize,
2466    ) -> Result<usize> {
2467        let cfg: Idefics3Config = serde_json::from_str(config)?;
2468        let text_elems = {
2469            let cfg = &cfg.text_config;
2470
2471            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2472            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2473            let norm = cfg.hidden_size;
2474            embed_tokens + lm_head + norm
2475        };
2476
2477        let connector_elems = {
2478            let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2479            let out_dim = cfg.text_config.hidden_size;
2480
2481            in_dim * out_dim
2482        };
2483
2484        let vision_transformer = {
2485            let cfg = &cfg.vision_config;
2486
2487            let post_layernorm = cfg.hidden_size;
2488
2489            let conv_config = Conv2dConfig {
2490                stride: cfg.patch_size,
2491                ..Default::default()
2492            };
2493            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2494                * cfg.patch_size
2495                * cfg.patch_size;
2496
2497            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2498            let num_patches = num_patches_per_side.pow(2);
2499            let position_embedding = num_patches * cfg.hidden_size;
2500
2501            let layer_elems = {
2502                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2503                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2504
2505                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2506                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2507
2508                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2509                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2510                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2511                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2512
2513                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2514            };
2515
2516            post_layernorm
2517                + patch_embedding
2518                + position_embedding
2519                + layer_elems * cfg.num_hidden_layers
2520        };
2521
2522        let elems = text_elems + connector_elems + vision_transformer;
2523
2524        Ok(elems * dtype.size_in_bytes())
2525    }
2526
2527    fn layer_sizes_in_bytes(
2528        &self,
2529        config: &str,
2530        dtype: DType,
2531        weight_pack_factor: usize,
2532    ) -> Result<Vec<usize>> {
2533        let cfg: Idefics3Config = serde_json::from_str(config)?;
2534        let cfg = cfg.text_config;
2535        let per_layer_elems = {
2536            let input_layernorm = cfg.hidden_size;
2537            let post_attention_layernorm = cfg.hidden_size;
2538
2539            let size_in = cfg.hidden_size;
2540            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2541            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2542            let q_proj = size_in * size_q / weight_pack_factor;
2543            let k_proj = size_in * size_kv / weight_pack_factor;
2544            let v_proj = size_in * size_kv / weight_pack_factor;
2545            let o_proj = size_q * size_in / weight_pack_factor;
2546
2547            let h_size = cfg.hidden_size;
2548            let i_size = cfg.intermediate_size;
2549            let gate_proj = h_size * i_size / weight_pack_factor;
2550            let up_proj = h_size * i_size / weight_pack_factor;
2551            let down_proj = i_size * h_size / weight_pack_factor;
2552
2553            input_layernorm
2554                + post_attention_layernorm
2555                + q_proj
2556                + k_proj
2557                + v_proj
2558                + o_proj
2559                + gate_proj
2560                + up_proj
2561                + down_proj
2562        };
2563        Ok(vec![
2564            per_layer_elems * dtype.size_in_bytes();
2565            cfg.num_hidden_layers
2566        ])
2567    }
2568
2569    fn num_layers(&self, config: &str) -> Result<usize> {
2570        let cfg: Idefics3Config = serde_json::from_str(config)?;
2571        Ok(cfg.text_config.num_hidden_layers)
2572    }
2573    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2574        let cfg: Idefics3Config = serde_json::from_str(config)?;
2575        let cfg = &cfg.text_config;
2576
2577        let cfg = ModelConfigMetadata {
2578            max_seq_len: cfg.max_position_embeddings,
2579            num_layers: cfg.num_hidden_layers,
2580            hidden_size: cfg.hidden_size,
2581            num_kv_heads: cfg.num_key_value_heads,
2582            num_attn_heads: cfg.num_attention_heads,
2583            sliding_window: None,
2584            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2585            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2586        };
2587
2588        Ok(Box::new(cfg))
2589    }
2590
2591    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2592        Some(vec![NonMappedSubModel::Vision])
2593    }
2594}
2595
2596// ======================== MiniCpm-O loader
2597
2598/// [`VisionLoader`] for an MiniCpm-O model.
2599///
2600/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2601pub struct MiniCpmOLoader;
2602
2603pub struct MiniCpmOPrefixer;
2604
2605impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2606    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2607        format!(
2608            "{}{prompt}",
2609            "(<image>./</image>)".repeat(image_indexes.len())
2610        )
2611    }
2612}
2613
2614impl VisionModelLoader for MiniCpmOLoader {
2615    fn load(
2616        &self,
2617        config: &str,
2618        vb: ShardedVarBuilder,
2619        normal_loading_metadata: NormalLoadingMetadata,
2620        attention_mechanism: AttentionImplementation,
2621    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2622        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2623        Ok(Box::new(MiniCpmOModel::new(
2624            &cfg,
2625            vb,
2626            self.is_gptx(config),
2627            normal_loading_metadata,
2628            attention_mechanism,
2629        )?))
2630    }
2631    fn is_gptx(&self, _config: &str) -> bool {
2632        true
2633    }
2634    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2635        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2636        Ok(Box::new(cfg))
2637    }
2638    fn get_processor(
2639        &self,
2640        _model_config: &str,
2641        processor_config: Option<ProcessorConfig>,
2642        preprocessor_config: PreProcessorConfig,
2643        max_edge: Option<u32>,
2644    ) -> Arc<dyn Processor + Send + Sync> {
2645        Arc::new(MiniCpmOProcessor::new(
2646            processor_config.unwrap_or_default(),
2647            preprocessor_config,
2648            max_edge,
2649        ))
2650    }
2651    fn supports_paged_attention(&self, _config: &str) -> bool {
2652        true
2653    }
2654    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2655        Arc::new(MiniCpmOPrefixer)
2656    }
2657    fn modalities(&self, _config: &str) -> Result<Modalities> {
2658        Ok(Modalities {
2659            input: vec![SupportedModality::Text, SupportedModality::Vision],
2660            output: vec![SupportedModality::Text],
2661        })
2662    }
2663}
2664
2665impl IsqModelLoader for MiniCpmOLoader {
2666    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2667        Ok(vec![
2668            Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2669            // Attention
2670            Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2671            Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2672            Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2673            Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2674            // MLP
2675            Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2676            Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2677            Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2678        ])
2679    }
2680    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2681        self.isq_layer_regexes(config)
2682    }
2683}
2684
2685impl DeviceMappedModelLoader for MiniCpmOLoader {
2686    fn mapped_max_act_size_elems(
2687        &self,
2688        config: &str,
2689        params: &AutoDeviceMapParams,
2690        _prompt_chunksize: usize,
2691    ) -> Result<usize> {
2692        let AutoDeviceMapParams::Vision {
2693            max_seq_len,
2694            max_batch_size,
2695            max_image_shape: _,
2696            max_num_images,
2697        } = params
2698        else {
2699            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2700        };
2701
2702        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2703
2704        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2705        let img_seq_len = (num_patches + 1) * max_num_images;
2706
2707        let max_text_attn = {
2708            // This model injects the vision information directly into the input embeddings
2709            let max_seq_len = img_seq_len + max_seq_len;
2710            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2711        };
2712
2713        Ok(max_text_attn)
2714    }
2715
2716    fn non_mapped_max_act_size_elems(
2717        &self,
2718        config: &str,
2719        params: &AutoDeviceMapParams,
2720    ) -> Result<usize> {
2721        let AutoDeviceMapParams::Vision {
2722            max_seq_len: _,
2723            max_batch_size,
2724            max_image_shape: _,
2725            max_num_images,
2726        } = params
2727        else {
2728            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2729        };
2730
2731        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2732
2733        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2734        let img_seq_len = num_patches + 1;
2735
2736        let max_vision_attn = {
2737            // do_image_splitting = true
2738            let images_factor = 5;
2739
2740            (max_batch_size * images_factor * max_num_images)
2741                * cfg.vision_config.num_attention_heads
2742                * img_seq_len
2743                * img_seq_len
2744        };
2745
2746        Ok(max_vision_attn)
2747    }
2748
2749    fn non_mapped_size_in_bytes(
2750        &self,
2751        config: &str,
2752        dtype: DType,
2753        weight_pack_factor: usize,
2754    ) -> Result<usize> {
2755        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2756        let text_elems = {
2757            let cfg = &cfg.text_config;
2758
2759            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2760            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2761            let norm = cfg.hidden_size;
2762            embed_tokens + lm_head + norm
2763        };
2764
2765        let vision_transformer = {
2766            let cfg = &cfg.vision_config;
2767
2768            let post_layernorm = cfg.hidden_size;
2769
2770            let conv_config = Conv2dConfig {
2771                stride: cfg.patch_size,
2772                ..Default::default()
2773            };
2774            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2775                * cfg.patch_size
2776                * cfg.patch_size;
2777
2778            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2779            let num_patches = num_patches_per_side.pow(2);
2780            let position_embedding = num_patches * cfg.hidden_size;
2781
2782            let layer_elems = {
2783                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2784                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2785
2786                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2787                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2788
2789                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2790                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2791                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2792                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2793
2794                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2795            };
2796
2797            post_layernorm
2798                + patch_embedding
2799                + position_embedding
2800                + layer_elems * cfg.num_hidden_layers
2801        };
2802
2803        let elems = text_elems + vision_transformer;
2804
2805        Ok(elems * dtype.size_in_bytes())
2806    }
2807
2808    fn layer_sizes_in_bytes(
2809        &self,
2810        config: &str,
2811        dtype: DType,
2812        weight_pack_factor: usize,
2813    ) -> Result<Vec<usize>> {
2814        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2815        let cfg = cfg.text_config;
2816        let per_layer_elems = {
2817            let input_layernorm = cfg.hidden_size;
2818            let post_attention_layernorm = cfg.hidden_size;
2819
2820            let size_in = cfg.hidden_size;
2821            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2822            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2823            let q_proj = size_in * size_q / weight_pack_factor;
2824            let k_proj = size_in * size_kv / weight_pack_factor;
2825            let v_proj = size_in * size_kv / weight_pack_factor;
2826            let o_proj = size_q * size_in / weight_pack_factor;
2827
2828            let h_size = cfg.hidden_size;
2829            let i_size = cfg.intermediate_size;
2830            let gate_proj = h_size * i_size / weight_pack_factor;
2831            let up_proj = h_size * i_size / weight_pack_factor;
2832            let down_proj = i_size * h_size / weight_pack_factor;
2833
2834            input_layernorm
2835                + post_attention_layernorm
2836                + q_proj
2837                + k_proj
2838                + v_proj
2839                + o_proj
2840                + gate_proj
2841                + up_proj
2842                + down_proj
2843        };
2844        Ok(vec![
2845            per_layer_elems * dtype.size_in_bytes();
2846            cfg.num_hidden_layers
2847        ])
2848    }
2849
2850    fn num_layers(&self, config: &str) -> Result<usize> {
2851        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2852        Ok(cfg.text_config.num_hidden_layers)
2853    }
2854    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2855        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2856        let cfg = &cfg.text_config;
2857
2858        let cfg = ModelConfigMetadata {
2859            max_seq_len: cfg.max_position_embeddings,
2860            num_layers: cfg.num_hidden_layers,
2861            hidden_size: cfg.hidden_size,
2862            num_kv_heads: cfg.num_key_value_heads,
2863            num_attn_heads: cfg.num_attention_heads,
2864            sliding_window: None,
2865            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2866            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2867        };
2868
2869        Ok(Box::new(cfg))
2870    }
2871}
2872
2873// ======================== Phi 4MM loader
2874
2875/// [`VisionLoader`] for a Phi 4MM Vision model.
2876///
2877/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2878pub struct Phi4MMLoader;
2879
2880pub struct Phi4MMPrefixer;
2881
2882impl MultimodalPromptPrefixer for Phi4MMPrefixer {
2883    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2884        // Image indexing starts at 0.
2885
2886        format!(
2887            "{}{prompt}",
2888            image_indexes
2889                .into_iter()
2890                .map(|image_index| format!("<|image_{}|>", image_index + 1))
2891                .join("")
2892        )
2893    }
2894    fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
2895        // Image indexing starts at 0.
2896
2897        format!(
2898            "{}{prompt}",
2899            audio_indexes
2900                .into_iter()
2901                .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
2902                .join("")
2903        )
2904    }
2905}
2906
2907impl VisionModelLoader for Phi4MMLoader {
2908    fn load(
2909        &self,
2910        config: &str,
2911        vb: ShardedVarBuilder,
2912        normal_loading_metadata: NormalLoadingMetadata,
2913        attention_mechanism: AttentionImplementation,
2914    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2915        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2916        Ok(Box::new(Phi4MMModel::new(
2917            &cfg,
2918            vb,
2919            self.is_gptx(config),
2920            normal_loading_metadata,
2921            attention_mechanism,
2922        )?))
2923    }
2924    fn is_gptx(&self, _config: &str) -> bool {
2925        true
2926    }
2927    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2928        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2929        Ok(Box::new(cfg))
2930    }
2931    fn get_processor(
2932        &self,
2933        _model_config: &str,
2934        processor_config: Option<ProcessorConfig>,
2935        preprocessor_config: PreProcessorConfig,
2936        _max_edge: Option<u32>,
2937    ) -> Arc<dyn Processor + Send + Sync> {
2938        Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2939    }
2940    fn supports_paged_attention(&self, _config: &str) -> bool {
2941        true
2942    }
2943    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2944        true
2945    }
2946    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2947        Arc::new(Phi4MMPrefixer)
2948    }
2949    fn modalities(&self, _config: &str) -> Result<Modalities> {
2950        Ok(Modalities {
2951            input: vec![
2952                SupportedModality::Text,
2953                SupportedModality::Vision,
2954                SupportedModality::Audio,
2955            ],
2956            output: vec![SupportedModality::Text],
2957        })
2958    }
2959}
2960
2961impl IsqModelLoader for Phi4MMLoader {
2962    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2963        Ok(vec![
2964            Regex::new(r"lm_head\.(weight|bias)$")?,
2965            // Attention
2966            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
2967            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2968            // MLP
2969            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
2970            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2971        ])
2972    }
2973    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2974        self.isq_layer_regexes(config)
2975    }
2976}
2977
2978impl DeviceMappedModelLoader for Phi4MMLoader {
2979    fn mapped_max_act_size_elems(
2980        &self,
2981        config: &str,
2982        params: &AutoDeviceMapParams,
2983        _prompt_chunksize: usize,
2984    ) -> Result<usize> {
2985        // NOTE: we ignore max_num_images although it can only be one...
2986        let AutoDeviceMapParams::Vision {
2987            max_seq_len,
2988            max_batch_size,
2989            max_image_shape: _,
2990            max_num_images,
2991        } = params
2992        else {
2993            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2994        };
2995
2996        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2997
2998        let vcfg = &PHI4_MM_VISION_CFG;
2999
3000        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3001        let img_seq_len = (num_patches + 1) * max_num_images;
3002
3003        let max_text_attn = {
3004            // This model injects the vision information directly into the input embeddings
3005            let max_seq_len = img_seq_len + max_seq_len;
3006            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3007        };
3008
3009        Ok(max_text_attn)
3010    }
3011
3012    fn non_mapped_max_act_size_elems(
3013        &self,
3014        _config: &str,
3015        params: &AutoDeviceMapParams,
3016    ) -> Result<usize> {
3017        let AutoDeviceMapParams::Vision {
3018            max_seq_len: _,
3019            max_batch_size,
3020            max_image_shape,
3021            max_num_images,
3022        } = params
3023        else {
3024            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3025        };
3026
3027        let vcfg = &PHI4_MM_VISION_CFG;
3028
3029        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3030        let img_seq_len = num_patches + 1;
3031
3032        let max_batch_size = max_batch_size
3033            * (max_image_shape
3034                .0
3035                .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3036                * max_image_shape
3037                    .1
3038                    .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3039                + 1);
3040
3041        let max_vision_attn = (max_batch_size * max_num_images)
3042            * vcfg.num_attention_heads
3043            * img_seq_len
3044            * img_seq_len;
3045        let max_qkv = 3
3046            * (max_batch_size
3047                * vcfg.num_attention_heads
3048                * img_seq_len
3049                * (vcfg.hidden_size / vcfg.num_attention_heads));
3050
3051        Ok(max_vision_attn + max_qkv)
3052    }
3053
3054    fn non_mapped_size_in_bytes(
3055        &self,
3056        config: &str,
3057        dtype: DType,
3058        weight_pack_factor: usize,
3059    ) -> Result<usize> {
3060        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3061        let elems = {
3062            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3063            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3064            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3065                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3066            } else {
3067                0
3068            };
3069            let norm = cfg.hidden_size;
3070
3071            let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3072                let projection_cls = img_embed
3073                    .projection_cls
3074                    .clone()
3075                    .unwrap_or("linear".to_string());
3076                let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3077                let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3078                let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3079
3080                let proj = match (projection_cls.as_str(), use_hd_transform) {
3081                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3082                    ("mlp", true) => {
3083                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3084                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3085                        a + b
3086                    }
3087                    ("mlp", false) => {
3088                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3089                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3090                        a + b
3091                    }
3092                    _ => {
3093                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3094                    }
3095                };
3096
3097                let (glb_gn, sub_gn) = if with_learnable_separator {
3098                    let glb_gn = image_dim_out * 4;
3099                    let sub_gn = image_dim_out * 4;
3100                    (glb_gn, sub_gn)
3101                } else {
3102                    (0, 0)
3103                };
3104
3105                let vision_transformer = {
3106                    let cfg = &PHI4_MM_VISION_CFG;
3107
3108                    let post_layernorm = cfg.hidden_size;
3109
3110                    let conv_config = Conv2dConfig {
3111                        stride: cfg.patch_size,
3112                        ..Default::default()
3113                    };
3114                    let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3115                        * cfg.patch_size
3116                        * cfg.patch_size;
3117
3118                    let num_patches_per_side = cfg.image_size / cfg.patch_size;
3119                    let num_patches = num_patches_per_side.pow(2);
3120                    let position_embedding = num_patches * cfg.hidden_size;
3121
3122                    let layer_elems = {
3123                        let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3124                        let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3125
3126                        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3127                        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3128
3129                        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3130                        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3131                        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3132                        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3133
3134                        layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3135                    };
3136
3137                    post_layernorm
3138                        + patch_embedding
3139                        + position_embedding
3140                        + layer_elems * cfg.num_hidden_layers
3141                };
3142
3143                proj + glb_gn + sub_gn + vision_transformer
3144            } else {
3145                0
3146            };
3147
3148            embed_tokens + lm_head + norm + image_embed
3149        };
3150
3151        Ok(elems * dtype.size_in_bytes())
3152    }
3153
3154    fn layer_sizes_in_bytes(
3155        &self,
3156        config: &str,
3157        dtype: DType,
3158        weight_pack_factor: usize,
3159    ) -> Result<Vec<usize>> {
3160        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3161        let per_layer_elems = {
3162            let input_layernorm = cfg.hidden_size;
3163            let post_attention_layernorm = cfg.hidden_size;
3164
3165            let size_in = cfg.hidden_size;
3166            let head_dim = cfg.head_dim();
3167            let op_size =
3168                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3169            let qkv_proj = size_in * op_size / weight_pack_factor;
3170            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3171
3172            let h_size = cfg.hidden_size;
3173            let i_size = cfg.intermediate_size;
3174            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3175            let down_proj = h_size * i_size / weight_pack_factor;
3176
3177            input_layernorm
3178                + post_attention_layernorm
3179                + qkv_proj
3180                + o_proj
3181                + gate_up_proj
3182                + down_proj
3183        };
3184        Ok(vec![
3185            per_layer_elems * dtype.size_in_bytes();
3186            cfg.num_hidden_layers
3187        ])
3188    }
3189
3190    fn num_layers(&self, config: &str) -> Result<usize> {
3191        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3192        Ok(cfg.num_hidden_layers)
3193    }
3194
3195    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3196        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3197
3198        let cfg = ModelConfigMetadata {
3199            max_seq_len: cfg.max_position_embeddings,
3200            num_layers: cfg.num_hidden_layers,
3201            hidden_size: cfg.hidden_size,
3202            num_kv_heads: cfg.num_key_value_heads(),
3203            num_attn_heads: cfg.num_attention_heads,
3204            sliding_window: cfg.sliding_window,
3205            k_head_dim: cfg.head_dim(),
3206            v_head_dim: cfg.head_dim(),
3207        };
3208
3209        Ok(Box::new(cfg))
3210    }
3211
3212    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3213        Some(vec![NonMappedSubModel::Vision])
3214    }
3215}
3216
3217// ======================== Qwen2_5VL Loader
3218
3219/// [`VisionLoader`] for an Qwen2_5VL model.
3220///
3221/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3222pub struct Qwen2_5VLLoader;
3223
3224pub struct Qwen2_5VLPrefixer;
3225
3226impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3227    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3228        format!(
3229            "{}{prompt}",
3230            format!(
3231                "{}{}{}",
3232                Qwen2_5VLProcessor::VISION_START,
3233                Qwen2_5VLProcessor::IMAGE_PAD,
3234                Qwen2_5VLProcessor::VISION_END
3235            )
3236            .repeat(image_indexes.len())
3237        )
3238    }
3239}
3240
3241impl VisionModelLoader for Qwen2_5VLLoader {
3242    fn load(
3243        &self,
3244        config: &str,
3245        vb: ShardedVarBuilder,
3246        normal_loading_metadata: NormalLoadingMetadata,
3247        attention_mechanism: AttentionImplementation,
3248    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3249        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3250        Ok(Box::new(Qwen2_5VLModel::new(
3251            &cfg,
3252            vb,
3253            self.is_gptx(config),
3254            normal_loading_metadata,
3255            attention_mechanism,
3256        )?))
3257    }
3258    fn is_gptx(&self, _config: &str) -> bool {
3259        true
3260    }
3261    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3262        let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3263        Ok(Box::new(config))
3264    }
3265    fn get_processor(
3266        &self,
3267        _model_config: &str,
3268        _processor_config: Option<ProcessorConfig>,
3269        _preprocessor_config: PreProcessorConfig,
3270        max_edge: Option<u32>,
3271    ) -> Arc<dyn Processor + Send + Sync> {
3272        Arc::new(Qwen2_5VLProcessor::new(max_edge))
3273    }
3274    fn supports_paged_attention(&self, _config: &str) -> bool {
3275        false
3276    }
3277    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3278        Arc::new(Qwen2_5VLPrefixer)
3279    }
3280    fn modalities(&self, _config: &str) -> Result<Modalities> {
3281        Ok(Modalities {
3282            input: vec![SupportedModality::Text, SupportedModality::Vision],
3283            output: vec![SupportedModality::Text],
3284        })
3285    }
3286}
3287
3288impl IsqModelLoader for Qwen2_5VLLoader {
3289    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3290        Ok(vec![
3291            Regex::new(r"lm_head\.(weight|bias)$")?,
3292            // Attention
3293            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3294            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3295            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3296            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3297            // MLP
3298            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3299            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3300            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3301        ])
3302    }
3303    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3304        self.isq_layer_regexes(config)
3305    }
3306}
3307
3308impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3309    fn mapped_max_act_size_elems(
3310        &self,
3311        config: &str,
3312        params: &AutoDeviceMapParams,
3313        _prompt_chunksize: usize,
3314    ) -> Result<usize> {
3315        let AutoDeviceMapParams::Vision {
3316            max_seq_len,
3317            max_batch_size,
3318            max_image_shape,
3319            max_num_images,
3320        } = params
3321        else {
3322            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3323        };
3324
3325        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3326
3327        let img_seq_len = {
3328            let cfg = &cfg.vision_config;
3329            let grid_t = max_num_images / cfg.temporal_patch_size;
3330            let grid_h = max_image_shape.0 / cfg.patch_size;
3331            let grid_w = max_image_shape.1 / cfg.patch_size;
3332            grid_t * grid_h * grid_w
3333        };
3334        let img_seq_len = img_seq_len * max_num_images;
3335
3336        let max_text_attn = {
3337            // This model injects the vision information directly into the input embeddings
3338            let max_seq_len = img_seq_len + max_seq_len;
3339            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3340        };
3341
3342        Ok(max_text_attn)
3343    }
3344
3345    fn non_mapped_max_act_size_elems(
3346        &self,
3347        config: &str,
3348        params: &AutoDeviceMapParams,
3349    ) -> Result<usize> {
3350        let AutoDeviceMapParams::Vision {
3351            max_seq_len: _,
3352            max_batch_size,
3353            max_image_shape,
3354            max_num_images,
3355        } = params
3356        else {
3357            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3358        };
3359
3360        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3361
3362        let img_seq_len = {
3363            let cfg = &cfg.vision_config;
3364            let grid_t = max_num_images / cfg.temporal_patch_size;
3365            let grid_h = max_image_shape.0 / cfg.patch_size;
3366            let grid_w = max_image_shape.1 / cfg.patch_size;
3367            grid_t * grid_h * grid_w
3368        };
3369
3370        let max_vision_attn = {
3371            let cfg = &cfg.vision_config;
3372            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3373        };
3374
3375        Ok(max_vision_attn)
3376    }
3377
3378    fn non_mapped_size_in_bytes(
3379        &self,
3380        config: &str,
3381        dtype: DType,
3382        weight_pack_factor: usize,
3383    ) -> Result<usize> {
3384        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3385        let text_elems = {
3386            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3387            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3388            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3389                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3390            } else {
3391                0
3392            };
3393            let norm = cfg.hidden_size;
3394            embed_tokens + lm_head + norm
3395        };
3396
3397        let patch_merger = {
3398            let cfg = &cfg.vision_config;
3399            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3400
3401            let mlp0 = hidden_size * hidden_size + hidden_size;
3402            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3403
3404            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3405
3406            mlp0 + mlp2 + ln_q
3407        };
3408
3409        let patch_embed = {
3410            let cfg = &cfg.vision_config;
3411            let conv_cfg = Conv3dConfig {
3412                stride: cfg.patch_size,
3413                ..Default::default()
3414            };
3415            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3416            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3417                * kernel_sizes[0]
3418                * kernel_sizes[1]
3419                * kernel_sizes[2]
3420        };
3421
3422        let encoder_layer = {
3423            let cfg = &cfg.vision_config;
3424            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3425            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3426
3427            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3428            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3429            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3430
3431            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3432            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3433
3434            norm1 + norm2 + fc1 + fc2 + qkv + out
3435        };
3436
3437        let elems =
3438            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3439
3440        Ok(elems * dtype.size_in_bytes())
3441    }
3442
3443    fn layer_sizes_in_bytes(
3444        &self,
3445        config: &str,
3446        dtype: DType,
3447        weight_pack_factor: usize,
3448    ) -> Result<Vec<usize>> {
3449        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3450        let per_layer_elems = {
3451            let input_layernorm = cfg.hidden_size;
3452            let post_attention_layernorm = cfg.hidden_size;
3453
3454            let size_in = cfg.hidden_size;
3455            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3456            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3457            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3458            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3459            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3460            let o_proj = size_q * size_in / weight_pack_factor;
3461
3462            let h_size = cfg.hidden_size;
3463            let i_size = cfg.intermediate_size;
3464            let gate_proj = h_size * i_size / weight_pack_factor;
3465            let up_proj = h_size * i_size / weight_pack_factor;
3466            let down_proj = i_size * h_size / weight_pack_factor;
3467
3468            input_layernorm
3469                + post_attention_layernorm
3470                + q_proj
3471                + k_proj
3472                + v_proj
3473                + o_proj
3474                + gate_proj
3475                + up_proj
3476                + down_proj
3477        };
3478        Ok(vec![
3479            per_layer_elems * dtype.size_in_bytes();
3480            cfg.num_hidden_layers
3481        ])
3482    }
3483
3484    fn num_layers(&self, config: &str) -> Result<usize> {
3485        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3486        Ok(cfg.num_hidden_layers)
3487    }
3488
3489    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3490        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3491
3492        let cfg = ModelConfigMetadata {
3493            max_seq_len: cfg.max_position_embeddings,
3494            num_layers: cfg.num_hidden_layers,
3495            hidden_size: cfg.hidden_size,
3496            num_kv_heads: cfg.num_key_value_heads,
3497            num_attn_heads: cfg.num_attention_heads,
3498            sliding_window: cfg.sliding_window,
3499            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3500            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3501        };
3502
3503        Ok(Box::new(cfg))
3504    }
3505
3506    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3507        Some(vec![NonMappedSubModel::Vision])
3508    }
3509}
3510
3511// ======================== Gemma 3 Loader
3512
3513/// [`VisionLoader`] for an Gemma 3 model.
3514///
3515/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3516pub struct Gemma3Loader;
3517
3518pub struct Gemma3Prefixer;
3519
3520impl MultimodalPromptPrefixer for Gemma3Prefixer {
3521    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3522        prompt.to_string()
3523    }
3524}
3525
3526impl VisionModelLoader for Gemma3Loader {
3527    fn load(
3528        &self,
3529        config: &str,
3530        vb: ShardedVarBuilder,
3531        normal_loading_metadata: NormalLoadingMetadata,
3532        attention_mechanism: AttentionImplementation,
3533    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3534        let cfg: Gemma3Config = serde_json::from_str(config)?;
3535        Ok(Box::new(Gemma3Model::new(
3536            &cfg,
3537            vb,
3538            self.is_gptx(config),
3539            normal_loading_metadata,
3540            attention_mechanism,
3541        )?))
3542    }
3543    fn is_gptx(&self, _config: &str) -> bool {
3544        true
3545    }
3546    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3547        let config: Gemma3Config = serde_json::from_str(config)?;
3548        Ok(Box::new(config))
3549    }
3550    fn get_processor(
3551        &self,
3552        config: &str,
3553        processor_config: Option<ProcessorConfig>,
3554        _preprocessor_config: PreProcessorConfig,
3555        _max_edge: Option<u32>,
3556    ) -> Arc<dyn Processor + Send + Sync> {
3557        let config: Gemma3Config = serde_json::from_str(config).unwrap();
3558        // Handle the Gemma 3 1b case here
3559        Arc::new(Gemma3Processor::new(
3560            processor_config.unwrap_or_default(),
3561            matches!(config, Gemma3Config::WithVision { .. }),
3562        ))
3563    }
3564    fn supports_paged_attention(&self, _config: &str) -> bool {
3565        true
3566    }
3567    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3568        true
3569    }
3570    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3571        Arc::new(Gemma3Prefixer)
3572    }
3573    fn modalities(&self, _config: &str) -> Result<Modalities> {
3574        Ok(Modalities {
3575            input: vec![SupportedModality::Text, SupportedModality::Vision],
3576            output: vec![SupportedModality::Text],
3577        })
3578    }
3579}
3580
3581impl IsqModelLoader for Gemma3Loader {
3582    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3583        Ok(vec![
3584            Regex::new(r"lm_head\.(weight|bias)$")?,
3585            // Attention
3586            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3587            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3588            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3589            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3590            // MLP
3591            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3592            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3593            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3594        ])
3595    }
3596    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3597        Ok(vec![
3598            Regex::new(r"lm_head\.(weight|bias)$")?,
3599            // Attention
3600            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3601            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3602            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3603            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3604            // MLP
3605            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3606            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3607            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3608        ])
3609    }
3610}
3611
3612impl DeviceMappedModelLoader for Gemma3Loader {
3613    fn mapped_max_act_size_elems(
3614        &self,
3615        config: &str,
3616        params: &AutoDeviceMapParams,
3617        prompt_chunksize: usize,
3618    ) -> Result<usize> {
3619        let AutoDeviceMapParams::Vision {
3620            max_seq_len,
3621            max_batch_size,
3622            max_image_shape: _,
3623            max_num_images,
3624        } = params
3625        else {
3626            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3627        };
3628
3629        let cfg: Gemma3Config = serde_json::from_str(config)?;
3630
3631        match cfg {
3632            Gemma3Config::Text(text_config) => Ok(max_batch_size
3633                * text_config.num_attention_heads
3634                * prompt_chunksize
3635                * prompt_chunksize),
3636            Gemma3Config::WithVision {
3637                text_config,
3638                vision_config,
3639                ..
3640            } => {
3641                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3642                let img_seq_len = (num_patches + 1) * max_num_images;
3643
3644                let max_text_attn = {
3645                    // This model injects the vision information directly into the input embeddings
3646                    let max_seq_len = img_seq_len + *max_seq_len;
3647                    max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3648                };
3649                Ok(max_text_attn)
3650            }
3651        }
3652    }
3653
3654    fn non_mapped_max_act_size_elems(
3655        &self,
3656        config: &str,
3657        params: &AutoDeviceMapParams,
3658    ) -> Result<usize> {
3659        let AutoDeviceMapParams::Vision {
3660            max_seq_len: _,
3661            max_batch_size,
3662            max_image_shape: _,
3663            max_num_images,
3664        } = params
3665        else {
3666            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3667        };
3668
3669        let cfg: Gemma3Config = serde_json::from_str(config)?;
3670
3671        match cfg {
3672            Gemma3Config::WithVision { vision_config, .. } => {
3673                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3674                let img_seq_len = num_patches + 1;
3675
3676                let max_vision_attn = {
3677                    (max_batch_size * max_num_images)
3678                        * vision_config.num_attention_heads
3679                        * img_seq_len
3680                        * img_seq_len
3681                };
3682
3683                Ok(max_vision_attn)
3684            }
3685            Gemma3Config::Text(_) => Ok(0),
3686        }
3687    }
3688
3689    fn non_mapped_size_in_bytes(
3690        &self,
3691        config: &str,
3692        dtype: DType,
3693        weight_pack_factor: usize,
3694    ) -> Result<usize> {
3695        let cfg: Gemma3Config = serde_json::from_str(config)?;
3696
3697        let text_elems = {
3698            let cfg = match &cfg {
3699                Gemma3Config::Text(cfg) => cfg,
3700                Gemma3Config::WithVision { text_config, .. } => text_config,
3701            };
3702            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3703            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3704            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3705                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3706            } else {
3707                0
3708            };
3709            let norm = cfg.hidden_size;
3710            embed_tokens + lm_head + norm
3711        };
3712
3713        let vision_transformer = if let Gemma3Config::WithVision {
3714            vision_config: cfg, ..
3715        } = &cfg
3716        {
3717            let post_layernorm = cfg.hidden_size;
3718
3719            let conv_config = Conv2dConfig {
3720                stride: cfg.patch_size,
3721                ..Default::default()
3722            };
3723            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3724                * cfg.patch_size
3725                * cfg.patch_size;
3726
3727            let num_patches_per_side = cfg.image_size / cfg.patch_size;
3728            let num_patches = num_patches_per_side.pow(2);
3729            let position_embedding = num_patches * cfg.hidden_size;
3730
3731            let layer_elems = {
3732                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3733                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3734
3735                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3736                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3737
3738                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3739                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3740                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3741                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3742
3743                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3744            };
3745
3746            post_layernorm
3747                + patch_embedding
3748                + position_embedding
3749                + layer_elems * cfg.num_hidden_layers
3750        } else {
3751            0
3752        };
3753
3754        let elems = text_elems + vision_transformer;
3755
3756        Ok(elems * dtype.size_in_bytes())
3757    }
3758
3759    fn layer_sizes_in_bytes(
3760        &self,
3761        config: &str,
3762        dtype: DType,
3763        weight_pack_factor: usize,
3764    ) -> Result<Vec<usize>> {
3765        let cfg: Gemma3Config = serde_json::from_str(config)?;
3766
3767        let txt_cfg = match &cfg {
3768            Gemma3Config::Text(cfg) => cfg,
3769            Gemma3Config::WithVision { text_config, .. } => text_config,
3770        };
3771        let per_layer_elems = {
3772            let cfg = txt_cfg;
3773
3774            let input_layernorm = cfg.hidden_size;
3775            let post_attention_layernorm = cfg.hidden_size;
3776
3777            let size_in = cfg.hidden_size;
3778            let size_q = cfg.head_dim * cfg.num_attention_heads;
3779            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3780            let q_proj =
3781                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3782            let k_proj =
3783                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3784            let v_proj =
3785                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3786            let o_proj =
3787                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3788
3789            let h_size = cfg.hidden_size;
3790            let i_size = cfg.intermediate_size;
3791            let gate_proj = h_size * i_size / weight_pack_factor;
3792            let up_proj = h_size * i_size / weight_pack_factor;
3793            let down_proj = i_size * h_size / weight_pack_factor;
3794
3795            input_layernorm
3796                + post_attention_layernorm
3797                + q_proj
3798                + k_proj
3799                + v_proj
3800                + o_proj
3801                + gate_proj
3802                + up_proj
3803                + down_proj
3804        };
3805        Ok(vec![
3806            per_layer_elems * dtype.size_in_bytes();
3807            txt_cfg.num_hidden_layers
3808        ])
3809    }
3810
3811    fn num_layers(&self, config: &str) -> Result<usize> {
3812        let cfg: Gemma3Config = serde_json::from_str(config)?;
3813
3814        let txt_cfg = match &cfg {
3815            Gemma3Config::Text(cfg) => cfg,
3816            Gemma3Config::WithVision { text_config, .. } => text_config,
3817        };
3818
3819        Ok(txt_cfg.num_hidden_layers)
3820    }
3821
3822    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3823        let cfg: Gemma3Config = serde_json::from_str(config)?;
3824
3825        let cfg = match &cfg {
3826            Gemma3Config::Text(cfg) => cfg,
3827            Gemma3Config::WithVision { text_config, .. } => text_config,
3828        };
3829
3830        let cfg = ModelConfigMetadata {
3831            max_seq_len: cfg.max_position_embeddings,
3832            num_layers: cfg.num_hidden_layers,
3833            hidden_size: cfg.hidden_size,
3834            num_kv_heads: cfg.num_key_value_heads,
3835            num_attn_heads: cfg.num_attention_heads,
3836            sliding_window: None, // None to be more forgiving, some do not
3837            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3838            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3839        };
3840
3841        Ok(Box::new(cfg))
3842    }
3843
3844    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3845        Some(vec![NonMappedSubModel::Vision])
3846    }
3847}
3848
3849// ======================== Mistral 3 Loader
3850
3851/// [`VisionLoader`] for an Mistral 3 model.
3852///
3853/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3854pub struct Mistral3Loader;
3855
3856pub struct Mistral3Prefixer;
3857
3858impl MultimodalPromptPrefixer for Mistral3Prefixer {
3859    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3860        prompt.to_string()
3861    }
3862}
3863
3864impl VisionModelLoader for Mistral3Loader {
3865    fn load(
3866        &self,
3867        config: &str,
3868        vb: ShardedVarBuilder,
3869        normal_loading_metadata: NormalLoadingMetadata,
3870        attention_mechanism: AttentionImplementation,
3871    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3872        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3873        Ok(Box::new(Mistral3Model::new(
3874            &cfg,
3875            vb,
3876            self.is_gptx(config),
3877            normal_loading_metadata,
3878            attention_mechanism,
3879        )?))
3880    }
3881    fn is_gptx(&self, _config: &str) -> bool {
3882        true
3883    }
3884    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3885        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3886        Ok(Box::new(cfg))
3887    }
3888    fn get_processor(
3889        &self,
3890        _model_config: &str,
3891        processor_config: Option<ProcessorConfig>,
3892        _preprocessor_config: PreProcessorConfig,
3893        _max_edge: Option<u32>,
3894    ) -> Arc<dyn Processor + Send + Sync> {
3895        Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3896    }
3897    fn supports_paged_attention(&self, _config: &str) -> bool {
3898        true
3899    }
3900    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3901        true
3902    }
3903    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3904        Arc::new(Mistral3Prefixer)
3905    }
3906    fn modalities(&self, _config: &str) -> Result<Modalities> {
3907        Ok(Modalities {
3908            input: vec![SupportedModality::Text, SupportedModality::Vision],
3909            output: vec![SupportedModality::Text],
3910        })
3911    }
3912}
3913
3914impl IsqModelLoader for Mistral3Loader {
3915    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3916        Ok(vec![
3917            Regex::new(r"lm_head\.(weight|bias)$")?,
3918            // Attention
3919            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3920            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3921            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3922            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3923            // MLP
3924            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3925            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3926            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3927        ])
3928    }
3929    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3930        Ok(vec![
3931            Regex::new(r"lm_head\.(weight|bias)$")?,
3932            // Attention
3933            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3934            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3935            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3936            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3937            // MLP
3938            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3939            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3940            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3941        ])
3942    }
3943}
3944
3945#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3946impl DeviceMappedModelLoader for Mistral3Loader {
3947    fn mapped_max_act_size_elems(
3948        &self,
3949        config: &str,
3950        params: &AutoDeviceMapParams,
3951        _prompt_chunksize: usize,
3952    ) -> Result<usize> {
3953        let cfg: Mistral3Config = serde_json::from_str(config)?;
3954        let vcfg = &cfg.vision_config;
3955        let tcfg = &cfg.text_config;
3956
3957        let AutoDeviceMapParams::Vision {
3958            max_seq_len,
3959            max_batch_size,
3960            max_image_shape: (mut height, mut width),
3961            max_num_images,
3962        } = params
3963        else {
3964            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3965        };
3966
3967        let img_seq_len = {
3968            // Reshaping algorithm
3969
3970            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
3971            let (max_height, max_width) = (1540, 1540);
3972            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3973            if ratio > 1. {
3974                height = (height as f64 / ratio).floor() as usize;
3975                width = (width as f64 / ratio).floor() as usize;
3976            }
3977
3978            let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
3979            let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
3980
3981            height = num_height_tokens * vcfg.patch_size;
3982            width = num_width_tokens * vcfg.patch_size;
3983
3984            let num_height_tokens = height / vcfg.patch_size;
3985            let num_width_tokens = width / vcfg.patch_size;
3986
3987            (num_width_tokens + 1) * num_height_tokens
3988        };
3989
3990        // This model injects the vision information directly into the input embeddings
3991        let max_seq_len = img_seq_len * max_num_images + *max_seq_len;
3992        Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
3993    }
3994
3995    fn non_mapped_max_act_size_elems(
3996        &self,
3997        config: &str,
3998        params: &AutoDeviceMapParams,
3999    ) -> Result<usize> {
4000        let cfg: Mistral3Config = serde_json::from_str(config)?;
4001        let cfg = &cfg.vision_config;
4002
4003        let AutoDeviceMapParams::Vision {
4004            max_seq_len: _,
4005            max_batch_size,
4006            max_image_shape: (mut height, mut width),
4007            max_num_images,
4008        } = params
4009        else {
4010            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4011        };
4012
4013        let img_seq_len = {
4014            // Reshaping algorithm
4015
4016            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
4017            let (max_height, max_width) = (1540, 1540);
4018            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4019            if ratio > 1. {
4020                height = (height as f64 / ratio).floor() as usize;
4021                width = (width as f64 / ratio).floor() as usize;
4022            }
4023
4024            let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4025            let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4026
4027            height = num_height_tokens * cfg.patch_size;
4028            width = num_width_tokens * cfg.patch_size;
4029
4030            let num_height_tokens = height / cfg.patch_size;
4031            let num_width_tokens = width / cfg.patch_size;
4032
4033            (num_width_tokens + 1) * num_height_tokens
4034        };
4035
4036        Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4037    }
4038
4039    fn non_mapped_size_in_bytes(
4040        &self,
4041        config: &str,
4042        dtype: DType,
4043        weight_pack_factor: usize,
4044    ) -> Result<usize> {
4045        let cfg: Mistral3Config = serde_json::from_str(config)?;
4046
4047        let text_elems = {
4048            let cfg = &cfg.text_config;
4049
4050            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4051            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4052            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4053                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4054            } else {
4055                0
4056            };
4057            let norm = cfg.hidden_size;
4058            embed_tokens + lm_head + norm
4059        };
4060
4061        let vision_elems = {
4062            let cfg = &cfg.vision_config;
4063
4064            let patch_embed = {
4065                let conv_cfg = Conv2dConfig {
4066                    stride: cfg.patch_size,
4067                    ..Default::default()
4068                };
4069                cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4070                    * cfg.patch_size
4071                    * cfg.patch_size
4072                    * cfg.patch_size
4073            };
4074            let ln_pre = cfg.hidden_size;
4075            let vision_layer = {
4076                let attn_norm = cfg.hidden_size;
4077                let ffn_norm = cfg.hidden_size;
4078
4079                let gate = cfg.hidden_size * cfg.intermediate_size;
4080                let up = cfg.hidden_size * cfg.intermediate_size;
4081                let down = cfg.hidden_size * cfg.intermediate_size;
4082
4083                let q = cfg.hidden_size * cfg.hidden_size;
4084                let k = cfg.hidden_size * cfg.hidden_size;
4085                let v = cfg.hidden_size * cfg.hidden_size;
4086                let o = cfg.hidden_size * cfg.hidden_size;
4087
4088                attn_norm + ffn_norm + gate + up + down + q + k + v + o
4089            };
4090
4091            patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4092        };
4093
4094        let elems = text_elems + vision_elems;
4095
4096        Ok(elems * dtype.size_in_bytes())
4097    }
4098
4099    fn layer_sizes_in_bytes(
4100        &self,
4101        config: &str,
4102        dtype: DType,
4103        weight_pack_factor: usize,
4104    ) -> Result<Vec<usize>> {
4105        let cfg: Mistral3Config = serde_json::from_str(config)?;
4106        let cfg = &cfg.text_config;
4107
4108        let per_layer_elems = {
4109            let input_layernorm = cfg.hidden_size;
4110            let post_attention_layernorm = cfg.hidden_size;
4111
4112            let size_in = cfg.hidden_size;
4113            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4114            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4115            let q_proj = size_in * size_q / weight_pack_factor;
4116            let k_proj = size_in * size_kv / weight_pack_factor;
4117            let v_proj = size_in * size_kv / weight_pack_factor;
4118            let o_proj = size_q * size_in / weight_pack_factor;
4119
4120            let h_size = cfg.hidden_size;
4121            let i_size = cfg.intermediate_size;
4122            let gate_proj = h_size * i_size / weight_pack_factor;
4123            let up_proj = h_size * i_size / weight_pack_factor;
4124            let down_proj = i_size * h_size / weight_pack_factor;
4125
4126            input_layernorm
4127                + post_attention_layernorm
4128                + q_proj
4129                + k_proj
4130                + v_proj
4131                + o_proj
4132                + gate_proj
4133                + up_proj
4134                + down_proj
4135        };
4136        Ok(vec![
4137            per_layer_elems * dtype.size_in_bytes();
4138            cfg.num_hidden_layers
4139        ])
4140    }
4141
4142    fn num_layers(&self, config: &str) -> Result<usize> {
4143        let cfg: Mistral3Config = serde_json::from_str(config)?;
4144        let cfg = &cfg.text_config;
4145        Ok(cfg.num_hidden_layers)
4146    }
4147
4148    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4149        let cfg: Mistral3Config = serde_json::from_str(config)?;
4150        let cfg = &cfg.text_config;
4151
4152        let cfg = ModelConfigMetadata {
4153            max_seq_len: cfg.max_position_embeddings,
4154            num_layers: cfg.num_hidden_layers,
4155            hidden_size: cfg.hidden_size,
4156            num_kv_heads: cfg.num_key_value_heads,
4157            num_attn_heads: cfg.num_attention_heads,
4158            sliding_window: cfg.sliding_window,
4159            k_head_dim: cfg.head_dim(),
4160            v_head_dim: cfg.head_dim(),
4161        };
4162
4163        Ok(Box::new(cfg))
4164    }
4165
4166    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4167        Some(vec![NonMappedSubModel::Vision])
4168    }
4169}
4170
4171// ======================== Llama 4 Loader
4172
4173/// [`VisionLoader`] for an Llama Vision model.
4174///
4175/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
4176pub struct VLlama4Loader;
4177
4178pub struct VLlama4Prefixer;
4179
4180impl MultimodalPromptPrefixer for VLlama4Prefixer {
4181    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4182        format!(
4183            "{}{prompt}",
4184            llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4185        )
4186    }
4187}
4188
4189impl VisionModelLoader for VLlama4Loader {
4190    fn load(
4191        &self,
4192        config: &str,
4193        vb: ShardedVarBuilder,
4194        normal_loading_metadata: NormalLoadingMetadata,
4195        attention_mechanism: AttentionImplementation,
4196    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4197        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4198        Ok(Box::new(Llama4Model::new(
4199            &cfg,
4200            vb,
4201            self.is_gptx(config),
4202            normal_loading_metadata,
4203            attention_mechanism,
4204        )?))
4205    }
4206    fn is_gptx(&self, _config: &str) -> bool {
4207        false
4208    }
4209    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4210        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4211        Ok(Box::new(cfg))
4212    }
4213    fn get_processor(
4214        &self,
4215        _model_config: &str,
4216        processor_config: Option<ProcessorConfig>,
4217        _preprocessor_config: PreProcessorConfig,
4218        _max_edge: Option<u32>,
4219    ) -> Arc<dyn Processor + Send + Sync> {
4220        Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4221    }
4222    fn supports_paged_attention(&self, _config: &str) -> bool {
4223        true
4224    }
4225    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4226        Arc::new(VLlama4Prefixer)
4227    }
4228    fn modalities(&self, _config: &str) -> Result<Modalities> {
4229        Ok(Modalities {
4230            input: vec![SupportedModality::Text, SupportedModality::Vision],
4231            output: vec![SupportedModality::Text],
4232        })
4233    }
4234}
4235
4236impl IsqModelLoader for VLlama4Loader {
4237    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4238        Ok(vec![
4239            Regex::new(r"lm_head\.(weight|bias)$")?,
4240            // Attention
4241            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4242            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4243            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4244            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4245            // FF MoE
4246            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4247            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4248            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4249            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4250            Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4251            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4252            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4253            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4254            // FF MLP
4255            Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4256            Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4257            Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4258        ])
4259    }
4260    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4261        Ok(vec![
4262            Regex::new(r"lm_head\.(weight|bias)$")?,
4263            // Attention
4264            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4265            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4266            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4267            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4268            // FF MoE
4269            Regex::new(
4270                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4271            )?,
4272            Regex::new(
4273                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4274            )?,
4275            Regex::new(
4276                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4277            )?,
4278            Regex::new(
4279                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4280            )?,
4281            Regex::new(
4282                r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4283            )?,
4284            Regex::new(
4285                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4286            )?,
4287            Regex::new(
4288                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4289            )?,
4290            Regex::new(
4291                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4292            )?,
4293            // FF MLP
4294            Regex::new(
4295                r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4296            )?,
4297            Regex::new(
4298                r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4299            )?,
4300            Regex::new(
4301                r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4302            )?,
4303        ])
4304    }
4305}
4306
4307impl VLlama4Loader {
4308    /// This incorporates the max batch size!
4309    /// Returns (pixels max batch size, num text image tokens)
4310    #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4311    fn run_dummy_processing(
4312        &self,
4313        cfg: &Llama4Config,
4314        height: usize,
4315        width: usize,
4316        max_num_images: usize,
4317        max_batch_size: usize,
4318    ) -> Result<(usize, usize)> {
4319        let cfg = &cfg.vision_config;
4320
4321        let img_processor =
4322            Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4323        let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4324        let res = img_processor.preprocess(
4325            vec![image; max_num_images],
4326            vec![],
4327            &PreProcessorConfig::default(),
4328            &Device::Cpu,
4329            (max_batch_size, max_num_images),
4330        )?;
4331
4332        let pixels_batch_size = res.pixel_values.dim(0)?;
4333        let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4334
4335        let (image_h, image_w) = (
4336            res.pixel_values.dim(D::Minus2).unwrap(),
4337            res.pixel_values.dim(D::Minus1).unwrap(),
4338        );
4339        let num_patches_per_chunk = (image_h / img_processor.patch_size)
4340            * (image_w / img_processor.patch_size)
4341            / img_processor.downsample_ratio;
4342
4343        Ok((
4344            pixels_max_batch_size,
4345            num_patches_per_chunk * pixels_max_batch_size,
4346        ))
4347    }
4348}
4349
4350impl DeviceMappedModelLoader for VLlama4Loader {
4351    fn mapped_max_act_size_elems(
4352        &self,
4353        config: &str,
4354        params: &AutoDeviceMapParams,
4355        _prompt_chunksize: usize,
4356    ) -> Result<usize> {
4357        let AutoDeviceMapParams::Vision {
4358            max_seq_len,
4359            max_batch_size,
4360            max_image_shape: (height, width),
4361            max_num_images,
4362        } = params
4363        else {
4364            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4365        };
4366
4367        let cfg: Llama4Config = serde_json::from_str(config)?;
4368
4369        let (_pixels_batch_size, num_text_image_toks) =
4370            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4371
4372        let max_seq_len = max_seq_len + num_text_image_toks;
4373
4374        Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4375    }
4376    fn non_mapped_max_act_size_elems(
4377        &self,
4378        config: &str,
4379        params: &AutoDeviceMapParams,
4380    ) -> Result<usize> {
4381        let AutoDeviceMapParams::Vision {
4382            max_seq_len: _,
4383            max_batch_size,
4384            max_image_shape: (height, width),
4385            max_num_images,
4386        } = params
4387        else {
4388            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4389        };
4390
4391        let cfg: Llama4Config = serde_json::from_str(config)?;
4392
4393        let (pixels_batch_size, _num_text_image_toks) =
4394            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4395        let max_seq_len = cfg.vision_config.num_patches();
4396
4397        Ok((max_batch_size * pixels_batch_size)
4398            * cfg.vision_config.num_attention_heads
4399            * max_seq_len
4400            * max_seq_len)
4401    }
4402
4403    fn non_mapped_size_in_bytes(
4404        &self,
4405        config: &str,
4406        dtype: DType,
4407        weight_pack_factor: usize,
4408    ) -> Result<usize> {
4409        let cfg: Llama4Config = serde_json::from_str(config)?;
4410        let tcfg = &cfg.text_config;
4411
4412        let text_elems = {
4413            let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4414            let lm_head = if !tcfg.tie_word_embeddings {
4415                tcfg.hidden_size * tcfg.vocab_size
4416            } else {
4417                0
4418            };
4419            let norm = tcfg.hidden_size;
4420            embed_tokens + lm_head + norm
4421        };
4422
4423        let vision_elems = {
4424            let cfg = &cfg.vision_config;
4425
4426            let num_patches = cfg.num_patches();
4427
4428            let unfold_elems =
4429                (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4430            let class_embeddng_elems = cfg.hidden_size;
4431            let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4432            let layernorm_pre_elems = cfg.hidden_size;
4433            let layernorm_post_elems = cfg.hidden_size;
4434
4435            let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4436                / weight_pack_factor
4437                + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4438
4439            let encoder_layer = {
4440                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4441                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4442
4443                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4444                let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4445                    / weight_pack_factor
4446                    + cfg.num_attention_heads * head_dim;
4447                let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4448                    / weight_pack_factor
4449                    + cfg.num_attention_heads * head_dim;
4450                let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4451                    / weight_pack_factor
4452                    + cfg.num_attention_heads * head_dim;
4453                let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4454                    / weight_pack_factor
4455                    + cfg.num_attention_heads * head_dim;
4456
4457                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4458                    + cfg.intermediate_size;
4459                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4460                    + cfg.hidden_size;
4461
4462                input_layernorm
4463                    + post_attention_layernorm
4464                    + q_proj
4465                    + k_proj
4466                    + v_proj
4467                    + o_proj
4468                    + fc1
4469                    + fc2
4470            };
4471
4472            unfold_elems
4473                + class_embeddng_elems
4474                + positional_embedding_vlm_elems
4475                + layernorm_post_elems
4476                + layernorm_pre_elems
4477                + pixel_shuffle_elems
4478                + encoder_layer * cfg.num_hidden_layers
4479        };
4480
4481        let elems = text_elems + vision_elems;
4482
4483        Ok(elems * dtype.size_in_bytes())
4484    }
4485
4486    fn layer_sizes_in_bytes(
4487        &self,
4488        config: &str,
4489        dtype: DType,
4490        weight_pack_factor: usize,
4491    ) -> Result<Vec<usize>> {
4492        let cfg: Llama4Config = serde_json::from_str(config)?;
4493        let tcfg = &cfg.text_config;
4494
4495        let mut per_layer_elems = Vec::new();
4496
4497        for layer_idx in 0..tcfg.num_hidden_layers {
4498            let input_layernorm = tcfg.hidden_size;
4499            let post_attention_layernorm = tcfg.hidden_size;
4500
4501            let size_in = tcfg.hidden_size;
4502            let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4503            let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4504            let q_proj = size_in * size_q / weight_pack_factor;
4505            let k_proj = size_in * size_kv / weight_pack_factor;
4506            let v_proj = size_in * size_kv / weight_pack_factor;
4507            let o_proj = size_q * size_in / weight_pack_factor;
4508
4509            let use_moe = tcfg.moe_layers().contains(&layer_idx);
4510            let moe_block = if use_moe {
4511                let h_size = tcfg.hidden_size;
4512                let i_size = tcfg.intermediate_size;
4513                let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4514                let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4515                let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4516
4517                gate_proj + up_proj + down_proj
4518            } else {
4519                let h_size = tcfg.hidden_size;
4520                let i_size = tcfg.intermediate_size_mlp;
4521                let gate_proj = h_size * i_size / weight_pack_factor;
4522                let up_proj = h_size * i_size / weight_pack_factor;
4523                let down_proj = i_size * h_size / weight_pack_factor;
4524
4525                gate_proj + up_proj + down_proj
4526            };
4527
4528            per_layer_elems.push(
4529                input_layernorm
4530                    + post_attention_layernorm
4531                    + q_proj
4532                    + k_proj
4533                    + v_proj
4534                    + o_proj
4535                    + moe_block,
4536            );
4537        }
4538
4539        Ok(per_layer_elems
4540            .into_iter()
4541            .map(|x| x * dtype.size_in_bytes())
4542            .collect())
4543    }
4544
4545    fn num_layers(&self, config: &str) -> Result<usize> {
4546        let cfg: Llama4Config = serde_json::from_str(config)?;
4547        Ok(cfg.text_config.num_hidden_layers)
4548    }
4549
4550    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4551        let cfg: Llama4Config = serde_json::from_str(config)?;
4552        let cfg = &cfg.text_config;
4553
4554        let cfg = ModelConfigMetadata {
4555            max_seq_len: cfg.max_position_embeddings,
4556            num_layers: cfg.num_hidden_layers,
4557            hidden_size: cfg.hidden_size,
4558            num_kv_heads: cfg.num_attention_heads,
4559            num_attn_heads: cfg.num_attention_heads,
4560            sliding_window: None,
4561            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4562            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4563        };
4564
4565        Ok(Box::new(cfg))
4566    }
4567
4568    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4569        Some(vec![NonMappedSubModel::Vision])
4570    }
4571}