mistralrs_core/pipeline/loaders/
vision_loaders.rs

1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::attention::ATTENTION_CHUNK_SIZE;
24use crate::device_map::DeviceMapper;
25use crate::layers::Conv3dConfig;
26use crate::matformer::MatformerSliceConfig;
27use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
28use crate::pipeline::isq::IsqModelLoader;
29use crate::pipeline::loaders::AutoDeviceMapParams;
30use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
31use crate::pipeline::{
32    EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
33    SupportedModality,
34};
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::vision_models::clip::ClipConfig;
37use crate::vision_models::gemma3::config::Gemma3Config;
38use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
39use crate::vision_models::gemma3n::config::{Gemma3nConfig, IntermediateSize};
40use crate::vision_models::gemma3n::{Gemma3nModel, Gemma3nProcessor};
41use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
42use crate::vision_models::idefics2_input_processor::Idefics2Processor;
43use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
44use crate::vision_models::image_processor::ImagePreProcessor;
45use crate::vision_models::inputs_processor::Phi4MMProcessor;
46use crate::vision_models::llama4::{
47    self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
48};
49use crate::vision_models::llava::config::Config as LLaVAConfig;
50use crate::vision_models::llava15::Model as LLaVA;
51use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
52use crate::vision_models::llava_next::Model as LLaVANext;
53use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
54use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
55use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
56use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
57use crate::vision_models::phi3_inputs_processor::Phi3Processor;
58use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
59use crate::vision_models::preprocessor_config::PreProcessorConfig;
60use crate::vision_models::processor_config::ProcessorConfig;
61use crate::vision_models::qwen2_5_vl::{
62    Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
63};
64use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
65use crate::vision_models::qwen3_vl::{Config as Qwen3VLConfig, Qwen3VLModel, Qwen3VLProcessor};
66use crate::vision_models::qwen3_vl_moe::{
67    Config as Qwen3VLMoEConfig, Qwen3VLMoEModel, Qwen3VLMoEProcessor,
68};
69use crate::vision_models::{minicpmo, phi4};
70
71pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
72    // pixel_values and pixel_attention_mask only specified for prompt seqs
73    #[allow(clippy::too_many_arguments)]
74    fn forward(
75        &self,
76        input_ids: &Tensor,
77        pixel_values: Option<Tensor>,
78        seqlen_offsets: &[usize],
79        context_lens: Vec<(usize, usize)>,
80        position_ids: Vec<usize>,
81        model_specific_args: Box<dyn Any>, // pixel attention mask, or image sizes, or anything else
82        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
83        flash_params: &FlashParams,
84    ) -> candle_core::Result<Tensor>;
85    fn device(&self) -> &Device;
86    fn cache(&self) -> &EitherCache;
87    fn cache_mut(&mut self) -> &mut EitherCache;
88    fn max_seq_len(&self) -> usize;
89    fn config(&self) -> &ModelConfigMetadata;
90    /// For a prompt without images. Requires batch size of 1!
91    fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
92}
93
94pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
95    fn load(
96        &self,
97        config: &str,
98        vb: ShardedVarBuilder,
99        normal_loading_metadata: NormalLoadingMetadata,
100        attention_mechanism: AttentionImplementation,
101    ) -> Result<Box<dyn VisionModel + Send + Sync>>;
102    fn is_gptx(&self, config: &str) -> bool;
103    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
104    fn get_processor(
105        &self,
106        model_config: &str,
107        processor_config: Option<ProcessorConfig>,
108        preprocessor_config: PreProcessorConfig,
109        max_edge: Option<u32>,
110    ) -> Arc<dyn Processor + Send + Sync>;
111    fn supports_paged_attention(&self, config: &str) -> bool;
112    fn supports_prefix_cacher(&self, _config: &str) -> bool {
113        // Default is false, specific model must override.
114        false
115    }
116    fn modalities(&self, config: &str) -> Result<Modalities>;
117    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
118    fn get_device_for_tensor(
119        &self,
120        config: &str,
121        _mapper: &dyn DeviceMapper,
122        loading_isq: bool,
123    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
124        if loading_isq {
125            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
126        } else {
127            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
128            let num_layers = self.model_config(config)?.num_layers();
129            let closure = move |name: String| {
130                if let Some(captures) = re.captures(&name) {
131                    captures
132                        .get(1)
133                        .and_then(|m| m.as_str().parse::<usize>().ok())
134                        .map(|l| l.min(num_layers))
135                        .map(DeviceForLoadTensor::Idx)
136                        .unwrap_or(DeviceForLoadTensor::Base)
137                } else {
138                    DeviceForLoadTensor::Base
139                }
140            };
141
142            Ok(Arc::new(closure))
143        }
144    }
145}
146
147#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
148#[derive(Clone, Debug, Deserialize, PartialEq)]
149/// The architecture to load the vision model as.
150pub enum VisionLoaderType {
151    #[serde(rename = "phi3v")]
152    Phi3V,
153    #[serde(rename = "idefics2")]
154    Idefics2,
155    #[serde(rename = "llava_next")]
156    LLaVANext,
157    #[serde(rename = "llava")]
158    LLaVA,
159    #[serde(rename = "vllama")]
160    VLlama,
161    #[serde(rename = "qwen2vl")]
162    Qwen2VL,
163    #[serde(rename = "idefics3")]
164    Idefics3,
165    #[serde(rename = "minicpmo")]
166    MiniCpmO,
167    #[serde(rename = "phi4mm")]
168    Phi4MM,
169    #[serde(rename = "qwen2_5vl")]
170    Qwen2_5VL,
171    #[serde(rename = "gemma3")]
172    Gemma3,
173    #[serde(rename = "mistral3")]
174    Mistral3,
175    #[serde(rename = "llama4")]
176    Llama4,
177    #[serde(rename = "gemma3n")]
178    Gemma3n,
179    #[serde(rename = "qwen3vl")]
180    Qwen3VL,
181    #[serde(rename = "qwen3vlmoe")]
182    Qwen3VLMoE,
183}
184
185// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
186impl VisionLoaderType {
187    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
188        match name {
189            "Phi3VForCausalLM" => Ok(Self::Phi3V),
190            "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
191            "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
192            "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
193            "MllamaForConditionalGeneration" => Ok(Self::VLlama),
194            "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
195            "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
196            "MiniCPMO" => Ok(Self::MiniCpmO),
197            "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
198            "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
199            "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
200            "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
201            "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
202            "Gemma3nForConditionalGeneration" => Ok(Self::Gemma3n),
203            "Qwen3VLForConditionalGeneration" => Ok(Self::Qwen3VL),
204            "Qwen3VLMoeForConditionalGeneration" => Ok(Self::Qwen3VLMoE),
205            other => anyhow::bail!(
206                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
207            ),
208        }
209    }
210}
211
212impl FromStr for VisionLoaderType {
213    type Err = String;
214    fn from_str(s: &str) -> Result<Self, Self::Err> {
215        match s {
216            "phi3v" => Ok(Self::Phi3V),
217            "idefics2" => Ok(Self::Idefics2),
218            "llava_next" => Ok(Self::LLaVANext),
219            "llava" => Ok(Self::LLaVA),
220            "vllama" => Ok(Self::VLlama),
221            "qwen2vl" => Ok(Self::Qwen2VL),
222            "idefics3" => Ok(Self::Idefics3),
223            "minicpmo" => Ok(Self::MiniCpmO),
224            "phi4mm" => Ok(Self::Phi4MM),
225            "qwen2_5vl" => Ok(Self::Qwen2_5VL),
226            "gemma3" => Ok(Self::Gemma3),
227            "mistral3" => Ok(Self::Mistral3),
228            "llama4" => Ok(Self::Llama4),
229            "gemma3n" => Ok(Self::Gemma3n),
230            "qwen3vl" => Ok(Self::Qwen3VL),
231            "qwen3vlmoe" => Ok(Self::Qwen3VLMoE),
232            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`, `gemma3n`, `qwen3vl`, `qwen3vlmoe`.")),
233        }
234    }
235}
236
237impl std::fmt::Display for VisionLoaderType {
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        let name = match self {
240            VisionLoaderType::Phi3V => "phi3v",
241            VisionLoaderType::Idefics2 => "idefics2",
242            VisionLoaderType::LLaVANext => "llava_next",
243            VisionLoaderType::LLaVA => "llava",
244            VisionLoaderType::VLlama => "vllama",
245            VisionLoaderType::Qwen2VL => "qwen2vl",
246            VisionLoaderType::Idefics3 => "idefics3",
247            VisionLoaderType::MiniCpmO => "minicpmo",
248            VisionLoaderType::Phi4MM => "phi4mm",
249            VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
250            VisionLoaderType::Gemma3 => "gemma3",
251            VisionLoaderType::Mistral3 => "mistral3",
252            VisionLoaderType::Llama4 => "llama4",
253            VisionLoaderType::Gemma3n => "gemma3n",
254            VisionLoaderType::Qwen3VL => "qwen3vl",
255            VisionLoaderType::Qwen3VLMoE => "qwen3vlmoe",
256        };
257        write!(f, "{name}")
258    }
259}
260
261#[derive(Deserialize)]
262struct AutoVisionLoaderConfig {
263    architectures: Vec<String>,
264}
265
266/// Automatically selects a VisionModelLoader implementation based on the JSON `architectures` field.
267pub struct AutoVisionLoader;
268
269impl AutoVisionLoader {
270    fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
271        let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
272        if auto_cfg.architectures.len() != 1 {
273            anyhow::bail!("Expected exactly one architecture in config");
274        }
275
276        let name = &auto_cfg.architectures[0];
277        let tp = VisionLoaderType::from_causal_lm_name(name)?;
278
279        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
280
281        // Delegate to the concrete loader
282        Ok(match tp {
283            VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
284            VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
285            VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
286            VisionLoaderType::LLaVA => Box::new(LLaVALoader),
287            VisionLoaderType::VLlama => Box::new(VLlamaLoader),
288            VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
289            VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
290            VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
291            VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
292            VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
293            VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
294            VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
295            VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
296            VisionLoaderType::Gemma3n => Box::new(Gemma3nLoader),
297            VisionLoaderType::Qwen3VL => Box::new(Qwen3VLLoader),
298            VisionLoaderType::Qwen3VLMoE => Box::new(Qwen3VLMoELoader),
299        })
300    }
301}
302
303impl VisionModelLoader for AutoVisionLoader {
304    fn load(
305        &self,
306        config: &str,
307        vb: ShardedVarBuilder,
308        normal_loading_metadata: NormalLoadingMetadata,
309        attention_mechanism: AttentionImplementation,
310    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
311        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
312    }
313
314    fn is_gptx(&self, config: &str) -> bool {
315        Self::get_loader(config)
316            .expect("AutoVisionLoader get_loader")
317            .is_gptx(config)
318    }
319
320    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
321        Self::get_loader(config)?.get_config_repr(config)
322    }
323
324    fn get_processor(
325        &self,
326        model_config: &str,
327        proc_cfg: Option<ProcessorConfig>,
328        preproc_cfg: PreProcessorConfig,
329        max_edge: Option<u32>,
330    ) -> Arc<dyn Processor + Send + Sync> {
331        Self::get_loader(model_config)
332            .expect("AutoVisionLoader get_loader")
333            .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
334    }
335
336    fn supports_paged_attention(&self, config: &str) -> bool {
337        Self::get_loader(config)
338            .expect("AutoVisionLoader")
339            .supports_paged_attention(config)
340    }
341
342    fn modalities(&self, config: &str) -> Result<Modalities> {
343        Self::get_loader(config)?.modalities(config)
344    }
345
346    fn supports_prefix_cacher(&self, config: &str) -> bool {
347        Self::get_loader(config)
348            .expect("AutoVisionLoader")
349            .supports_prefix_cacher(config)
350    }
351
352    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
353        Self::get_loader(config)
354            .expect("AutoVisionLoader")
355            .prefixer(config)
356    }
357
358    fn get_device_for_tensor(
359        &self,
360        config: &str,
361        mapper: &dyn DeviceMapper,
362        loading_isq: bool,
363    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
364        Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
365    }
366}
367
368impl IsqModelLoader for AutoVisionLoader {
369    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
370        Self::get_loader(config)?.isq_layer_regexes(config)
371    }
372    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
373        Self::get_loader(config)?.immediate_isq_predicates(config)
374    }
375}
376
377impl DeviceMappedModelLoader for AutoVisionLoader {
378    fn mapped_max_act_size_elems(
379        &self,
380        config: &str,
381        params: &AutoDeviceMapParams,
382    ) -> Result<usize> {
383        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
384    }
385    fn non_mapped_max_act_size_elems(
386        &self,
387        config: &str,
388        params: &AutoDeviceMapParams,
389    ) -> Result<usize> {
390        Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
391    }
392    fn non_mapped_size_in_bytes(
393        &self,
394        config: &str,
395        dtype: DType,
396        weight_pack_factor: usize,
397        _matformer_config: Option<&MatformerSliceConfig>,
398    ) -> Result<usize> {
399        Self::get_loader(config)?.non_mapped_size_in_bytes(
400            config,
401            dtype,
402            weight_pack_factor,
403            _matformer_config,
404        )
405    }
406    fn layer_sizes_in_bytes(
407        &self,
408        config: &str,
409        dtype: DType,
410        weight_pack_factor: usize,
411        _matformer_config: Option<&MatformerSliceConfig>,
412    ) -> Result<Vec<usize>> {
413        Self::get_loader(config)?.layer_sizes_in_bytes(
414            config,
415            dtype,
416            weight_pack_factor,
417            _matformer_config,
418        )
419    }
420    fn num_layers(&self, config: &str) -> Result<usize> {
421        Self::get_loader(config)?.num_layers(config)
422    }
423    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
424        Self::get_loader(config)?.model_config(config)
425    }
426}
427
428macro_rules! bias_if {
429    ($cond:expr, $size:expr) => {
430        if $cond {
431            $size
432        } else {
433            0
434        }
435    };
436}
437
438fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
439    let pre_layer_norm = cfg.hidden_size;
440    let final_layer_norm = cfg.hidden_size;
441
442    let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
443    let num_positions = num_patches + 1;
444
445    let class_embedding = cfg.hidden_size;
446
447    let position_ids = num_positions;
448    let position_embedding = num_positions * cfg.hidden_size;
449
450    let conv2dconfig = Conv2dConfig {
451        stride: cfg.patch_size,
452        ..Default::default()
453    };
454    let patch_embedding =
455        cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
456
457    let encoder_layer_elems = {
458        let layer_norm1 = cfg.hidden_size;
459        let layer_norm2 = cfg.hidden_size;
460
461        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
462        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
463        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
464        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
465
466        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
467        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
468
469        layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
470    };
471
472    pre_layer_norm
473        + final_layer_norm
474        + class_embedding
475        + position_ids
476        + position_embedding
477        + patch_embedding
478        + cfg.num_hidden_layers * encoder_layer_elems
479}
480
481// ======================== Phi 3 loader
482
483/// [`VisionLoader`] for a Phi 3 Vision model.
484///
485/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
486pub struct Phi3VLoader;
487
488pub struct Phi3VPrefixer;
489
490impl MultimodalPromptPrefixer for Phi3VPrefixer {
491    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
492        // Image indexing starts at 0.
493        format!(
494            "{}{prompt}",
495            image_indexes
496                .into_iter()
497                .map(|image_index| format!("<|image_{}|>", image_index + 1))
498                .join("")
499        )
500    }
501}
502
503impl VisionModelLoader for Phi3VLoader {
504    fn load(
505        &self,
506        config: &str,
507        vb: ShardedVarBuilder,
508        normal_loading_metadata: NormalLoadingMetadata,
509        attention_mechanism: AttentionImplementation,
510    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
511        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
512        Ok(Box::new(Phi3::new(
513            &cfg,
514            vb,
515            self.is_gptx(config),
516            normal_loading_metadata,
517            attention_mechanism,
518        )?))
519    }
520    fn is_gptx(&self, _config: &str) -> bool {
521        true
522    }
523    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
524        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
525        Ok(Box::new(cfg))
526    }
527    fn get_processor(
528        &self,
529        _model_config: &str,
530        processor_config: Option<ProcessorConfig>,
531        preprocessor_config: PreProcessorConfig,
532        _max_edge: Option<u32>,
533    ) -> Arc<dyn Processor + Send + Sync> {
534        Phi3Processor::new_processor(processor_config, preprocessor_config)
535    }
536    fn supports_paged_attention(&self, _config: &str) -> bool {
537        true
538    }
539    fn supports_prefix_cacher(&self, _config: &str) -> bool {
540        true
541    }
542    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
543        Arc::new(Phi3VPrefixer)
544    }
545    fn modalities(&self, _config: &str) -> Result<Modalities> {
546        Ok(Modalities {
547            input: vec![SupportedModality::Text, SupportedModality::Vision],
548            output: vec![SupportedModality::Text],
549        })
550    }
551}
552
553impl IsqModelLoader for Phi3VLoader {
554    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
555        Ok(vec![
556            Regex::new(r"lm_head\.(weight|bias)$")?,
557            // Attention
558            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
559            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
560            // MLP
561            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
562            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
563        ])
564    }
565    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
566        self.isq_layer_regexes(config)
567    }
568}
569
570impl DeviceMappedModelLoader for Phi3VLoader {
571    fn mapped_max_act_size_elems(
572        &self,
573        config: &str,
574        params: &AutoDeviceMapParams,
575    ) -> Result<usize> {
576        // NOTE: we ignore max_num_images although it can only be one...
577        let AutoDeviceMapParams::Vision {
578            max_seq_len,
579            max_batch_size,
580            max_image_shape: _,
581            max_num_images,
582        } = params
583        else {
584            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
585        };
586
587        let cfg: Phi3Config = serde_json::from_str(config)?;
588
589        let vcfg = &PHI3V_CLIP_CONFIG;
590
591        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
592        let img_seq_len = (num_patches + 1) * max_num_images;
593
594        let max_text_attn = {
595            // This model injects the vision information directly into the input embeddings
596            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
597            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
598        };
599
600        Ok(max_text_attn)
601    }
602
603    fn non_mapped_max_act_size_elems(
604        &self,
605        config: &str,
606        params: &AutoDeviceMapParams,
607    ) -> Result<usize> {
608        // NOTE: we ignore max_num_images although it can only be one...
609        let AutoDeviceMapParams::Vision {
610            max_seq_len: _,
611            max_batch_size,
612            max_image_shape: _,
613            max_num_images,
614        } = params
615        else {
616            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
617        };
618
619        let cfg: Phi3Config = serde_json::from_str(config)?;
620
621        let vcfg = &PHI3V_CLIP_CONFIG;
622
623        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
624        let img_seq_len = num_patches + 1;
625
626        let max_vision_attn = {
627            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
628        };
629
630        Ok(max_vision_attn)
631    }
632
633    fn non_mapped_size_in_bytes(
634        &self,
635        config: &str,
636        dtype: DType,
637        weight_pack_factor: usize,
638        _matformer_config: Option<&MatformerSliceConfig>,
639    ) -> Result<usize> {
640        let cfg: Phi3Config = serde_json::from_str(config)?;
641        let elems = {
642            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
643            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
644            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
645                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
646            } else {
647                0
648            };
649            let norm = cfg.hidden_size;
650
651            let image_embed = {
652                let projection_cls = cfg
653                    .embd_layer
654                    .projection_cls
655                    .clone()
656                    .unwrap_or("linear".to_string());
657                let with_learnable_separator =
658                    cfg.embd_layer.with_learnable_separator.unwrap_or(false);
659                let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
660                let image_dim_out = cfg.img_processor.image_dim_out;
661
662                let proj = match (projection_cls.as_str(), use_hd_transform) {
663                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
664                    ("mlp", true) => {
665                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
666                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
667                        a + b
668                    }
669                    ("mlp", false) => {
670                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
671                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
672                        a + b
673                    }
674                    _ => {
675                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
676                    }
677                };
678
679                let (glb_gn, sub_gn) = if with_learnable_separator {
680                    let glb_gn = image_dim_out * 4;
681                    let sub_gn = image_dim_out * 4;
682                    (glb_gn, sub_gn)
683                } else {
684                    (0, 0)
685                };
686
687                let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
688
689                proj + glb_gn + sub_gn + clip_vit
690            };
691
692            embed_tokens + lm_head + norm + image_embed
693        };
694
695        Ok(elems * dtype.size_in_bytes())
696    }
697
698    fn layer_sizes_in_bytes(
699        &self,
700        config: &str,
701        dtype: DType,
702        weight_pack_factor: usize,
703        _matformer_config: Option<&MatformerSliceConfig>,
704    ) -> Result<Vec<usize>> {
705        let cfg: Phi3Config = serde_json::from_str(config)?;
706        let per_layer_elems = {
707            let input_layernorm = cfg.hidden_size;
708            let post_attention_layernorm = cfg.hidden_size;
709
710            let size_in = cfg.hidden_size;
711            let head_dim = cfg.head_dim();
712            let op_size =
713                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
714            let qkv_proj = size_in * op_size / weight_pack_factor;
715            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
716
717            let h_size = cfg.hidden_size;
718            let i_size = cfg.intermediate_size;
719            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
720            let down_proj = h_size * i_size / weight_pack_factor;
721
722            input_layernorm
723                + post_attention_layernorm
724                + qkv_proj
725                + o_proj
726                + gate_up_proj
727                + down_proj
728        };
729        Ok(vec![
730            per_layer_elems * dtype.size_in_bytes();
731            cfg.num_hidden_layers
732        ])
733    }
734
735    fn num_layers(&self, config: &str) -> Result<usize> {
736        let cfg: Phi3Config = serde_json::from_str(config)?;
737        Ok(cfg.num_hidden_layers)
738    }
739
740    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
741        let cfg: Phi3Config = serde_json::from_str(config)?;
742
743        let cfg = ModelConfigMetadata {
744            max_seq_len: cfg.max_position_embeddings,
745            num_layers: cfg.num_hidden_layers,
746            hidden_size: cfg.hidden_size,
747            num_kv_heads: cfg.num_key_value_heads,
748            num_attn_heads: cfg.num_attention_heads,
749            sliding_window: cfg.sliding_window,
750            k_head_dim: cfg.head_dim(),
751            v_head_dim: cfg.head_dim(),
752        };
753
754        Ok(Box::new(cfg))
755    }
756
757    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
758        Some(vec![NonMappedSubModel::Vision])
759    }
760}
761
762// ======================== Idefics 2 loader
763
764/// [`VisionLoader`] for an Idefics 2 Vision model.
765///
766/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
767pub struct Idefics2Loader;
768
769pub struct Idefics2Prefixer;
770
771impl MultimodalPromptPrefixer for Idefics2Prefixer {
772    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
773        // Chat template does it
774        prompt.to_string()
775    }
776}
777
778impl VisionModelLoader for Idefics2Loader {
779    fn load(
780        &self,
781        config: &str,
782        vb: ShardedVarBuilder,
783        normal_loading_metadata: NormalLoadingMetadata,
784        attention_mechanism: AttentionImplementation,
785    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
786        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
787        Ok(Box::new(Idefics2::new(
788            &cfg,
789            vb,
790            self.is_gptx(config),
791            normal_loading_metadata,
792            attention_mechanism,
793        )?))
794    }
795    fn is_gptx(&self, _config: &str) -> bool {
796        true
797    }
798    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
799        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
800        Ok(Box::new(cfg))
801    }
802    fn get_processor(
803        &self,
804        _model_config: &str,
805        processor_config: Option<ProcessorConfig>,
806        preprocessor_config: PreProcessorConfig,
807        max_edge: Option<u32>,
808    ) -> Arc<dyn Processor + Send + Sync> {
809        Arc::new(Idefics2Processor::new(
810            processor_config.unwrap(),
811            preprocessor_config,
812            max_edge,
813        ))
814    }
815    fn supports_paged_attention(&self, _config: &str) -> bool {
816        true
817    }
818    fn supports_prefix_cacher(&self, _config: &str) -> bool {
819        true
820    }
821    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
822        Arc::new(Idefics2Prefixer)
823    }
824    fn modalities(&self, _config: &str) -> Result<Modalities> {
825        Ok(Modalities {
826            input: vec![SupportedModality::Text, SupportedModality::Vision],
827            output: vec![SupportedModality::Text],
828        })
829    }
830}
831
832impl IsqModelLoader for Idefics2Loader {
833    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
834        Ok(vec![
835            Regex::new(r"lm_head\.(weight|bias)$")?,
836            // Attention
837            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
838            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
839            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
840            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
841            // MLP
842            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
843            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
844            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
845        ])
846    }
847    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
848        Ok(vec![
849            Regex::new(r"lm_head\.(weight|bias)$")?,
850            // Attention
851            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
852            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
853            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
854            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
855            // MLP
856            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
857            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
858            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
859        ])
860    }
861}
862
863impl DeviceMappedModelLoader for Idefics2Loader {
864    fn mapped_max_act_size_elems(
865        &self,
866        config: &str,
867        params: &AutoDeviceMapParams,
868    ) -> Result<usize> {
869        let AutoDeviceMapParams::Vision {
870            max_seq_len,
871            max_batch_size,
872            max_image_shape: _,
873            max_num_images,
874        } = params
875        else {
876            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
877        };
878
879        let cfg: Idefics2Config = serde_json::from_str(config)?;
880
881        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
882        let img_seq_len = (num_patches + 1) * max_num_images;
883
884        let max_text_attn = {
885            // This model injects the vision information directly into the input embeddings
886            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
887            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
888        };
889
890        Ok(max_text_attn)
891    }
892
893    fn non_mapped_max_act_size_elems(
894        &self,
895        config: &str,
896        params: &AutoDeviceMapParams,
897    ) -> Result<usize> {
898        let AutoDeviceMapParams::Vision {
899            max_seq_len: _,
900            max_batch_size,
901            max_image_shape: _,
902            max_num_images,
903        } = params
904        else {
905            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
906        };
907
908        let cfg: Idefics2Config = serde_json::from_str(config)?;
909
910        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
911        let img_seq_len = num_patches + 1;
912
913        let max_vision_attn = {
914            // do_image_splitting = true
915            let images_factor = 5;
916
917            (max_batch_size * images_factor * max_num_images)
918                * cfg.vision_config.num_attention_heads
919                * img_seq_len
920                * img_seq_len
921        };
922
923        Ok(max_vision_attn)
924    }
925
926    fn non_mapped_size_in_bytes(
927        &self,
928        config: &str,
929        dtype: DType,
930        weight_pack_factor: usize,
931        _matformer_config: Option<&MatformerSliceConfig>,
932    ) -> Result<usize> {
933        let cfg: Idefics2Config = serde_json::from_str(config)?;
934        let text_elems = {
935            let tie_word_embeddings = cfg.tie_word_embeddings;
936            let cfg = &cfg.text_config;
937
938            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
939            let lm_head = if !tie_word_embeddings {
940                cfg.hidden_size * cfg.vocab_size
941            } else {
942                0
943            };
944            let norm = cfg.hidden_size;
945            embed_tokens + lm_head + norm
946        };
947
948        let connector_elems = {
949            let tcfg = &cfg.text_config;
950            let vcfg = &cfg.vision_config;
951            let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
952            let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
953            let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
954
955            let perceiver_elems = {
956                let tcfg = &cfg.text_config;
957                let pcfg = &cfg.perceiver_config;
958
959                let n_latents = pcfg.resampler_n_latents;
960                let hidden_size = tcfg.hidden_size;
961                let depth = pcfg.resampler_depth;
962
963                let norm = tcfg.hidden_size;
964                let latents = n_latents * hidden_size;
965
966                let layer_elems = {
967                    let input_latents_norm = hidden_size;
968                    let input_context_norm = hidden_size;
969                    let post_attn_norm = hidden_size;
970
971                    let num_heads = pcfg.resampler_n_heads;
972                    let head_dim = pcfg.resampler_head_dim;
973                    let num_key_value_heads = pcfg.num_key_value_heads;
974
975                    let q_proj = hidden_size * num_heads * head_dim;
976                    let k_proj = hidden_size * num_key_value_heads * head_dim;
977                    let v_proj = hidden_size * num_key_value_heads * head_dim;
978                    let o_proj = num_heads * head_dim * hidden_size;
979
980                    let gate_proj = hidden_size * hidden_size * 4;
981                    let up_proj = hidden_size * hidden_size * 4;
982                    let down_proj = hidden_size * 4 * hidden_size;
983
984                    input_latents_norm
985                        + input_context_norm
986                        + post_attn_norm
987                        + q_proj
988                        + k_proj
989                        + v_proj
990                        + o_proj
991                        + gate_proj
992                        + up_proj
993                        + down_proj
994                };
995
996                norm + latents + layer_elems * depth
997            };
998
999            gate_proj + up_proj + down_proj + perceiver_elems
1000        };
1001
1002        let vision_transformer = {
1003            let cfg = &cfg.vision_config;
1004
1005            let post_layernorm = cfg.hidden_size;
1006
1007            let conv_config = Conv2dConfig {
1008                stride: cfg.patch_size,
1009                ..Default::default()
1010            };
1011            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
1012                * cfg.patch_size
1013                * cfg.patch_size;
1014
1015            let num_patches_per_side = cfg.image_size / cfg.patch_size;
1016            let num_patches = num_patches_per_side.pow(2);
1017            let position_embedding = num_patches * cfg.hidden_size;
1018
1019            let layer_elems = {
1020                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1021                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1022
1023                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
1024                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
1025
1026                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1027                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1028                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1029                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1030
1031                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
1032            };
1033
1034            post_layernorm + patch_embedding + position_embedding + layer_elems
1035        };
1036
1037        let elems = text_elems + connector_elems + vision_transformer;
1038
1039        Ok(elems * dtype.size_in_bytes())
1040    }
1041
1042    fn layer_sizes_in_bytes(
1043        &self,
1044        config: &str,
1045        dtype: DType,
1046        weight_pack_factor: usize,
1047        _matformer_config: Option<&MatformerSliceConfig>,
1048    ) -> Result<Vec<usize>> {
1049        let cfg: Idefics2Config = serde_json::from_str(config)?;
1050        let cfg = cfg.text_config;
1051        let per_layer_elems = {
1052            let input_layernorm = cfg.hidden_size;
1053            let post_attention_layernorm = cfg.hidden_size;
1054
1055            let size_in = cfg.hidden_size;
1056            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1057            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1058            let q_proj = size_in * size_q / weight_pack_factor;
1059            let k_proj = size_in * size_kv / weight_pack_factor;
1060            let v_proj = size_in * size_kv / weight_pack_factor;
1061            let o_proj = size_q * size_in / weight_pack_factor;
1062
1063            let h_size = cfg.hidden_size;
1064            let i_size = cfg.intermediate_size;
1065            let gate_proj = h_size * i_size / weight_pack_factor;
1066            let up_proj = h_size * i_size / weight_pack_factor;
1067            let down_proj = i_size * h_size / weight_pack_factor;
1068
1069            input_layernorm
1070                + post_attention_layernorm
1071                + q_proj
1072                + k_proj
1073                + v_proj
1074                + o_proj
1075                + gate_proj
1076                + up_proj
1077                + down_proj
1078        };
1079        Ok(vec![
1080            per_layer_elems * dtype.size_in_bytes();
1081            cfg.num_hidden_layers
1082        ])
1083    }
1084
1085    fn num_layers(&self, config: &str) -> Result<usize> {
1086        let cfg: Idefics2Config = serde_json::from_str(config)?;
1087        Ok(cfg.text_config.num_hidden_layers)
1088    }
1089    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1090        let cfg: Idefics2Config = serde_json::from_str(config)?;
1091        let cfg = &cfg.text_config;
1092
1093        let cfg = ModelConfigMetadata {
1094            max_seq_len: cfg.max_position_embeddings,
1095            num_layers: cfg.num_hidden_layers,
1096            hidden_size: cfg.hidden_size,
1097            num_kv_heads: cfg.num_key_value_heads,
1098            num_attn_heads: cfg.num_attention_heads,
1099            sliding_window: cfg.sliding_window,
1100            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1101            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1102        };
1103
1104        Ok(Box::new(cfg))
1105    }
1106
1107    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1108        Some(vec![NonMappedSubModel::Vision])
1109    }
1110}
1111
1112// ======================== LLaVANext Loader
1113
1114/// [`VisionLoader`] for an LLaVANext Vision model.
1115///
1116/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1117pub struct LLaVANextLoader;
1118
1119pub struct LLaVANextPrefixer;
1120
1121impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1122    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1123        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1124    }
1125}
1126
1127impl VisionModelLoader for LLaVANextLoader {
1128    fn load(
1129        &self,
1130        config: &str,
1131        vb: ShardedVarBuilder,
1132        normal_loading_metadata: NormalLoadingMetadata,
1133        attention_mechanism: AttentionImplementation,
1134    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1135        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1136        Ok(Box::new(LLaVANext::new(
1137            &cfg,
1138            vb,
1139            self.is_gptx(config),
1140            normal_loading_metadata,
1141            attention_mechanism,
1142        )?))
1143    }
1144    fn is_gptx(&self, _config: &str) -> bool {
1145        false
1146    }
1147    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1148        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1149        Ok(Box::new(cfg))
1150    }
1151    fn get_processor(
1152        &self,
1153        model_config: &str,
1154        _processor_config: Option<ProcessorConfig>,
1155        _preprocessor_config: PreProcessorConfig,
1156        _max_edge: Option<u32>,
1157    ) -> Arc<dyn Processor + Send + Sync> {
1158        Arc::new(LLaVANextProcessor::new(model_config))
1159    }
1160    fn supports_paged_attention(&self, _config: &str) -> bool {
1161        true
1162    }
1163    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1164        true
1165    }
1166    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1167        Arc::new(LLaVANextPrefixer)
1168    }
1169    fn modalities(&self, _config: &str) -> Result<Modalities> {
1170        Ok(Modalities {
1171            input: vec![SupportedModality::Text, SupportedModality::Vision],
1172            output: vec![SupportedModality::Text],
1173        })
1174    }
1175}
1176
1177impl IsqModelLoader for LLaVANextLoader {
1178    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1179        Ok(vec![
1180            Regex::new(r"lm_head\.(weight|bias)$")?,
1181            // Attention
1182            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1183            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1184            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1185            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1186            // MLP
1187            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1188            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1189            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1190        ])
1191    }
1192    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1193        Ok(vec![
1194            Regex::new(r"lm_head\.(weight|bias)$")?,
1195            // Attention
1196            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1197            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1198            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1199            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1200            // MLP
1201            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1202            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1203            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1204        ])
1205    }
1206}
1207
1208impl DeviceMappedModelLoader for LLaVANextLoader {
1209    fn mapped_max_act_size_elems(
1210        &self,
1211        config: &str,
1212        params: &AutoDeviceMapParams,
1213    ) -> Result<usize> {
1214        let AutoDeviceMapParams::Vision {
1215            max_seq_len,
1216            max_batch_size,
1217            max_image_shape,
1218            max_num_images,
1219        } = params
1220        else {
1221            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1222        };
1223
1224        let config: LLaVAConfig = serde_json::from_str(config)?;
1225
1226        #[allow(clippy::cast_possible_truncation)]
1227        let img_seq_len =
1228            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1229                &config,
1230                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1231            );
1232        let img_seq_len = img_seq_len * max_num_images;
1233
1234        let max_text_attn = {
1235            let cfg = &config.text_config;
1236            // This model injects the vision information directly into the input embeddings
1237            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1238
1239            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1240        };
1241
1242        Ok(max_text_attn)
1243    }
1244
1245    fn non_mapped_max_act_size_elems(
1246        &self,
1247        config: &str,
1248        params: &AutoDeviceMapParams,
1249    ) -> Result<usize> {
1250        let AutoDeviceMapParams::Vision {
1251            max_seq_len: _,
1252            max_batch_size,
1253            max_image_shape,
1254            max_num_images,
1255        } = params
1256        else {
1257            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1258        };
1259
1260        let config: LLaVAConfig = serde_json::from_str(config)?;
1261
1262        #[allow(clippy::cast_possible_truncation)]
1263        let img_seq_len =
1264            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1265                &config,
1266                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1267            );
1268
1269        let max_vision_attn = {
1270            (max_batch_size * max_num_images)
1271                * config.vision_config.num_attention_heads
1272                * img_seq_len
1273                * img_seq_len
1274        };
1275
1276        Ok(max_vision_attn)
1277    }
1278
1279    fn non_mapped_size_in_bytes(
1280        &self,
1281        config: &str,
1282        dtype: DType,
1283        weight_pack_factor: usize,
1284        _matformer_config: Option<&MatformerSliceConfig>,
1285    ) -> Result<usize> {
1286        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1287        let text_elems = {
1288            let cfg = &cfg.text_config;
1289            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1290            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1291            let norm = cfg.hidden_size;
1292            embed_tokens + lm_head + norm
1293        };
1294
1295        let image_newline = cfg.text_config.hidden_size;
1296        let mmproj = {
1297            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1298                + cfg.text_config.hidden_size;
1299            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1300                + cfg.text_config.hidden_size;
1301
1302            linear_1 + linear_2
1303        };
1304        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1305
1306        let elems = text_elems + image_newline + mmproj + vision_tower;
1307        Ok(elems * dtype.size_in_bytes())
1308    }
1309
1310    fn layer_sizes_in_bytes(
1311        &self,
1312        config: &str,
1313        dtype: DType,
1314        weight_pack_factor: usize,
1315        _matformer_config: Option<&MatformerSliceConfig>,
1316    ) -> Result<Vec<usize>> {
1317        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1318        let per_layer_elems = {
1319            let cfg = &cfg.text_config;
1320            let input_layernorm = cfg.hidden_size;
1321            let post_attention_layernorm = cfg.hidden_size;
1322
1323            let size_in = cfg.hidden_size;
1324            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1325            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1326            let q_proj = size_in * size_q / weight_pack_factor;
1327            let k_proj = size_in * size_kv / weight_pack_factor;
1328            let v_proj = size_in * size_kv / weight_pack_factor;
1329            let o_proj = size_q * size_in / weight_pack_factor;
1330
1331            let h_size = cfg.hidden_size;
1332            let i_size = cfg.intermediate_size;
1333            let gate_proj = h_size * i_size / weight_pack_factor;
1334            let up_proj = h_size * i_size / weight_pack_factor;
1335            let down_proj = i_size * h_size / weight_pack_factor;
1336
1337            input_layernorm
1338                + post_attention_layernorm
1339                + q_proj
1340                + k_proj
1341                + v_proj
1342                + o_proj
1343                + gate_proj
1344                + up_proj
1345                + down_proj
1346        };
1347        Ok(vec![
1348            per_layer_elems * dtype.size_in_bytes();
1349            cfg.text_config.num_hidden_layers
1350        ])
1351    }
1352
1353    fn num_layers(&self, config: &str) -> Result<usize> {
1354        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1355        Ok(cfg.text_config.num_hidden_layers)
1356    }
1357
1358    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1359        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1360        let cfg = &cfg.text_config;
1361
1362        let cfg = ModelConfigMetadata {
1363            max_seq_len: cfg.max_position_embeddings,
1364            num_layers: cfg.num_hidden_layers,
1365            hidden_size: cfg.hidden_size,
1366            num_kv_heads: cfg.num_key_value_heads,
1367            num_attn_heads: cfg.num_attention_heads,
1368            sliding_window: cfg.sliding_window,
1369            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1370            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1371        };
1372
1373        Ok(Box::new(cfg))
1374    }
1375
1376    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1377        Some(vec![NonMappedSubModel::Vision])
1378    }
1379}
1380
1381// ======================== LLaVA Loader
1382
1383/// [`VisionLoader`] for an LLaVA Vision model.
1384///
1385/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1386pub struct LLaVALoader;
1387
1388pub struct LLaVAPrefixer;
1389
1390impl MultimodalPromptPrefixer for LLaVAPrefixer {
1391    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1392        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1393    }
1394}
1395
1396impl VisionModelLoader for LLaVALoader {
1397    fn load(
1398        &self,
1399        config: &str,
1400        vb: ShardedVarBuilder,
1401        normal_loading_metadata: NormalLoadingMetadata,
1402        attention_mechanism: AttentionImplementation,
1403    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1404        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1405        Ok(Box::new(LLaVA::new(
1406            &cfg,
1407            vb,
1408            self.is_gptx(config),
1409            normal_loading_metadata,
1410            attention_mechanism,
1411        )?))
1412    }
1413    fn is_gptx(&self, _config: &str) -> bool {
1414        false
1415    }
1416    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1417        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1418        Ok(Box::new(cfg))
1419    }
1420    fn get_processor(
1421        &self,
1422        model_config: &str,
1423        _processor_config: Option<ProcessorConfig>,
1424        _preprocessor_config: PreProcessorConfig,
1425        _max_edge: Option<u32>,
1426    ) -> Arc<dyn Processor + Send + Sync> {
1427        Arc::new(LLaVAProcessor::new(model_config))
1428    }
1429    fn supports_paged_attention(&self, _config: &str) -> bool {
1430        true
1431    }
1432    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1433        true
1434    }
1435    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1436        Arc::new(LLaVAPrefixer)
1437    }
1438    fn modalities(&self, _config: &str) -> Result<Modalities> {
1439        Ok(Modalities {
1440            input: vec![SupportedModality::Text, SupportedModality::Vision],
1441            output: vec![SupportedModality::Text],
1442        })
1443    }
1444}
1445
1446impl IsqModelLoader for LLaVALoader {
1447    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1448        Ok(vec![
1449            Regex::new(r"lm_head\.(weight|bias)$")?,
1450            // Attention
1451            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1452            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1453            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1454            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1455            // MLP
1456            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1457            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1458            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1459        ])
1460    }
1461    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1462        Ok(vec![
1463            Regex::new(r"lm_head\.(weight|bias)$")?,
1464            // Attention
1465            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1466            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1467            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1468            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1469            // MLP
1470            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1471            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1472            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1473        ])
1474    }
1475}
1476
1477impl DeviceMappedModelLoader for LLaVALoader {
1478    fn mapped_max_act_size_elems(
1479        &self,
1480        config: &str,
1481        params: &AutoDeviceMapParams,
1482    ) -> Result<usize> {
1483        let AutoDeviceMapParams::Vision {
1484            max_seq_len,
1485            max_batch_size,
1486            max_image_shape: _,
1487            max_num_images,
1488        } = params
1489        else {
1490            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1491        };
1492
1493        let config: LLaVAConfig = serde_json::from_str(config)?;
1494
1495        let img_seq_len =
1496            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1497        let img_seq_len = img_seq_len * max_num_images;
1498
1499        let max_text_attn = {
1500            let cfg = &config.text_config;
1501            // This model injects the vision information directly into the input embeddings
1502            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1503
1504            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1505        };
1506
1507        Ok(max_text_attn)
1508    }
1509
1510    fn non_mapped_max_act_size_elems(
1511        &self,
1512        config: &str,
1513        params: &AutoDeviceMapParams,
1514    ) -> Result<usize> {
1515        let AutoDeviceMapParams::Vision {
1516            max_seq_len: _,
1517            max_batch_size,
1518            max_image_shape: _,
1519            max_num_images,
1520        } = params
1521        else {
1522            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1523        };
1524
1525        let config: LLaVAConfig = serde_json::from_str(config)?;
1526
1527        let img_seq_len =
1528            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1529
1530        let max_vision_attn = {
1531            (max_batch_size * max_num_images)
1532                * config.vision_config.num_attention_heads
1533                * img_seq_len
1534                * img_seq_len
1535        };
1536
1537        Ok(max_vision_attn)
1538    }
1539
1540    fn non_mapped_size_in_bytes(
1541        &self,
1542        config: &str,
1543        dtype: DType,
1544        weight_pack_factor: usize,
1545        _matformer_config: Option<&MatformerSliceConfig>,
1546    ) -> Result<usize> {
1547        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1548        let text_elems = {
1549            let cfg = &cfg.text_config;
1550            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1551            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1552            let norm = cfg.hidden_size;
1553            embed_tokens + lm_head + norm
1554        };
1555
1556        let image_newline = cfg.text_config.hidden_size;
1557        let mmproj = {
1558            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1559                + cfg.text_config.hidden_size;
1560            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1561                + cfg.text_config.hidden_size;
1562
1563            linear_1 + linear_2
1564        };
1565        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1566
1567        let elems = text_elems + image_newline + mmproj + vision_tower;
1568        Ok(elems * dtype.size_in_bytes())
1569    }
1570
1571    fn layer_sizes_in_bytes(
1572        &self,
1573        config: &str,
1574        dtype: DType,
1575        weight_pack_factor: usize,
1576        _matformer_config: Option<&MatformerSliceConfig>,
1577    ) -> Result<Vec<usize>> {
1578        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1579        let per_layer_elems = {
1580            let cfg = &cfg.text_config;
1581            let input_layernorm = cfg.hidden_size;
1582            let post_attention_layernorm = cfg.hidden_size;
1583
1584            let size_in = cfg.hidden_size;
1585            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1586            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1587            let q_proj = size_in * size_q / weight_pack_factor;
1588            let k_proj = size_in * size_kv / weight_pack_factor;
1589            let v_proj = size_in * size_kv / weight_pack_factor;
1590            let o_proj = size_q * size_in / weight_pack_factor;
1591
1592            let h_size = cfg.hidden_size;
1593            let i_size = cfg.intermediate_size;
1594            let gate_proj = h_size * i_size / weight_pack_factor;
1595            let up_proj = h_size * i_size / weight_pack_factor;
1596            let down_proj = i_size * h_size / weight_pack_factor;
1597
1598            input_layernorm
1599                + post_attention_layernorm
1600                + q_proj
1601                + k_proj
1602                + v_proj
1603                + o_proj
1604                + gate_proj
1605                + up_proj
1606                + down_proj
1607        };
1608        Ok(vec![
1609            per_layer_elems * dtype.size_in_bytes();
1610            cfg.text_config.num_hidden_layers
1611        ])
1612    }
1613
1614    fn num_layers(&self, config: &str) -> Result<usize> {
1615        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1616        Ok(cfg.text_config.num_hidden_layers)
1617    }
1618
1619    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1620        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1621        let cfg = &cfg.text_config;
1622
1623        let cfg = ModelConfigMetadata {
1624            max_seq_len: cfg.max_position_embeddings,
1625            num_layers: cfg.num_hidden_layers,
1626            hidden_size: cfg.hidden_size,
1627            num_kv_heads: cfg.num_key_value_heads,
1628            num_attn_heads: cfg.num_attention_heads,
1629            sliding_window: cfg.sliding_window,
1630            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1631            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1632        };
1633
1634        Ok(Box::new(cfg))
1635    }
1636
1637    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1638        Some(vec![NonMappedSubModel::Vision])
1639    }
1640}
1641
1642// ======================== MLlama Loader
1643
1644/// [`VisionLoader`] for an Llama Vision model.
1645///
1646/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1647pub struct VLlamaLoader;
1648
1649pub struct VLlamaPrefixer;
1650
1651impl MultimodalPromptPrefixer for VLlamaPrefixer {
1652    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1653        format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1654    }
1655}
1656
1657impl VisionModelLoader for VLlamaLoader {
1658    fn load(
1659        &self,
1660        config: &str,
1661        vb: ShardedVarBuilder,
1662        normal_loading_metadata: NormalLoadingMetadata,
1663        attention_mechanism: AttentionImplementation,
1664    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1665        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1666        Ok(Box::new(MLlamaModel::new(
1667            &cfg,
1668            vb,
1669            self.is_gptx(config),
1670            normal_loading_metadata,
1671            attention_mechanism,
1672        )?))
1673    }
1674    fn is_gptx(&self, _config: &str) -> bool {
1675        true
1676    }
1677    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1678        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1679        Ok(Box::new(cfg))
1680    }
1681    fn get_processor(
1682        &self,
1683        _model_config: &str,
1684        _processor_config: Option<ProcessorConfig>,
1685        _preprocessor_config: PreProcessorConfig,
1686        _max_edge: Option<u32>,
1687    ) -> Arc<dyn Processor + Send + Sync> {
1688        Arc::new(MLlamaProcessor::new())
1689    }
1690    fn supports_paged_attention(&self, _config: &str) -> bool {
1691        false
1692    }
1693    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1694        true
1695    }
1696    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1697        Arc::new(VLlamaPrefixer)
1698    }
1699    fn modalities(&self, _config: &str) -> Result<Modalities> {
1700        Ok(Modalities {
1701            input: vec![SupportedModality::Text, SupportedModality::Vision],
1702            output: vec![SupportedModality::Text],
1703        })
1704    }
1705}
1706
1707impl IsqModelLoader for VLlamaLoader {
1708    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1709        let config: MLlamaConfig = serde_json::from_str(config)?;
1710        let cross_attn_layers = &config.text_config.cross_attention_layers;
1711        let transformer_layers =
1712            (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1713        let mut text_regexes = Vec::new();
1714        for layer in transformer_layers {
1715            text_regexes.extend(vec![
1716                // Attention text
1717                Regex::new(&format!(
1718                    r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1719                ))?,
1720                Regex::new(&format!(
1721                    r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1722                ))?,
1723                Regex::new(&format!(
1724                    r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1725                ))?,
1726                Regex::new(&format!(
1727                    r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1728                ))?,
1729                // MLP text
1730                Regex::new(&format!(
1731                    r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1732                ))?,
1733                Regex::new(&format!(
1734                    r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1735                ))?,
1736                Regex::new(&format!(
1737                    r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1738                ))?,
1739            ]);
1740        }
1741        let vision_regexes = vec![
1742            // Vision attention (transformer)
1743            Regex::new(
1744                r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1745            )?,
1746            Regex::new(
1747                r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1748            )?,
1749            Regex::new(
1750                r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1751            )?,
1752            Regex::new(
1753                r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1754            )?,
1755            // Vision attention (global transforemr)
1756            Regex::new(
1757                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1758            )?,
1759            Regex::new(
1760                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1761            )?,
1762            Regex::new(
1763                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1764            )?,
1765            Regex::new(
1766                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1767            )?,
1768            // MLP vision
1769            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1770            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1771        ];
1772
1773        Ok([text_regexes, vision_regexes].concat())
1774    }
1775    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1776        self.isq_layer_regexes(config)
1777    }
1778}
1779
1780impl DeviceMappedModelLoader for VLlamaLoader {
1781    fn mapped_max_act_size_elems(
1782        &self,
1783        config: &str,
1784        params: &AutoDeviceMapParams,
1785    ) -> Result<usize> {
1786        let AutoDeviceMapParams::Vision {
1787            max_seq_len,
1788            max_batch_size,
1789            max_image_shape: _,
1790            max_num_images,
1791        } = params
1792        else {
1793            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1794        };
1795
1796        let config: MLlamaConfig = serde_json::from_str(config)?;
1797
1798        let img_seq_len = {
1799            let cfg = &config.vision_config;
1800            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1801            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1802            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1803        };
1804        let img_seq_len = img_seq_len * max_num_images;
1805
1806        let max_cross_text_attn = {
1807            let cfg = &config.text_config;
1808            max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1809        };
1810
1811        let max_self_text_attn = {
1812            let cfg = &config.text_config;
1813            max_batch_size * cfg.num_attention_heads * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)
1814        };
1815
1816        Ok(max_self_text_attn.max(max_cross_text_attn))
1817    }
1818
1819    fn non_mapped_max_act_size_elems(
1820        &self,
1821        config: &str,
1822        params: &AutoDeviceMapParams,
1823    ) -> Result<usize> {
1824        let AutoDeviceMapParams::Vision {
1825            max_seq_len: _,
1826            max_batch_size,
1827            max_image_shape: _,
1828            max_num_images,
1829        } = params
1830        else {
1831            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1832        };
1833
1834        let config: MLlamaConfig = serde_json::from_str(config)?;
1835
1836        let img_seq_len = {
1837            let cfg = &config.vision_config;
1838            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1839            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1840            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1841        };
1842        let max_vision_attn = {
1843            let cfg = &config.vision_config;
1844            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1845        };
1846
1847        Ok(max_vision_attn)
1848    }
1849
1850    fn non_mapped_size_in_bytes(
1851        &self,
1852        config: &str,
1853        dtype: DType,
1854        weight_pack_factor: usize,
1855        _matformer_config: Option<&MatformerSliceConfig>,
1856    ) -> Result<usize> {
1857        let config: MLlamaConfig = serde_json::from_str(config)?;
1858        let text_elems = {
1859            let cfg = &config.text_config;
1860            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1861            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1862            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1863                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1864            } else {
1865                0
1866            };
1867            let norm = cfg.hidden_size;
1868            embed_tokens + lm_head + norm
1869        };
1870
1871        let vision_elems = {
1872            let cfg = &config.vision_config;
1873
1874            let conv_cfg = Conv2dConfig {
1875                stride: cfg.patch_size,
1876                ..Default::default()
1877            };
1878            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1879                * cfg.patch_size
1880                * cfg.patch_size;
1881
1882            let class_embedding = cfg.hidden_size;
1883
1884            let gated_positional_embedding = {
1885                let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1886                let embedding = num_patches * cfg.hidden_size;
1887                let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1888                    * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1889
1890                embedding + tile_embedding
1891            };
1892
1893            let pre_tile_positional_embedding =
1894                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1895            let post_tile_positional_embedding =
1896                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1897
1898            let layernorm_pre = cfg.hidden_size;
1899            let layernorm_post = cfg.hidden_size;
1900
1901            let encoder_layer = {
1902                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1903                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1904
1905                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1906                let q_proj =
1907                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1908                let k_proj =
1909                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1910                let v_proj =
1911                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1912                let o_proj =
1913                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1914
1915                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1916                    + cfg.intermediate_size;
1917                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1918                    + cfg.hidden_size;
1919
1920                input_layernorm
1921                    + post_attention_layernorm
1922                    + q_proj
1923                    + k_proj
1924                    + v_proj
1925                    + o_proj
1926                    + fc1
1927                    + fc2
1928            };
1929
1930            patch_embedding
1931                + class_embedding
1932                + gated_positional_embedding
1933                + pre_tile_positional_embedding
1934                + post_tile_positional_embedding
1935                + layernorm_pre
1936                + layernorm_post
1937                + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1938        };
1939
1940        let elems = text_elems + vision_elems;
1941        Ok(elems * dtype.size_in_bytes())
1942    }
1943
1944    fn layer_sizes_in_bytes(
1945        &self,
1946        config: &str,
1947        dtype: DType,
1948        weight_pack_factor: usize,
1949        _matformer_config: Option<&MatformerSliceConfig>,
1950    ) -> Result<Vec<usize>> {
1951        let config: MLlamaConfig = serde_json::from_str(config)?;
1952        let cfg = &config.text_config;
1953
1954        let mut layer_sizes = Vec::new();
1955
1956        for i in 0..cfg.num_hidden_layers {
1957            let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1958                // No isq for cross attention
1959                1
1960            } else {
1961                weight_pack_factor
1962            };
1963
1964            let per_layer_elems = {
1965                let input_layernorm = cfg.hidden_size;
1966                let post_attention_layernorm = cfg.hidden_size;
1967
1968                let size_in = cfg.hidden_size;
1969                let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1970                let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1971                let q_proj = size_in * size_q / weight_pack_factor;
1972                let k_proj = size_in * size_kv / weight_pack_factor;
1973                let v_proj = size_in * size_kv / weight_pack_factor;
1974                let o_proj = size_q * size_in / weight_pack_factor;
1975
1976                let h_size = cfg.hidden_size;
1977                let i_size = cfg.intermediate_size;
1978                let gate_proj = h_size * i_size / weight_pack_factor;
1979                let up_proj = h_size * i_size / weight_pack_factor;
1980                let down_proj = i_size * h_size / weight_pack_factor;
1981
1982                input_layernorm
1983                    + post_attention_layernorm
1984                    + q_proj
1985                    + k_proj
1986                    + v_proj
1987                    + o_proj
1988                    + gate_proj
1989                    + up_proj
1990                    + down_proj
1991            };
1992
1993            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1994        }
1995
1996        Ok(layer_sizes)
1997    }
1998
1999    fn num_layers(&self, config: &str) -> Result<usize> {
2000        let config: MLlamaConfig = serde_json::from_str(config)?;
2001        Ok(config.text_config.num_hidden_layers)
2002    }
2003
2004    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2005        let cfg: MLlamaConfig = serde_json::from_str(config)?;
2006        let cfg = &cfg.text_config;
2007
2008        let cfg = ModelConfigMetadata {
2009            max_seq_len: cfg.max_position_embeddings,
2010            num_layers: cfg.num_hidden_layers,
2011            hidden_size: cfg.hidden_size,
2012            num_kv_heads: cfg.num_key_value_heads,
2013            num_attn_heads: cfg.num_attention_heads,
2014            sliding_window: None,
2015            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2016            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2017        };
2018
2019        Ok(Box::new(cfg))
2020    }
2021
2022    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2023        Some(vec![NonMappedSubModel::Vision])
2024    }
2025}
2026
2027// ======================== Qwen2VL Loader
2028
2029/// [`VisionLoader`] for an Qwen2-VL model.
2030///
2031/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2032pub struct Qwen2VLLoader;
2033
2034pub struct Qwen2VLPrefixer;
2035
2036impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
2037    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2038        format!(
2039            "{}{prompt}",
2040            format!(
2041                "{}{}{}",
2042                Qwen2VLProcessor::VISION_START,
2043                Qwen2VLProcessor::IMAGE_PAD,
2044                Qwen2VLProcessor::VISION_END
2045            )
2046            .repeat(image_indexes.len())
2047        )
2048    }
2049}
2050
2051impl VisionModelLoader for Qwen2VLLoader {
2052    fn load(
2053        &self,
2054        config: &str,
2055        vb: ShardedVarBuilder,
2056        normal_loading_metadata: NormalLoadingMetadata,
2057        attention_mechanism: AttentionImplementation,
2058    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2059        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2060        Ok(Box::new(Qwen2VLModel::new(
2061            &cfg,
2062            vb,
2063            self.is_gptx(config),
2064            normal_loading_metadata,
2065            attention_mechanism,
2066        )?))
2067    }
2068    fn is_gptx(&self, _config: &str) -> bool {
2069        true
2070    }
2071    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2072        let config: Qwen2VLConfig = serde_json::from_str(config)?;
2073        Ok(Box::new(config))
2074    }
2075    fn get_processor(
2076        &self,
2077        _model_config: &str,
2078        _processor_config: Option<ProcessorConfig>,
2079        _preprocessor_config: PreProcessorConfig,
2080        max_edge: Option<u32>,
2081    ) -> Arc<dyn Processor + Send + Sync> {
2082        Arc::new(Qwen2VLProcessor::new(max_edge))
2083    }
2084    fn supports_paged_attention(&self, _config: &str) -> bool {
2085        false
2086    }
2087    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2088        Arc::new(Qwen2VLPrefixer)
2089    }
2090    fn modalities(&self, _config: &str) -> Result<Modalities> {
2091        Ok(Modalities {
2092            input: vec![SupportedModality::Text, SupportedModality::Vision],
2093            output: vec![SupportedModality::Text],
2094        })
2095    }
2096}
2097
2098impl IsqModelLoader for Qwen2VLLoader {
2099    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2100        Ok(vec![
2101            Regex::new(r"lm_head\.(weight|bias)$")?,
2102            // Attention
2103            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2104            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2105            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2106            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2107            // MLP
2108            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2109            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2110            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2111        ])
2112    }
2113    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2114        self.isq_layer_regexes(config)
2115    }
2116}
2117
2118impl DeviceMappedModelLoader for Qwen2VLLoader {
2119    fn mapped_max_act_size_elems(
2120        &self,
2121        config: &str,
2122        params: &AutoDeviceMapParams,
2123    ) -> Result<usize> {
2124        let AutoDeviceMapParams::Vision {
2125            max_seq_len,
2126            max_batch_size,
2127            max_image_shape,
2128            max_num_images,
2129        } = params
2130        else {
2131            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2132        };
2133
2134        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2135
2136        // For images, grid_t=1. After spatial merging, grid_h and grid_w are reduced.
2137        let img_seq_len = {
2138            let cfg = &cfg.vision_config;
2139            // grid_t is 1 for images (temporal dimension is for video only)
2140            let grid_t = 1;
2141            // After patch embedding and spatial merge, the effective grid dimensions are reduced
2142            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
2143            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
2144            grid_t * grid_h * grid_w * max_num_images
2145        };
2146
2147        let max_text_attn = {
2148            // This model injects the vision information directly into the input embeddings
2149            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2150            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2151        };
2152
2153        Ok(max_text_attn)
2154    }
2155
2156    fn non_mapped_max_act_size_elems(
2157        &self,
2158        config: &str,
2159        params: &AutoDeviceMapParams,
2160    ) -> Result<usize> {
2161        let AutoDeviceMapParams::Vision {
2162            max_seq_len: _,
2163            max_batch_size,
2164            max_image_shape,
2165            max_num_images,
2166        } = params
2167        else {
2168            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2169        };
2170
2171        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2172
2173        // For the vision encoder, before spatial merging
2174        let img_seq_len = {
2175            let cfg = &cfg.vision_config;
2176            // grid_t is 1 for images
2177            let grid_t = 1;
2178            let grid_h = max_image_shape.0 / cfg.patch_size;
2179            let grid_w = max_image_shape.1 / cfg.patch_size;
2180            grid_t * grid_h * grid_w
2181        };
2182
2183        let max_vision_attn = {
2184            let cfg = &cfg.vision_config;
2185            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2186        };
2187
2188        Ok(max_vision_attn)
2189    }
2190
2191    fn non_mapped_size_in_bytes(
2192        &self,
2193        config: &str,
2194        dtype: DType,
2195        weight_pack_factor: usize,
2196        _matformer_config: Option<&MatformerSliceConfig>,
2197    ) -> Result<usize> {
2198        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2199        let text_elems = {
2200            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2201            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2202            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2203                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2204            } else {
2205                0
2206            };
2207            let norm = cfg.hidden_size;
2208            embed_tokens + lm_head + norm
2209        };
2210
2211        let patch_merger = {
2212            let cfg = &cfg.vision_config;
2213            let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2214
2215            let mlp0 = hidden_size * hidden_size + hidden_size;
2216            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2217
2218            let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2219
2220            mlp0 + mlp2 + ln_q
2221        };
2222
2223        let patch_embed = {
2224            let cfg = &cfg.vision_config;
2225            let conv_cfg = Conv3dConfig {
2226                stride: cfg.patch_size,
2227                ..Default::default()
2228            };
2229            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2230            cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2231                * kernel_sizes[0]
2232                * kernel_sizes[1]
2233                * kernel_sizes[2]
2234        };
2235
2236        let encoder_layer = {
2237            let cfg = &cfg.vision_config;
2238            let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2239            let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2240
2241            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2242            let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2243            let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2244            let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2245
2246            let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2247            let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2248
2249            norm1 + norm2 + fc1 + fc2 + qkv + out
2250        };
2251
2252        let elems =
2253            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2254
2255        Ok(elems * dtype.size_in_bytes())
2256    }
2257
2258    fn layer_sizes_in_bytes(
2259        &self,
2260        config: &str,
2261        dtype: DType,
2262        weight_pack_factor: usize,
2263        _matformer_config: Option<&MatformerSliceConfig>,
2264    ) -> Result<Vec<usize>> {
2265        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2266        let per_layer_elems = {
2267            let input_layernorm = cfg.hidden_size;
2268            let post_attention_layernorm = cfg.hidden_size;
2269
2270            let size_in = cfg.hidden_size;
2271            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2272            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2273            let q_proj = size_in * size_q / weight_pack_factor + size_q;
2274            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2275            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2276            let o_proj = size_q * size_in / weight_pack_factor;
2277
2278            let h_size = cfg.hidden_size;
2279            let i_size = cfg.intermediate_size;
2280            let gate_proj = h_size * i_size / weight_pack_factor;
2281            let up_proj = h_size * i_size / weight_pack_factor;
2282            let down_proj = i_size * h_size / weight_pack_factor;
2283
2284            input_layernorm
2285                + post_attention_layernorm
2286                + q_proj
2287                + k_proj
2288                + v_proj
2289                + o_proj
2290                + gate_proj
2291                + up_proj
2292                + down_proj
2293        };
2294        Ok(vec![
2295            per_layer_elems * dtype.size_in_bytes();
2296            cfg.num_hidden_layers
2297        ])
2298    }
2299
2300    fn num_layers(&self, config: &str) -> Result<usize> {
2301        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2302        Ok(cfg.num_hidden_layers)
2303    }
2304
2305    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2306        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2307
2308        let cfg = ModelConfigMetadata {
2309            max_seq_len: cfg.max_position_embeddings,
2310            num_layers: cfg.num_hidden_layers,
2311            hidden_size: cfg.hidden_size,
2312            num_kv_heads: cfg.num_key_value_heads,
2313            num_attn_heads: cfg.num_attention_heads,
2314            sliding_window: cfg.sliding_window,
2315            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2316            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2317        };
2318
2319        Ok(Box::new(cfg))
2320    }
2321
2322    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2323        Some(vec![NonMappedSubModel::Vision])
2324    }
2325}
2326
2327// ======================== Idefics 3 loader
2328
2329/// [`VisionLoader`] for an Idefics 3 Vision model.
2330///
2331/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2332pub struct Idefics3Loader;
2333
2334pub struct Idefics3Prefixer;
2335
2336impl MultimodalPromptPrefixer for Idefics3Prefixer {
2337    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2338        // Chat template does it
2339        prompt.to_string()
2340    }
2341}
2342
2343impl VisionModelLoader for Idefics3Loader {
2344    fn load(
2345        &self,
2346        config: &str,
2347        vb: ShardedVarBuilder,
2348        normal_loading_metadata: NormalLoadingMetadata,
2349        attention_mechanism: AttentionImplementation,
2350    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2351        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2352        Ok(Box::new(Idefics3Model::new(
2353            &cfg,
2354            vb,
2355            self.is_gptx(config),
2356            normal_loading_metadata,
2357            attention_mechanism,
2358        )?))
2359    }
2360    fn is_gptx(&self, _config: &str) -> bool {
2361        true
2362    }
2363    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2364        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2365        Ok(Box::new(cfg))
2366    }
2367    fn get_processor(
2368        &self,
2369        _model_config: &str,
2370        processor_config: Option<ProcessorConfig>,
2371        preprocessor_config: PreProcessorConfig,
2372        max_edge: Option<u32>,
2373    ) -> Arc<dyn Processor + Send + Sync> {
2374        Arc::new(Idefics3Processor::new(
2375            processor_config.unwrap_or_default(),
2376            preprocessor_config,
2377            max_edge,
2378        ))
2379    }
2380    fn supports_paged_attention(&self, _config: &str) -> bool {
2381        true
2382    }
2383    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2384        true
2385    }
2386    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2387        Arc::new(Idefics3Prefixer)
2388    }
2389    fn modalities(&self, _config: &str) -> Result<Modalities> {
2390        Ok(Modalities {
2391            input: vec![SupportedModality::Text, SupportedModality::Vision],
2392            output: vec![SupportedModality::Text],
2393        })
2394    }
2395}
2396
2397impl IsqModelLoader for Idefics3Loader {
2398    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2399        Ok(vec![
2400            Regex::new(r"lm_head\.(weight|bias)$")?,
2401            // Attention
2402            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2403            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2404            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2405            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2406            // MLP
2407            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2408            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2409            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2410        ])
2411    }
2412    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2413        Ok(vec![
2414            Regex::new(r"lm_head\.(weight|bias)$")?,
2415            // Attention
2416            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2417            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2418            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2419            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2420            // MLP
2421            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2422            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2423            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2424            // // Attention (vision)
2425            // Regex::new(
2426            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2427            // )?,
2428            // Regex::new(
2429            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
2430            // )?,
2431            // Regex::new(
2432            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
2433            // )?,
2434            // Regex::new(
2435            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(weight|bias)$",
2436            // )?,
2437            // MLP (vision)
2438            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2439            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
2440        ])
2441    }
2442}
2443
2444impl DeviceMappedModelLoader for Idefics3Loader {
2445    fn mapped_max_act_size_elems(
2446        &self,
2447        config: &str,
2448        params: &AutoDeviceMapParams,
2449    ) -> Result<usize> {
2450        let AutoDeviceMapParams::Vision {
2451            max_seq_len,
2452            max_batch_size,
2453            max_image_shape: _,
2454            max_num_images,
2455        } = params
2456        else {
2457            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2458        };
2459
2460        let cfg: Idefics3Config = serde_json::from_str(config)?;
2461
2462        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2463        let img_seq_len = (num_patches + 1) * max_num_images;
2464
2465        let max_text_attn = {
2466            // This model injects the vision information directly into the input embeddings
2467            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2468            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2469        };
2470
2471        Ok(max_text_attn)
2472    }
2473
2474    fn non_mapped_max_act_size_elems(
2475        &self,
2476        config: &str,
2477        params: &AutoDeviceMapParams,
2478    ) -> Result<usize> {
2479        let AutoDeviceMapParams::Vision {
2480            max_seq_len: _,
2481            max_batch_size,
2482            max_image_shape: _,
2483            max_num_images,
2484        } = params
2485        else {
2486            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2487        };
2488
2489        let cfg: Idefics3Config = serde_json::from_str(config)?;
2490
2491        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2492        let img_seq_len = num_patches + 1;
2493
2494        let max_vision_attn = {
2495            // do_image_splitting = true
2496            let images_factor = 5;
2497
2498            (max_batch_size * images_factor * max_num_images)
2499                * cfg.vision_config.num_attention_heads
2500                * img_seq_len
2501                * img_seq_len
2502        };
2503
2504        Ok(max_vision_attn)
2505    }
2506
2507    fn non_mapped_size_in_bytes(
2508        &self,
2509        config: &str,
2510        dtype: DType,
2511        weight_pack_factor: usize,
2512        _matformer_config: Option<&MatformerSliceConfig>,
2513    ) -> Result<usize> {
2514        let cfg: Idefics3Config = serde_json::from_str(config)?;
2515        let text_elems = {
2516            let cfg = &cfg.text_config;
2517
2518            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2519            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2520            let norm = cfg.hidden_size;
2521            embed_tokens + lm_head + norm
2522        };
2523
2524        let connector_elems = {
2525            let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2526            let out_dim = cfg.text_config.hidden_size;
2527
2528            in_dim * out_dim
2529        };
2530
2531        let vision_transformer = {
2532            let cfg = &cfg.vision_config;
2533
2534            let post_layernorm = cfg.hidden_size;
2535
2536            let conv_config = Conv2dConfig {
2537                stride: cfg.patch_size,
2538                ..Default::default()
2539            };
2540            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2541                * cfg.patch_size
2542                * cfg.patch_size;
2543
2544            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2545            let num_patches = num_patches_per_side.pow(2);
2546            let position_embedding = num_patches * cfg.hidden_size;
2547
2548            let layer_elems = {
2549                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2550                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2551
2552                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2553                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2554
2555                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2556                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2557                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2558                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2559
2560                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2561            };
2562
2563            post_layernorm
2564                + patch_embedding
2565                + position_embedding
2566                + layer_elems * cfg.num_hidden_layers
2567        };
2568
2569        let elems = text_elems + connector_elems + vision_transformer;
2570
2571        Ok(elems * dtype.size_in_bytes())
2572    }
2573
2574    fn layer_sizes_in_bytes(
2575        &self,
2576        config: &str,
2577        dtype: DType,
2578        weight_pack_factor: usize,
2579        _matformer_config: Option<&MatformerSliceConfig>,
2580    ) -> Result<Vec<usize>> {
2581        let cfg: Idefics3Config = serde_json::from_str(config)?;
2582        let cfg = cfg.text_config;
2583        let per_layer_elems = {
2584            let input_layernorm = cfg.hidden_size;
2585            let post_attention_layernorm = cfg.hidden_size;
2586
2587            let size_in = cfg.hidden_size;
2588            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2589            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2590            let q_proj = size_in * size_q / weight_pack_factor;
2591            let k_proj = size_in * size_kv / weight_pack_factor;
2592            let v_proj = size_in * size_kv / weight_pack_factor;
2593            let o_proj = size_q * size_in / weight_pack_factor;
2594
2595            let h_size = cfg.hidden_size;
2596            let i_size = cfg.intermediate_size;
2597            let gate_proj = h_size * i_size / weight_pack_factor;
2598            let up_proj = h_size * i_size / weight_pack_factor;
2599            let down_proj = i_size * h_size / weight_pack_factor;
2600
2601            input_layernorm
2602                + post_attention_layernorm
2603                + q_proj
2604                + k_proj
2605                + v_proj
2606                + o_proj
2607                + gate_proj
2608                + up_proj
2609                + down_proj
2610        };
2611        Ok(vec![
2612            per_layer_elems * dtype.size_in_bytes();
2613            cfg.num_hidden_layers
2614        ])
2615    }
2616
2617    fn num_layers(&self, config: &str) -> Result<usize> {
2618        let cfg: Idefics3Config = serde_json::from_str(config)?;
2619        Ok(cfg.text_config.num_hidden_layers)
2620    }
2621    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2622        let cfg: Idefics3Config = serde_json::from_str(config)?;
2623        let cfg = &cfg.text_config;
2624
2625        let cfg = ModelConfigMetadata {
2626            max_seq_len: cfg.max_position_embeddings,
2627            num_layers: cfg.num_hidden_layers,
2628            hidden_size: cfg.hidden_size,
2629            num_kv_heads: cfg.num_key_value_heads,
2630            num_attn_heads: cfg.num_attention_heads,
2631            sliding_window: None,
2632            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2633            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2634        };
2635
2636        Ok(Box::new(cfg))
2637    }
2638
2639    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2640        Some(vec![NonMappedSubModel::Vision])
2641    }
2642}
2643
2644// ======================== MiniCpm-O loader
2645
2646/// [`VisionLoader`] for an MiniCpm-O model.
2647///
2648/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2649pub struct MiniCpmOLoader;
2650
2651pub struct MiniCpmOPrefixer;
2652
2653impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2654    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2655        format!(
2656            "{}{prompt}",
2657            "(<image>./</image>)".repeat(image_indexes.len())
2658        )
2659    }
2660}
2661
2662impl VisionModelLoader for MiniCpmOLoader {
2663    fn load(
2664        &self,
2665        config: &str,
2666        vb: ShardedVarBuilder,
2667        normal_loading_metadata: NormalLoadingMetadata,
2668        attention_mechanism: AttentionImplementation,
2669    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2670        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2671        Ok(Box::new(MiniCpmOModel::new(
2672            &cfg,
2673            vb,
2674            self.is_gptx(config),
2675            normal_loading_metadata,
2676            attention_mechanism,
2677        )?))
2678    }
2679    fn is_gptx(&self, _config: &str) -> bool {
2680        true
2681    }
2682    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2683        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2684        Ok(Box::new(cfg))
2685    }
2686    fn get_processor(
2687        &self,
2688        _model_config: &str,
2689        processor_config: Option<ProcessorConfig>,
2690        preprocessor_config: PreProcessorConfig,
2691        max_edge: Option<u32>,
2692    ) -> Arc<dyn Processor + Send + Sync> {
2693        Arc::new(MiniCpmOProcessor::new(
2694            processor_config.unwrap_or_default(),
2695            preprocessor_config,
2696            max_edge,
2697        ))
2698    }
2699    fn supports_paged_attention(&self, _config: &str) -> bool {
2700        true
2701    }
2702    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2703        Arc::new(MiniCpmOPrefixer)
2704    }
2705    fn modalities(&self, _config: &str) -> Result<Modalities> {
2706        Ok(Modalities {
2707            input: vec![SupportedModality::Text, SupportedModality::Vision],
2708            output: vec![SupportedModality::Text],
2709        })
2710    }
2711}
2712
2713impl IsqModelLoader for MiniCpmOLoader {
2714    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2715        Ok(vec![
2716            Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2717            // Attention
2718            Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2719            Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2720            Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2721            Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2722            // MLP
2723            Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2724            Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2725            Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2726        ])
2727    }
2728    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2729        self.isq_layer_regexes(config)
2730    }
2731}
2732
2733impl DeviceMappedModelLoader for MiniCpmOLoader {
2734    fn mapped_max_act_size_elems(
2735        &self,
2736        config: &str,
2737        params: &AutoDeviceMapParams,
2738    ) -> Result<usize> {
2739        let AutoDeviceMapParams::Vision {
2740            max_seq_len,
2741            max_batch_size,
2742            max_image_shape: _,
2743            max_num_images,
2744        } = params
2745        else {
2746            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2747        };
2748
2749        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2750
2751        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2752        let img_seq_len = (num_patches + 1) * max_num_images;
2753
2754        let max_text_attn = {
2755            // This model injects the vision information directly into the input embeddings
2756            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2757            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2758        };
2759
2760        Ok(max_text_attn)
2761    }
2762
2763    fn non_mapped_max_act_size_elems(
2764        &self,
2765        config: &str,
2766        params: &AutoDeviceMapParams,
2767    ) -> Result<usize> {
2768        let AutoDeviceMapParams::Vision {
2769            max_seq_len: _,
2770            max_batch_size,
2771            max_image_shape: _,
2772            max_num_images,
2773        } = params
2774        else {
2775            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2776        };
2777
2778        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2779
2780        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2781        let img_seq_len = num_patches + 1;
2782
2783        let max_vision_attn = {
2784            // do_image_splitting = true
2785            let images_factor = 5;
2786
2787            (max_batch_size * images_factor * max_num_images)
2788                * cfg.vision_config.num_attention_heads
2789                * img_seq_len
2790                * img_seq_len
2791        };
2792
2793        Ok(max_vision_attn)
2794    }
2795
2796    fn non_mapped_size_in_bytes(
2797        &self,
2798        config: &str,
2799        dtype: DType,
2800        weight_pack_factor: usize,
2801        _matformer_config: Option<&MatformerSliceConfig>,
2802    ) -> Result<usize> {
2803        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2804        let text_elems = {
2805            let cfg = &cfg.text_config;
2806
2807            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2808            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2809            let norm = cfg.hidden_size;
2810            embed_tokens + lm_head + norm
2811        };
2812
2813        let vision_transformer = {
2814            let cfg = &cfg.vision_config;
2815
2816            let post_layernorm = cfg.hidden_size;
2817
2818            let conv_config = Conv2dConfig {
2819                stride: cfg.patch_size,
2820                ..Default::default()
2821            };
2822            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2823                * cfg.patch_size
2824                * cfg.patch_size;
2825
2826            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2827            let num_patches = num_patches_per_side.pow(2);
2828            let position_embedding = num_patches * cfg.hidden_size;
2829
2830            let layer_elems = {
2831                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2832                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2833
2834                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2835                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2836
2837                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2838                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2839                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2840                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2841
2842                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2843            };
2844
2845            post_layernorm
2846                + patch_embedding
2847                + position_embedding
2848                + layer_elems * cfg.num_hidden_layers
2849        };
2850
2851        let elems = text_elems + vision_transformer;
2852
2853        Ok(elems * dtype.size_in_bytes())
2854    }
2855
2856    fn layer_sizes_in_bytes(
2857        &self,
2858        config: &str,
2859        dtype: DType,
2860        weight_pack_factor: usize,
2861        _matformer_config: Option<&MatformerSliceConfig>,
2862    ) -> Result<Vec<usize>> {
2863        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2864        let cfg = cfg.text_config;
2865        let per_layer_elems = {
2866            let input_layernorm = cfg.hidden_size;
2867            let post_attention_layernorm = cfg.hidden_size;
2868
2869            let size_in = cfg.hidden_size;
2870            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2871            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2872            let q_proj = size_in * size_q / weight_pack_factor;
2873            let k_proj = size_in * size_kv / weight_pack_factor;
2874            let v_proj = size_in * size_kv / weight_pack_factor;
2875            let o_proj = size_q * size_in / weight_pack_factor;
2876
2877            let h_size = cfg.hidden_size;
2878            let i_size = cfg.intermediate_size;
2879            let gate_proj = h_size * i_size / weight_pack_factor;
2880            let up_proj = h_size * i_size / weight_pack_factor;
2881            let down_proj = i_size * h_size / weight_pack_factor;
2882
2883            input_layernorm
2884                + post_attention_layernorm
2885                + q_proj
2886                + k_proj
2887                + v_proj
2888                + o_proj
2889                + gate_proj
2890                + up_proj
2891                + down_proj
2892        };
2893        Ok(vec![
2894            per_layer_elems * dtype.size_in_bytes();
2895            cfg.num_hidden_layers
2896        ])
2897    }
2898
2899    fn num_layers(&self, config: &str) -> Result<usize> {
2900        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2901        Ok(cfg.text_config.num_hidden_layers)
2902    }
2903    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2904        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2905        let cfg = &cfg.text_config;
2906
2907        let cfg = ModelConfigMetadata {
2908            max_seq_len: cfg.max_position_embeddings,
2909            num_layers: cfg.num_hidden_layers,
2910            hidden_size: cfg.hidden_size,
2911            num_kv_heads: cfg.num_key_value_heads,
2912            num_attn_heads: cfg.num_attention_heads,
2913            sliding_window: None,
2914            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2915            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2916        };
2917
2918        Ok(Box::new(cfg))
2919    }
2920}
2921
2922// ======================== Phi 4MM loader
2923
2924/// [`VisionLoader`] for a Phi 4MM Vision model.
2925///
2926/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2927pub struct Phi4MMLoader;
2928
2929pub struct Phi4MMPrefixer;
2930
2931impl MultimodalPromptPrefixer for Phi4MMPrefixer {
2932    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2933        // Image indexing starts at 0.
2934
2935        format!(
2936            "{}{prompt}",
2937            image_indexes
2938                .into_iter()
2939                .map(|image_index| format!("<|image_{}|>", image_index + 1))
2940                .join("")
2941        )
2942    }
2943    fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
2944        // Image indexing starts at 0.
2945
2946        format!(
2947            "{}{prompt}",
2948            audio_indexes
2949                .into_iter()
2950                .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
2951                .join("")
2952        )
2953    }
2954}
2955
2956impl VisionModelLoader for Phi4MMLoader {
2957    fn load(
2958        &self,
2959        config: &str,
2960        vb: ShardedVarBuilder,
2961        normal_loading_metadata: NormalLoadingMetadata,
2962        attention_mechanism: AttentionImplementation,
2963    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2964        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2965        Ok(Box::new(Phi4MMModel::new(
2966            &cfg,
2967            vb,
2968            self.is_gptx(config),
2969            normal_loading_metadata,
2970            attention_mechanism,
2971        )?))
2972    }
2973    fn is_gptx(&self, _config: &str) -> bool {
2974        true
2975    }
2976    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2977        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2978        Ok(Box::new(cfg))
2979    }
2980    fn get_processor(
2981        &self,
2982        _model_config: &str,
2983        processor_config: Option<ProcessorConfig>,
2984        preprocessor_config: PreProcessorConfig,
2985        _max_edge: Option<u32>,
2986    ) -> Arc<dyn Processor + Send + Sync> {
2987        Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2988    }
2989    fn supports_paged_attention(&self, _config: &str) -> bool {
2990        true
2991    }
2992    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2993        true
2994    }
2995    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2996        Arc::new(Phi4MMPrefixer)
2997    }
2998    fn modalities(&self, _config: &str) -> Result<Modalities> {
2999        Ok(Modalities {
3000            input: vec![
3001                SupportedModality::Text,
3002                SupportedModality::Vision,
3003                SupportedModality::Audio,
3004            ],
3005            output: vec![SupportedModality::Text],
3006        })
3007    }
3008}
3009
3010impl IsqModelLoader for Phi4MMLoader {
3011    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3012        Ok(vec![
3013            Regex::new(r"lm_head\.(weight|bias)$")?,
3014            // Attention
3015            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
3016            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3017            // MLP
3018            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
3019            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3020        ])
3021    }
3022    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3023        self.isq_layer_regexes(config)
3024    }
3025}
3026
3027impl DeviceMappedModelLoader for Phi4MMLoader {
3028    fn mapped_max_act_size_elems(
3029        &self,
3030        config: &str,
3031        params: &AutoDeviceMapParams,
3032    ) -> Result<usize> {
3033        // NOTE: we ignore max_num_images although it can only be one...
3034        let AutoDeviceMapParams::Vision {
3035            max_seq_len,
3036            max_batch_size,
3037            max_image_shape: _,
3038            max_num_images,
3039        } = params
3040        else {
3041            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3042        };
3043
3044        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3045
3046        let vcfg = &PHI4_MM_VISION_CFG;
3047
3048        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3049        let img_seq_len = (num_patches + 1) * max_num_images;
3050
3051        let max_text_attn = {
3052            // This model injects the vision information directly into the input embeddings
3053            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3054            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3055        };
3056
3057        Ok(max_text_attn)
3058    }
3059
3060    fn non_mapped_max_act_size_elems(
3061        &self,
3062        _config: &str,
3063        params: &AutoDeviceMapParams,
3064    ) -> Result<usize> {
3065        let AutoDeviceMapParams::Vision {
3066            max_seq_len: _,
3067            max_batch_size,
3068            max_image_shape,
3069            max_num_images,
3070        } = params
3071        else {
3072            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3073        };
3074
3075        let vcfg = &PHI4_MM_VISION_CFG;
3076
3077        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3078        let img_seq_len = num_patches + 1;
3079
3080        let max_batch_size = max_batch_size
3081            * (max_image_shape
3082                .0
3083                .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3084                * max_image_shape
3085                    .1
3086                    .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3087                + 1);
3088
3089        let max_vision_attn = (max_batch_size * max_num_images)
3090            * vcfg.num_attention_heads
3091            * img_seq_len
3092            * img_seq_len;
3093        let max_qkv = 3
3094            * (max_batch_size
3095                * vcfg.num_attention_heads
3096                * img_seq_len
3097                * (vcfg.hidden_size / vcfg.num_attention_heads));
3098
3099        Ok(max_vision_attn + max_qkv)
3100    }
3101
3102    fn non_mapped_size_in_bytes(
3103        &self,
3104        config: &str,
3105        dtype: DType,
3106        weight_pack_factor: usize,
3107        _matformer_config: Option<&MatformerSliceConfig>,
3108    ) -> Result<usize> {
3109        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3110        let elems = {
3111            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3112            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3113            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3114                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3115            } else {
3116                0
3117            };
3118            let norm = cfg.hidden_size;
3119
3120            let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3121                let projection_cls = img_embed
3122                    .projection_cls
3123                    .clone()
3124                    .unwrap_or("linear".to_string());
3125                let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3126                let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3127                let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3128
3129                let proj = match (projection_cls.as_str(), use_hd_transform) {
3130                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3131                    ("mlp", true) => {
3132                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3133                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3134                        a + b
3135                    }
3136                    ("mlp", false) => {
3137                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3138                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3139                        a + b
3140                    }
3141                    _ => {
3142                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3143                    }
3144                };
3145
3146                let (glb_gn, sub_gn) = if with_learnable_separator {
3147                    let glb_gn = image_dim_out * 4;
3148                    let sub_gn = image_dim_out * 4;
3149                    (glb_gn, sub_gn)
3150                } else {
3151                    (0, 0)
3152                };
3153
3154                let vision_transformer = {
3155                    let cfg = &PHI4_MM_VISION_CFG;
3156
3157                    let post_layernorm = cfg.hidden_size;
3158
3159                    let conv_config = Conv2dConfig {
3160                        stride: cfg.patch_size,
3161                        ..Default::default()
3162                    };
3163                    let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3164                        * cfg.patch_size
3165                        * cfg.patch_size;
3166
3167                    let num_patches_per_side = cfg.image_size / cfg.patch_size;
3168                    let num_patches = num_patches_per_side.pow(2);
3169                    let position_embedding = num_patches * cfg.hidden_size;
3170
3171                    let layer_elems = {
3172                        let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3173                        let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3174
3175                        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3176                        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3177
3178                        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3179                        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3180                        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3181                        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3182
3183                        layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3184                    };
3185
3186                    post_layernorm
3187                        + patch_embedding
3188                        + position_embedding
3189                        + layer_elems * cfg.num_hidden_layers
3190                };
3191
3192                proj + glb_gn + sub_gn + vision_transformer
3193            } else {
3194                0
3195            };
3196
3197            embed_tokens + lm_head + norm + image_embed
3198        };
3199
3200        Ok(elems * dtype.size_in_bytes())
3201    }
3202
3203    fn layer_sizes_in_bytes(
3204        &self,
3205        config: &str,
3206        dtype: DType,
3207        weight_pack_factor: usize,
3208        _matformer_config: Option<&MatformerSliceConfig>,
3209    ) -> Result<Vec<usize>> {
3210        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3211        let per_layer_elems = {
3212            let input_layernorm = cfg.hidden_size;
3213            let post_attention_layernorm = cfg.hidden_size;
3214
3215            let size_in = cfg.hidden_size;
3216            let head_dim = cfg.head_dim();
3217            let op_size =
3218                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3219            let qkv_proj = size_in * op_size / weight_pack_factor;
3220            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3221
3222            let h_size = cfg.hidden_size;
3223            let i_size = cfg.intermediate_size;
3224            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3225            let down_proj = h_size * i_size / weight_pack_factor;
3226
3227            input_layernorm
3228                + post_attention_layernorm
3229                + qkv_proj
3230                + o_proj
3231                + gate_up_proj
3232                + down_proj
3233        };
3234        Ok(vec![
3235            per_layer_elems * dtype.size_in_bytes();
3236            cfg.num_hidden_layers
3237        ])
3238    }
3239
3240    fn num_layers(&self, config: &str) -> Result<usize> {
3241        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3242        Ok(cfg.num_hidden_layers)
3243    }
3244
3245    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3246        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3247
3248        let cfg = ModelConfigMetadata {
3249            max_seq_len: cfg.max_position_embeddings,
3250            num_layers: cfg.num_hidden_layers,
3251            hidden_size: cfg.hidden_size,
3252            num_kv_heads: cfg.num_key_value_heads(),
3253            num_attn_heads: cfg.num_attention_heads,
3254            sliding_window: cfg.sliding_window,
3255            k_head_dim: cfg.head_dim(),
3256            v_head_dim: cfg.head_dim(),
3257        };
3258
3259        Ok(Box::new(cfg))
3260    }
3261
3262    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3263        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
3264    }
3265}
3266
3267// ======================== Qwen2_5VL Loader
3268
3269/// [`VisionLoader`] for an Qwen2_5VL model.
3270///
3271/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3272pub struct Qwen2_5VLLoader;
3273
3274pub struct Qwen2_5VLPrefixer;
3275
3276impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3277    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3278        format!(
3279            "{}{prompt}",
3280            format!(
3281                "{}{}{}",
3282                Qwen2_5VLProcessor::VISION_START,
3283                Qwen2_5VLProcessor::IMAGE_PAD,
3284                Qwen2_5VLProcessor::VISION_END
3285            )
3286            .repeat(image_indexes.len())
3287        )
3288    }
3289}
3290
3291impl VisionModelLoader for Qwen2_5VLLoader {
3292    fn load(
3293        &self,
3294        config: &str,
3295        vb: ShardedVarBuilder,
3296        normal_loading_metadata: NormalLoadingMetadata,
3297        attention_mechanism: AttentionImplementation,
3298    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3299        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3300        Ok(Box::new(Qwen2_5VLModel::new(
3301            &cfg,
3302            vb,
3303            self.is_gptx(config),
3304            normal_loading_metadata,
3305            attention_mechanism,
3306        )?))
3307    }
3308    fn is_gptx(&self, _config: &str) -> bool {
3309        true
3310    }
3311    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3312        let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3313        Ok(Box::new(config))
3314    }
3315    fn get_processor(
3316        &self,
3317        _model_config: &str,
3318        _processor_config: Option<ProcessorConfig>,
3319        _preprocessor_config: PreProcessorConfig,
3320        max_edge: Option<u32>,
3321    ) -> Arc<dyn Processor + Send + Sync> {
3322        Arc::new(Qwen2_5VLProcessor::new(max_edge))
3323    }
3324    fn supports_paged_attention(&self, _config: &str) -> bool {
3325        false
3326    }
3327    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3328        Arc::new(Qwen2_5VLPrefixer)
3329    }
3330    fn modalities(&self, _config: &str) -> Result<Modalities> {
3331        Ok(Modalities {
3332            input: vec![SupportedModality::Text, SupportedModality::Vision],
3333            output: vec![SupportedModality::Text],
3334        })
3335    }
3336}
3337
3338impl IsqModelLoader for Qwen2_5VLLoader {
3339    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3340        Ok(vec![
3341            Regex::new(r"lm_head\.(weight|bias)$")?,
3342            // Attention
3343            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3344            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3345            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3346            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3347            // MLP
3348            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3349            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3350            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3351        ])
3352    }
3353    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3354        self.isq_layer_regexes(config)
3355    }
3356}
3357
3358impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3359    fn mapped_max_act_size_elems(
3360        &self,
3361        config: &str,
3362        params: &AutoDeviceMapParams,
3363    ) -> Result<usize> {
3364        let AutoDeviceMapParams::Vision {
3365            max_seq_len,
3366            max_batch_size,
3367            max_image_shape,
3368            max_num_images,
3369        } = params
3370        else {
3371            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3372        };
3373
3374        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3375
3376        let img_seq_len = {
3377            let cfg = &cfg.vision_config;
3378            let grid_t = max_num_images / cfg.temporal_patch_size;
3379            let grid_h = max_image_shape.0 / cfg.patch_size;
3380            let grid_w = max_image_shape.1 / cfg.patch_size;
3381            grid_t * grid_h * grid_w
3382        };
3383        let img_seq_len = img_seq_len * max_num_images;
3384
3385        let max_text_attn = {
3386            // This model injects the vision information directly into the input embeddings
3387            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3388            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3389        };
3390
3391        Ok(max_text_attn)
3392    }
3393
3394    fn non_mapped_max_act_size_elems(
3395        &self,
3396        config: &str,
3397        params: &AutoDeviceMapParams,
3398    ) -> Result<usize> {
3399        let AutoDeviceMapParams::Vision {
3400            max_seq_len: _,
3401            max_batch_size,
3402            max_image_shape,
3403            max_num_images,
3404        } = params
3405        else {
3406            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3407        };
3408
3409        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3410
3411        let img_seq_len = {
3412            let cfg = &cfg.vision_config;
3413            let grid_t = max_num_images / cfg.temporal_patch_size;
3414            let grid_h = max_image_shape.0 / cfg.patch_size;
3415            let grid_w = max_image_shape.1 / cfg.patch_size;
3416            grid_t * grid_h * grid_w
3417        };
3418
3419        let max_vision_attn = {
3420            let cfg = &cfg.vision_config;
3421            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3422        };
3423
3424        Ok(max_vision_attn)
3425    }
3426
3427    fn non_mapped_size_in_bytes(
3428        &self,
3429        config: &str,
3430        dtype: DType,
3431        weight_pack_factor: usize,
3432        _matformer_config: Option<&MatformerSliceConfig>,
3433    ) -> Result<usize> {
3434        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3435        let text_elems = {
3436            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3437            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3438            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3439                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3440            } else {
3441                0
3442            };
3443            let norm = cfg.hidden_size;
3444            embed_tokens + lm_head + norm
3445        };
3446
3447        let patch_merger = {
3448            let cfg = &cfg.vision_config;
3449            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3450
3451            let mlp0 = hidden_size * hidden_size + hidden_size;
3452            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3453
3454            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3455
3456            mlp0 + mlp2 + ln_q
3457        };
3458
3459        let patch_embed = {
3460            let cfg = &cfg.vision_config;
3461            let conv_cfg = Conv3dConfig {
3462                stride: cfg.patch_size,
3463                ..Default::default()
3464            };
3465            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3466            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3467                * kernel_sizes[0]
3468                * kernel_sizes[1]
3469                * kernel_sizes[2]
3470        };
3471
3472        let encoder_layer = {
3473            let cfg = &cfg.vision_config;
3474            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3475            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3476
3477            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3478            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3479            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3480
3481            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3482            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3483
3484            norm1 + norm2 + fc1 + fc2 + qkv + out
3485        };
3486
3487        let elems =
3488            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3489
3490        Ok(elems * dtype.size_in_bytes())
3491    }
3492
3493    fn layer_sizes_in_bytes(
3494        &self,
3495        config: &str,
3496        dtype: DType,
3497        weight_pack_factor: usize,
3498        _matformer_config: Option<&MatformerSliceConfig>,
3499    ) -> Result<Vec<usize>> {
3500        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3501        let per_layer_elems = {
3502            let input_layernorm = cfg.hidden_size;
3503            let post_attention_layernorm = cfg.hidden_size;
3504
3505            let size_in = cfg.hidden_size;
3506            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3507            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3508            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3509            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3510            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3511            let o_proj = size_q * size_in / weight_pack_factor;
3512
3513            let h_size = cfg.hidden_size;
3514            let i_size = cfg.intermediate_size;
3515            let gate_proj = h_size * i_size / weight_pack_factor;
3516            let up_proj = h_size * i_size / weight_pack_factor;
3517            let down_proj = i_size * h_size / weight_pack_factor;
3518
3519            input_layernorm
3520                + post_attention_layernorm
3521                + q_proj
3522                + k_proj
3523                + v_proj
3524                + o_proj
3525                + gate_proj
3526                + up_proj
3527                + down_proj
3528        };
3529        Ok(vec![
3530            per_layer_elems * dtype.size_in_bytes();
3531            cfg.num_hidden_layers
3532        ])
3533    }
3534
3535    fn num_layers(&self, config: &str) -> Result<usize> {
3536        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3537        Ok(cfg.num_hidden_layers)
3538    }
3539
3540    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3541        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3542
3543        let cfg = ModelConfigMetadata {
3544            max_seq_len: cfg.max_position_embeddings,
3545            num_layers: cfg.num_hidden_layers,
3546            hidden_size: cfg.hidden_size,
3547            num_kv_heads: cfg.num_key_value_heads,
3548            num_attn_heads: cfg.num_attention_heads,
3549            sliding_window: cfg.sliding_window,
3550            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3551            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3552        };
3553
3554        Ok(Box::new(cfg))
3555    }
3556
3557    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3558        Some(vec![NonMappedSubModel::Vision])
3559    }
3560}
3561
3562// ======================== Gemma 3 Loader
3563
3564/// [`VisionLoader`] for an Gemma 3 model.
3565///
3566/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3567pub struct Gemma3Loader;
3568
3569pub struct Gemma3Prefixer;
3570
3571impl MultimodalPromptPrefixer for Gemma3Prefixer {
3572    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3573        prompt.to_string()
3574    }
3575}
3576
3577impl VisionModelLoader for Gemma3Loader {
3578    fn load(
3579        &self,
3580        config: &str,
3581        vb: ShardedVarBuilder,
3582        normal_loading_metadata: NormalLoadingMetadata,
3583        attention_mechanism: AttentionImplementation,
3584    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3585        let cfg: Gemma3Config = serde_json::from_str(config)?;
3586        Ok(Box::new(Gemma3Model::new(
3587            &cfg,
3588            vb,
3589            self.is_gptx(config),
3590            normal_loading_metadata,
3591            attention_mechanism,
3592        )?))
3593    }
3594    fn is_gptx(&self, _config: &str) -> bool {
3595        true
3596    }
3597    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3598        let config: Gemma3Config = serde_json::from_str(config)?;
3599        Ok(Box::new(config))
3600    }
3601    fn get_processor(
3602        &self,
3603        config: &str,
3604        processor_config: Option<ProcessorConfig>,
3605        _preprocessor_config: PreProcessorConfig,
3606        _max_edge: Option<u32>,
3607    ) -> Arc<dyn Processor + Send + Sync> {
3608        let config: Gemma3Config = serde_json::from_str(config).unwrap();
3609        // Handle the Gemma 3 1b case here
3610        Arc::new(Gemma3Processor::new(
3611            processor_config.unwrap_or_default(),
3612            matches!(config, Gemma3Config::WithVision { .. }),
3613        ))
3614    }
3615    fn supports_paged_attention(&self, _config: &str) -> bool {
3616        true
3617    }
3618    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3619        true
3620    }
3621    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3622        Arc::new(Gemma3Prefixer)
3623    }
3624    fn modalities(&self, _config: &str) -> Result<Modalities> {
3625        Ok(Modalities {
3626            input: vec![SupportedModality::Text, SupportedModality::Vision],
3627            output: vec![SupportedModality::Text],
3628        })
3629    }
3630}
3631
3632impl IsqModelLoader for Gemma3Loader {
3633    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3634        Ok(vec![
3635            Regex::new(r"lm_head\.(weight|bias)$")?,
3636            // Attention
3637            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3638            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3639            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3640            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3641            // MLP
3642            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3643            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3644            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3645        ])
3646    }
3647    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3648        Ok(vec![
3649            Regex::new(r"lm_head\.(weight|bias)$")?,
3650            // Attention
3651            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3652            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3653            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3654            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3655            // MLP
3656            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3657            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3658            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3659        ])
3660    }
3661}
3662
3663impl DeviceMappedModelLoader for Gemma3Loader {
3664    fn mapped_max_act_size_elems(
3665        &self,
3666        config: &str,
3667        params: &AutoDeviceMapParams,
3668    ) -> Result<usize> {
3669        let AutoDeviceMapParams::Vision {
3670            max_seq_len,
3671            max_batch_size,
3672            max_image_shape: _,
3673            max_num_images,
3674        } = params
3675        else {
3676            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3677        };
3678
3679        let cfg: Gemma3Config = serde_json::from_str(config)?;
3680
3681        match cfg {
3682            Gemma3Config::Text(text_config) => Ok(max_batch_size
3683                * text_config.num_attention_heads
3684                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)),
3685            Gemma3Config::WithVision {
3686                text_config,
3687                vision_config,
3688                ..
3689            } => {
3690                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3691                let img_seq_len = (num_patches + 1) * max_num_images;
3692
3693                let max_text_attn = {
3694                    // This model injects the vision information directly into the input embeddings
3695                    let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3696                    max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3697                };
3698                Ok(max_text_attn)
3699            }
3700        }
3701    }
3702
3703    fn non_mapped_max_act_size_elems(
3704        &self,
3705        config: &str,
3706        params: &AutoDeviceMapParams,
3707    ) -> Result<usize> {
3708        let AutoDeviceMapParams::Vision {
3709            max_seq_len: _,
3710            max_batch_size,
3711            max_image_shape: _,
3712            max_num_images,
3713        } = params
3714        else {
3715            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3716        };
3717
3718        let cfg: Gemma3Config = serde_json::from_str(config)?;
3719
3720        match cfg {
3721            Gemma3Config::WithVision { vision_config, .. } => {
3722                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3723                let img_seq_len = num_patches + 1;
3724
3725                let max_vision_attn = {
3726                    (max_batch_size * max_num_images)
3727                        * vision_config.num_attention_heads
3728                        * img_seq_len
3729                        * img_seq_len
3730                };
3731
3732                Ok(max_vision_attn)
3733            }
3734            Gemma3Config::Text(_) => Ok(0),
3735        }
3736    }
3737
3738    fn non_mapped_size_in_bytes(
3739        &self,
3740        config: &str,
3741        dtype: DType,
3742        weight_pack_factor: usize,
3743        _matformer_config: Option<&MatformerSliceConfig>,
3744    ) -> Result<usize> {
3745        let cfg: Gemma3Config = serde_json::from_str(config)?;
3746
3747        let text_elems = {
3748            let cfg = match &cfg {
3749                Gemma3Config::Text(cfg) => cfg,
3750                Gemma3Config::WithVision { text_config, .. } => text_config,
3751            };
3752            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3753            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3754            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3755                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3756            } else {
3757                0
3758            };
3759            let norm = cfg.hidden_size;
3760            embed_tokens + lm_head + norm
3761        };
3762
3763        let vision_transformer = if let Gemma3Config::WithVision {
3764            vision_config: cfg, ..
3765        } = &cfg
3766        {
3767            let post_layernorm = cfg.hidden_size;
3768
3769            let conv_config = Conv2dConfig {
3770                stride: cfg.patch_size,
3771                ..Default::default()
3772            };
3773            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3774                * cfg.patch_size
3775                * cfg.patch_size;
3776
3777            let num_patches_per_side = cfg.image_size / cfg.patch_size;
3778            let num_patches = num_patches_per_side.pow(2);
3779            let position_embedding = num_patches * cfg.hidden_size;
3780
3781            let layer_elems = {
3782                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3783                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3784
3785                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3786                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3787
3788                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3789                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3790                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3791                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3792
3793                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3794            };
3795
3796            post_layernorm
3797                + patch_embedding
3798                + position_embedding
3799                + layer_elems * cfg.num_hidden_layers
3800        } else {
3801            0
3802        };
3803
3804        let elems = text_elems + vision_transformer;
3805
3806        Ok(elems * dtype.size_in_bytes())
3807    }
3808
3809    fn layer_sizes_in_bytes(
3810        &self,
3811        config: &str,
3812        dtype: DType,
3813        weight_pack_factor: usize,
3814        _matformer_config: Option<&MatformerSliceConfig>,
3815    ) -> Result<Vec<usize>> {
3816        let cfg: Gemma3Config = serde_json::from_str(config)?;
3817
3818        let txt_cfg = match &cfg {
3819            Gemma3Config::Text(cfg) => cfg,
3820            Gemma3Config::WithVision { text_config, .. } => text_config,
3821        };
3822        let per_layer_elems = {
3823            let cfg = txt_cfg;
3824
3825            let input_layernorm = cfg.hidden_size;
3826            let post_attention_layernorm = cfg.hidden_size;
3827
3828            let size_in = cfg.hidden_size;
3829            let size_q = cfg.head_dim * cfg.num_attention_heads;
3830            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3831            let q_proj =
3832                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3833            let k_proj =
3834                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3835            let v_proj =
3836                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3837            let o_proj =
3838                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3839
3840            let h_size = cfg.hidden_size;
3841            let i_size = cfg.intermediate_size;
3842            let gate_proj = h_size * i_size / weight_pack_factor;
3843            let up_proj = h_size * i_size / weight_pack_factor;
3844            let down_proj = i_size * h_size / weight_pack_factor;
3845
3846            input_layernorm
3847                + post_attention_layernorm
3848                + q_proj
3849                + k_proj
3850                + v_proj
3851                + o_proj
3852                + gate_proj
3853                + up_proj
3854                + down_proj
3855        };
3856        Ok(vec![
3857            per_layer_elems * dtype.size_in_bytes();
3858            txt_cfg.num_hidden_layers
3859        ])
3860    }
3861
3862    fn num_layers(&self, config: &str) -> Result<usize> {
3863        let cfg: Gemma3Config = serde_json::from_str(config)?;
3864
3865        let txt_cfg = match &cfg {
3866            Gemma3Config::Text(cfg) => cfg,
3867            Gemma3Config::WithVision { text_config, .. } => text_config,
3868        };
3869
3870        Ok(txt_cfg.num_hidden_layers)
3871    }
3872
3873    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3874        let cfg: Gemma3Config = serde_json::from_str(config)?;
3875
3876        let cfg = match &cfg {
3877            Gemma3Config::Text(cfg) => cfg,
3878            Gemma3Config::WithVision { text_config, .. } => text_config,
3879        };
3880
3881        let cfg = ModelConfigMetadata {
3882            max_seq_len: cfg.max_position_embeddings,
3883            num_layers: cfg.num_hidden_layers,
3884            hidden_size: cfg.hidden_size,
3885            num_kv_heads: cfg.num_key_value_heads,
3886            num_attn_heads: cfg.num_attention_heads,
3887            sliding_window: None, // None to be more forgiving, some do not
3888            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3889            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3890        };
3891
3892        Ok(Box::new(cfg))
3893    }
3894
3895    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3896        Some(vec![NonMappedSubModel::Vision])
3897    }
3898}
3899
3900// ======================== Mistral 3 Loader
3901
3902/// [`VisionLoader`] for an Mistral 3 model.
3903///
3904/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3905pub struct Mistral3Loader;
3906
3907pub struct Mistral3Prefixer;
3908
3909impl MultimodalPromptPrefixer for Mistral3Prefixer {
3910    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3911        prompt.to_string()
3912    }
3913}
3914
3915impl VisionModelLoader for Mistral3Loader {
3916    fn load(
3917        &self,
3918        config: &str,
3919        vb: ShardedVarBuilder,
3920        normal_loading_metadata: NormalLoadingMetadata,
3921        attention_mechanism: AttentionImplementation,
3922    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3923        let mut cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3924        cfg.propagate_quantization_config();
3925        Ok(Box::new(Mistral3Model::new(
3926            &cfg,
3927            vb,
3928            self.is_gptx(config),
3929            normal_loading_metadata,
3930            attention_mechanism,
3931        )?))
3932    }
3933    fn is_gptx(&self, _config: &str) -> bool {
3934        true
3935    }
3936    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3937        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3938        Ok(Box::new(cfg))
3939    }
3940    fn get_processor(
3941        &self,
3942        _model_config: &str,
3943        processor_config: Option<ProcessorConfig>,
3944        _preprocessor_config: PreProcessorConfig,
3945        _max_edge: Option<u32>,
3946    ) -> Arc<dyn Processor + Send + Sync> {
3947        Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3948    }
3949    fn supports_paged_attention(&self, _config: &str) -> bool {
3950        true
3951    }
3952    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3953        true
3954    }
3955    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3956        Arc::new(Mistral3Prefixer)
3957    }
3958    fn modalities(&self, _config: &str) -> Result<Modalities> {
3959        Ok(Modalities {
3960            input: vec![SupportedModality::Text, SupportedModality::Vision],
3961            output: vec![SupportedModality::Text],
3962        })
3963    }
3964}
3965
3966impl IsqModelLoader for Mistral3Loader {
3967    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3968        Ok(vec![
3969            Regex::new(r"lm_head\.(weight|bias)$")?,
3970            // Attention
3971            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3972            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3973            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3974            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3975            // MLP
3976            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3977            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3978            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3979        ])
3980    }
3981    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3982        Ok(vec![
3983            Regex::new(r"lm_head\.(weight|bias)$")?,
3984            // Attention
3985            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3986            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3987            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3988            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3989            // MLP
3990            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3991            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3992            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3993        ])
3994    }
3995}
3996
3997#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3998impl DeviceMappedModelLoader for Mistral3Loader {
3999    fn mapped_max_act_size_elems(
4000        &self,
4001        config: &str,
4002        params: &AutoDeviceMapParams,
4003    ) -> Result<usize> {
4004        let cfg: Mistral3Config = serde_json::from_str(config)?;
4005        let vcfg = &cfg.vision_config;
4006        let tcfg = &cfg.text_config;
4007
4008        let AutoDeviceMapParams::Vision {
4009            max_seq_len,
4010            max_batch_size,
4011            max_image_shape: (mut height, mut width),
4012            max_num_images,
4013        } = params
4014        else {
4015            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4016        };
4017
4018        let img_seq_len = {
4019            // Reshaping algorithm
4020
4021            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
4022            let (max_height, max_width) = (1540, 1540);
4023            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4024            if ratio > 1. {
4025                height = (height as f64 / ratio).floor() as usize;
4026                width = (width as f64 / ratio).floor() as usize;
4027            }
4028
4029            let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
4030            let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
4031
4032            height = num_height_tokens * vcfg.patch_size;
4033            width = num_width_tokens * vcfg.patch_size;
4034
4035            let num_height_tokens = height / vcfg.patch_size;
4036            let num_width_tokens = width / vcfg.patch_size;
4037
4038            (num_width_tokens + 1) * num_height_tokens
4039        };
4040
4041        // This model injects the vision information directly into the input embeddings
4042        let max_seq_len = img_seq_len * max_num_images + *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4043        Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
4044    }
4045
4046    fn non_mapped_max_act_size_elems(
4047        &self,
4048        config: &str,
4049        params: &AutoDeviceMapParams,
4050    ) -> Result<usize> {
4051        let cfg: Mistral3Config = serde_json::from_str(config)?;
4052        let cfg = &cfg.vision_config;
4053
4054        let AutoDeviceMapParams::Vision {
4055            max_seq_len: _,
4056            max_batch_size,
4057            max_image_shape: (mut height, mut width),
4058            max_num_images,
4059        } = params
4060        else {
4061            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4062        };
4063
4064        let img_seq_len = {
4065            // Reshaping algorithm
4066
4067            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
4068            let (max_height, max_width) = (1540, 1540);
4069            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4070            if ratio > 1. {
4071                height = (height as f64 / ratio).floor() as usize;
4072                width = (width as f64 / ratio).floor() as usize;
4073            }
4074
4075            let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4076            let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4077
4078            height = num_height_tokens * cfg.patch_size;
4079            width = num_width_tokens * cfg.patch_size;
4080
4081            let num_height_tokens = height / cfg.patch_size;
4082            let num_width_tokens = width / cfg.patch_size;
4083
4084            (num_width_tokens + 1) * num_height_tokens
4085        };
4086
4087        Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4088    }
4089
4090    fn non_mapped_size_in_bytes(
4091        &self,
4092        config: &str,
4093        dtype: DType,
4094        weight_pack_factor: usize,
4095        _matformer_config: Option<&MatformerSliceConfig>,
4096    ) -> Result<usize> {
4097        let cfg: Mistral3Config = serde_json::from_str(config)?;
4098
4099        let text_elems = {
4100            let cfg = &cfg.text_config;
4101
4102            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4103            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4104            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4105                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4106            } else {
4107                0
4108            };
4109            let norm = cfg.hidden_size;
4110            embed_tokens + lm_head + norm
4111        };
4112
4113        let vision_elems = {
4114            let cfg = &cfg.vision_config;
4115
4116            let patch_embed = {
4117                let conv_cfg = Conv2dConfig {
4118                    stride: cfg.patch_size,
4119                    ..Default::default()
4120                };
4121                cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4122                    * cfg.patch_size
4123                    * cfg.patch_size
4124                    * cfg.patch_size
4125            };
4126            let ln_pre = cfg.hidden_size;
4127            let vision_layer = {
4128                let attn_norm = cfg.hidden_size;
4129                let ffn_norm = cfg.hidden_size;
4130
4131                let gate = cfg.hidden_size * cfg.intermediate_size;
4132                let up = cfg.hidden_size * cfg.intermediate_size;
4133                let down = cfg.hidden_size * cfg.intermediate_size;
4134
4135                let q = cfg.hidden_size * cfg.hidden_size;
4136                let k = cfg.hidden_size * cfg.hidden_size;
4137                let v = cfg.hidden_size * cfg.hidden_size;
4138                let o = cfg.hidden_size * cfg.hidden_size;
4139
4140                attn_norm + ffn_norm + gate + up + down + q + k + v + o
4141            };
4142
4143            patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4144        };
4145
4146        let elems = text_elems + vision_elems;
4147
4148        Ok(elems * dtype.size_in_bytes())
4149    }
4150
4151    fn layer_sizes_in_bytes(
4152        &self,
4153        config: &str,
4154        dtype: DType,
4155        weight_pack_factor: usize,
4156        _matformer_config: Option<&MatformerSliceConfig>,
4157    ) -> Result<Vec<usize>> {
4158        let cfg: Mistral3Config = serde_json::from_str(config)?;
4159        let cfg = &cfg.text_config;
4160
4161        let per_layer_elems = {
4162            let input_layernorm = cfg.hidden_size;
4163            let post_attention_layernorm = cfg.hidden_size;
4164
4165            let size_in = cfg.hidden_size;
4166            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4167            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4168            let q_proj = size_in * size_q / weight_pack_factor;
4169            let k_proj = size_in * size_kv / weight_pack_factor;
4170            let v_proj = size_in * size_kv / weight_pack_factor;
4171            let o_proj = size_q * size_in / weight_pack_factor;
4172
4173            let h_size = cfg.hidden_size;
4174            let i_size = cfg.intermediate_size;
4175            let gate_proj = h_size * i_size / weight_pack_factor;
4176            let up_proj = h_size * i_size / weight_pack_factor;
4177            let down_proj = i_size * h_size / weight_pack_factor;
4178
4179            input_layernorm
4180                + post_attention_layernorm
4181                + q_proj
4182                + k_proj
4183                + v_proj
4184                + o_proj
4185                + gate_proj
4186                + up_proj
4187                + down_proj
4188        };
4189        Ok(vec![
4190            per_layer_elems * dtype.size_in_bytes();
4191            cfg.num_hidden_layers
4192        ])
4193    }
4194
4195    fn num_layers(&self, config: &str) -> Result<usize> {
4196        let cfg: Mistral3Config = serde_json::from_str(config)?;
4197        let cfg = &cfg.text_config;
4198        Ok(cfg.num_hidden_layers)
4199    }
4200
4201    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4202        let cfg: Mistral3Config = serde_json::from_str(config)?;
4203        let cfg = &cfg.text_config;
4204
4205        let cfg = ModelConfigMetadata {
4206            max_seq_len: cfg.max_position_embeddings,
4207            num_layers: cfg.num_hidden_layers,
4208            hidden_size: cfg.hidden_size,
4209            num_kv_heads: cfg.num_key_value_heads,
4210            num_attn_heads: cfg.num_attention_heads,
4211            sliding_window: cfg.sliding_window,
4212            k_head_dim: cfg.head_dim(),
4213            v_head_dim: cfg.head_dim(),
4214        };
4215
4216        Ok(Box::new(cfg))
4217    }
4218
4219    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4220        Some(vec![NonMappedSubModel::Vision])
4221    }
4222}
4223
4224// ======================== Llama 4 Loader
4225
4226/// [`VisionLoader`] for an Llama Vision model.
4227///
4228/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
4229pub struct VLlama4Loader;
4230
4231pub struct VLlama4Prefixer;
4232
4233impl MultimodalPromptPrefixer for VLlama4Prefixer {
4234    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4235        format!(
4236            "{}{prompt}",
4237            llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4238        )
4239    }
4240}
4241
4242impl VisionModelLoader for VLlama4Loader {
4243    fn load(
4244        &self,
4245        config: &str,
4246        vb: ShardedVarBuilder,
4247        normal_loading_metadata: NormalLoadingMetadata,
4248        attention_mechanism: AttentionImplementation,
4249    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4250        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4251        Ok(Box::new(Llama4Model::new(
4252            &cfg,
4253            vb,
4254            self.is_gptx(config),
4255            normal_loading_metadata,
4256            attention_mechanism,
4257        )?))
4258    }
4259    fn is_gptx(&self, _config: &str) -> bool {
4260        false
4261    }
4262    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4263        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4264        Ok(Box::new(cfg))
4265    }
4266    fn get_processor(
4267        &self,
4268        _model_config: &str,
4269        processor_config: Option<ProcessorConfig>,
4270        _preprocessor_config: PreProcessorConfig,
4271        _max_edge: Option<u32>,
4272    ) -> Arc<dyn Processor + Send + Sync> {
4273        Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4274    }
4275    fn supports_paged_attention(&self, _config: &str) -> bool {
4276        true
4277    }
4278    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4279        Arc::new(VLlama4Prefixer)
4280    }
4281    fn modalities(&self, _config: &str) -> Result<Modalities> {
4282        Ok(Modalities {
4283            input: vec![SupportedModality::Text, SupportedModality::Vision],
4284            output: vec![SupportedModality::Text],
4285        })
4286    }
4287}
4288
4289impl IsqModelLoader for VLlama4Loader {
4290    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4291        Ok(vec![
4292            Regex::new(r"lm_head\.(weight|bias)$")?,
4293            // Attention
4294            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4295            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4296            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4297            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4298            // FF MoE
4299            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4300            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4301            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4302            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4303            Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4304            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4305            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4306            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4307            // FF MLP
4308            Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4309            Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4310            Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4311        ])
4312    }
4313    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4314        Ok(vec![
4315            Regex::new(r"lm_head\.(weight|bias)$")?,
4316            // Attention
4317            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4318            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4319            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4320            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4321            // FF MoE
4322            Regex::new(
4323                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4324            )?,
4325            Regex::new(
4326                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4327            )?,
4328            Regex::new(
4329                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4330            )?,
4331            Regex::new(
4332                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4333            )?,
4334            Regex::new(
4335                r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4336            )?,
4337            Regex::new(
4338                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4339            )?,
4340            Regex::new(
4341                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4342            )?,
4343            Regex::new(
4344                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4345            )?,
4346            // FF MLP
4347            Regex::new(
4348                r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4349            )?,
4350            Regex::new(
4351                r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4352            )?,
4353            Regex::new(
4354                r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4355            )?,
4356        ])
4357    }
4358}
4359
4360impl VLlama4Loader {
4361    /// This incorporates the max batch size!
4362    /// Returns (pixels max batch size, num text image tokens)
4363    #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4364    fn run_dummy_processing(
4365        &self,
4366        cfg: &Llama4Config,
4367        height: usize,
4368        width: usize,
4369        max_num_images: usize,
4370        max_batch_size: usize,
4371    ) -> Result<(usize, usize)> {
4372        let cfg = &cfg.vision_config;
4373
4374        let img_processor =
4375            Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4376        let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4377        let res = img_processor.preprocess(
4378            vec![image; max_num_images],
4379            vec![],
4380            &PreProcessorConfig::default(),
4381            &Device::Cpu,
4382            (max_batch_size, max_num_images),
4383        )?;
4384
4385        let pixels_batch_size = res.pixel_values.dim(0)?;
4386        let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4387
4388        let (image_h, image_w) = (
4389            res.pixel_values.dim(D::Minus2).unwrap(),
4390            res.pixel_values.dim(D::Minus1).unwrap(),
4391        );
4392        let num_patches_per_chunk = (image_h / img_processor.patch_size)
4393            * (image_w / img_processor.patch_size)
4394            / img_processor.downsample_ratio;
4395
4396        Ok((
4397            pixels_max_batch_size,
4398            num_patches_per_chunk * pixels_max_batch_size,
4399        ))
4400    }
4401}
4402
4403impl DeviceMappedModelLoader for VLlama4Loader {
4404    fn mapped_max_act_size_elems(
4405        &self,
4406        config: &str,
4407        params: &AutoDeviceMapParams,
4408    ) -> Result<usize> {
4409        let AutoDeviceMapParams::Vision {
4410            max_seq_len,
4411            max_batch_size,
4412            max_image_shape: (height, width),
4413            max_num_images,
4414        } = params
4415        else {
4416            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4417        };
4418
4419        let cfg: Llama4Config = serde_json::from_str(config)?;
4420
4421        let (_pixels_batch_size, num_text_image_toks) =
4422            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4423
4424        let max_seq_len = max_seq_len.min(&ATTENTION_CHUNK_SIZE) + num_text_image_toks;
4425
4426        Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4427    }
4428    fn non_mapped_max_act_size_elems(
4429        &self,
4430        config: &str,
4431        params: &AutoDeviceMapParams,
4432    ) -> Result<usize> {
4433        let AutoDeviceMapParams::Vision {
4434            max_seq_len: _,
4435            max_batch_size,
4436            max_image_shape: (height, width),
4437            max_num_images,
4438        } = params
4439        else {
4440            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4441        };
4442
4443        let cfg: Llama4Config = serde_json::from_str(config)?;
4444
4445        let (pixels_batch_size, _num_text_image_toks) =
4446            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4447        let max_seq_len = cfg.vision_config.num_patches();
4448
4449        Ok((max_batch_size * pixels_batch_size)
4450            * cfg.vision_config.num_attention_heads
4451            * max_seq_len
4452            * max_seq_len)
4453    }
4454
4455    fn non_mapped_size_in_bytes(
4456        &self,
4457        config: &str,
4458        dtype: DType,
4459        weight_pack_factor: usize,
4460        _matformer_config: Option<&MatformerSliceConfig>,
4461    ) -> Result<usize> {
4462        let cfg: Llama4Config = serde_json::from_str(config)?;
4463        let tcfg = &cfg.text_config;
4464
4465        let text_elems = {
4466            let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4467            let lm_head = if !tcfg.tie_word_embeddings {
4468                tcfg.hidden_size * tcfg.vocab_size
4469            } else {
4470                0
4471            };
4472            let norm = tcfg.hidden_size;
4473            embed_tokens + lm_head + norm
4474        };
4475
4476        let vision_elems = {
4477            let cfg = &cfg.vision_config;
4478
4479            let num_patches = cfg.num_patches();
4480
4481            let unfold_elems =
4482                (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4483            let class_embeddng_elems = cfg.hidden_size;
4484            let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4485            let layernorm_pre_elems = cfg.hidden_size;
4486            let layernorm_post_elems = cfg.hidden_size;
4487
4488            let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4489                / weight_pack_factor
4490                + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4491
4492            let encoder_layer = {
4493                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4494                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4495
4496                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4497                let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4498                    / weight_pack_factor
4499                    + cfg.num_attention_heads * head_dim;
4500                let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4501                    / weight_pack_factor
4502                    + cfg.num_attention_heads * head_dim;
4503                let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4504                    / weight_pack_factor
4505                    + cfg.num_attention_heads * head_dim;
4506                let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4507                    / weight_pack_factor
4508                    + cfg.num_attention_heads * head_dim;
4509
4510                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4511                    + cfg.intermediate_size;
4512                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4513                    + cfg.hidden_size;
4514
4515                input_layernorm
4516                    + post_attention_layernorm
4517                    + q_proj
4518                    + k_proj
4519                    + v_proj
4520                    + o_proj
4521                    + fc1
4522                    + fc2
4523            };
4524
4525            unfold_elems
4526                + class_embeddng_elems
4527                + positional_embedding_vlm_elems
4528                + layernorm_post_elems
4529                + layernorm_pre_elems
4530                + pixel_shuffle_elems
4531                + encoder_layer * cfg.num_hidden_layers
4532        };
4533
4534        let elems = text_elems + vision_elems;
4535
4536        Ok(elems * dtype.size_in_bytes())
4537    }
4538
4539    fn layer_sizes_in_bytes(
4540        &self,
4541        config: &str,
4542        dtype: DType,
4543        weight_pack_factor: usize,
4544        _matformer_config: Option<&MatformerSliceConfig>,
4545    ) -> Result<Vec<usize>> {
4546        let cfg: Llama4Config = serde_json::from_str(config)?;
4547        let tcfg = &cfg.text_config;
4548
4549        let mut per_layer_elems = Vec::new();
4550
4551        for layer_idx in 0..tcfg.num_hidden_layers {
4552            let input_layernorm = tcfg.hidden_size;
4553            let post_attention_layernorm = tcfg.hidden_size;
4554
4555            let size_in = tcfg.hidden_size;
4556            let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4557            let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4558            let q_proj = size_in * size_q / weight_pack_factor;
4559            let k_proj = size_in * size_kv / weight_pack_factor;
4560            let v_proj = size_in * size_kv / weight_pack_factor;
4561            let o_proj = size_q * size_in / weight_pack_factor;
4562
4563            let use_moe = tcfg.moe_layers().contains(&layer_idx);
4564            let moe_block = if use_moe {
4565                let h_size = tcfg.hidden_size;
4566                let i_size = tcfg.intermediate_size;
4567                let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4568                let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4569                let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4570
4571                gate_proj + up_proj + down_proj
4572            } else {
4573                let h_size = tcfg.hidden_size;
4574                let i_size = tcfg.intermediate_size_mlp;
4575                let gate_proj = h_size * i_size / weight_pack_factor;
4576                let up_proj = h_size * i_size / weight_pack_factor;
4577                let down_proj = i_size * h_size / weight_pack_factor;
4578
4579                gate_proj + up_proj + down_proj
4580            };
4581
4582            per_layer_elems.push(
4583                input_layernorm
4584                    + post_attention_layernorm
4585                    + q_proj
4586                    + k_proj
4587                    + v_proj
4588                    + o_proj
4589                    + moe_block,
4590            );
4591        }
4592
4593        Ok(per_layer_elems
4594            .into_iter()
4595            .map(|x| x * dtype.size_in_bytes())
4596            .collect())
4597    }
4598
4599    fn num_layers(&self, config: &str) -> Result<usize> {
4600        let cfg: Llama4Config = serde_json::from_str(config)?;
4601        Ok(cfg.text_config.num_hidden_layers)
4602    }
4603
4604    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4605        let cfg: Llama4Config = serde_json::from_str(config)?;
4606        let cfg = &cfg.text_config;
4607
4608        let cfg = ModelConfigMetadata {
4609            max_seq_len: cfg.max_position_embeddings,
4610            num_layers: cfg.num_hidden_layers,
4611            hidden_size: cfg.hidden_size,
4612            num_kv_heads: cfg.num_attention_heads,
4613            num_attn_heads: cfg.num_attention_heads,
4614            sliding_window: None,
4615            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4616            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4617        };
4618
4619        Ok(Box::new(cfg))
4620    }
4621
4622    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4623        Some(vec![NonMappedSubModel::Vision])
4624    }
4625}
4626
4627// ======================== Gemma 3n Loader
4628
4629/// [`VisionLoader`] for an Gemma 3n model.
4630///
4631/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
4632pub struct Gemma3nLoader;
4633
4634#[allow(dead_code)]
4635pub struct Gemma3nPrefixer;
4636
4637impl MultimodalPromptPrefixer for Gemma3nPrefixer {
4638    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4639        prompt.to_string()
4640    }
4641}
4642
4643impl VisionModelLoader for Gemma3nLoader {
4644    fn load(
4645        &self,
4646        config: &str,
4647        vb: ShardedVarBuilder,
4648        normal_loading_metadata: NormalLoadingMetadata,
4649        attention_mechanism: AttentionImplementation,
4650    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4651        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4652        Ok(Box::new(Gemma3nModel::new(
4653            &cfg,
4654            vb,
4655            self.is_gptx(config),
4656            normal_loading_metadata,
4657            attention_mechanism,
4658        )?))
4659    }
4660    fn is_gptx(&self, _config: &str) -> bool {
4661        true
4662    }
4663    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4664        let config: Gemma3nConfig = serde_json::from_str(config)?;
4665        Ok(Box::new(config))
4666    }
4667    fn get_processor(
4668        &self,
4669        _config: &str,
4670        processor_config: Option<ProcessorConfig>,
4671        _preprocessor_config: PreProcessorConfig,
4672        _max_edge: Option<u32>,
4673    ) -> Arc<dyn Processor + Send + Sync> {
4674        // Handle the Gemma 3 1b case here
4675        Arc::new(Gemma3nProcessor::new(
4676            processor_config.unwrap_or_default(),
4677            true,
4678        ))
4679    }
4680    fn supports_paged_attention(&self, _config: &str) -> bool {
4681        false
4682    }
4683    fn supports_prefix_cacher(&self, _config: &str) -> bool {
4684        true
4685    }
4686    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4687        Arc::new(Gemma3Prefixer)
4688    }
4689    fn modalities(&self, _config: &str) -> Result<Modalities> {
4690        Ok(Modalities {
4691            input: vec![
4692                SupportedModality::Text,
4693                SupportedModality::Vision,
4694                SupportedModality::Audio,
4695            ],
4696            output: vec![SupportedModality::Text],
4697        })
4698    }
4699}
4700
4701impl IsqModelLoader for Gemma3nLoader {
4702    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4703        Ok(vec![
4704            Regex::new(r"lm_head\.(weight|bias)$")?,
4705            // Language model attention
4706            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4707            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4708            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4709            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4710            // Language model MLP
4711            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4712            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4713            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4714            // Audio conformer attention layers
4715            Regex::new(r"conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$")?,
4716            Regex::new(r"conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$")?,
4717            Regex::new(r"conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$")?,
4718            Regex::new(
4719                r"conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4720            )?,
4721            Regex::new(r"conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4722            // Audio conformer FFW layers
4723            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$")?,
4724            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$")?,
4725            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$")?,
4726            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$")?,
4727            // Audio conformer conv1d layers
4728            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$")?,
4729            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$")?,
4730            // Audio subsample projection
4731            Regex::new(r"subsample_conv_projection\.input_proj_linear\.(weight|bias)$")?,
4732            // Multimodal embedders
4733            Regex::new(r"embed_vision\.embedding_projection\.(weight|bias)$")?,
4734            Regex::new(r"embed_audio\.embedding_projection\.(weight|bias)$")?,
4735        ])
4736    }
4737    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4738        Ok(vec![
4739            Regex::new(r"lm_head\.(weight|bias)$")?,
4740            // Language model attention
4741            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4742            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4743            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4744            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4745            // Language model MLP
4746            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4747            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4748            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4749            // Projections
4750            Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
4751            Regex::new(r"model\.language_model\.altup_projections\.(\d+)\.(weight|bias)$")?,
4752            Regex::new(r"model\.language_model\.altup_unembed_projections\.(\d+)\.(weight|bias)$")?,
4753            // Audio conformer attention layers
4754            Regex::new(
4755                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$",
4756            )?,
4757            Regex::new(
4758                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$",
4759            )?,
4760            Regex::new(
4761                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$",
4762            )?,
4763            Regex::new(
4764                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4765            )?,
4766            Regex::new(r"model\.audio_tower\.conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4767            // Audio conformer FFW layers
4768            Regex::new(
4769                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$",
4770            )?,
4771            Regex::new(
4772                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$",
4773            )?,
4774            Regex::new(
4775                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$",
4776            )?,
4777            Regex::new(
4778                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$",
4779            )?,
4780            // Audio conformer conv1d layers
4781            Regex::new(
4782                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$",
4783            )?,
4784            Regex::new(
4785                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$",
4786            )?,
4787            // Audio subsample projection
4788            Regex::new(
4789                r"model\.audio_tower\.subsample_conv_projection\.input_proj_linear\.(weight|bias)$",
4790            )?,
4791            // Multimodal embedders
4792            Regex::new(r"model\.embed_vision\.embedding_projection\.(weight|bias)$")?,
4793            Regex::new(r"model\.embed_audio\.embedding_projection\.(weight|bias)$")?,
4794        ])
4795    }
4796}
4797
4798impl DeviceMappedModelLoader for Gemma3nLoader {
4799    fn mapped_max_act_size_elems(
4800        &self,
4801        config: &str,
4802        params: &AutoDeviceMapParams,
4803    ) -> Result<usize> {
4804        let AutoDeviceMapParams::Vision {
4805            max_seq_len,
4806            max_batch_size,
4807            max_image_shape: _,
4808            max_num_images,
4809        } = params
4810        else {
4811            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4812        };
4813
4814        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4815        let text_cfg = &cfg.text_config;
4816
4817        // Gemma3n is an "inject into the prompt" model, similar to Gemma3
4818        // We need to account for vision and audio tokens in the sequence length
4819
4820        let mut total_seq_len = *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4821
4822        // Add vision tokens
4823        {
4824            // Vision tokens are injected into the prompt
4825            // MSFA outputs fixed 16x16 features regardless of input size
4826            let msfa_spatial_size = 16; // Fixed from vision.rs line 1115
4827            let vision_tokens_per_image = msfa_spatial_size * msfa_spatial_size; // 256 tokens
4828            total_seq_len += vision_tokens_per_image * max_num_images;
4829        }
4830
4831        // Add audio tokens
4832        {
4833            // Audio tokens are injected into the prompt
4834            // From config field audio_soft_tokens_per_image (typically 188)
4835            let audio_tokens = cfg.audio_soft_tokens_per_image;
4836            total_seq_len += audio_tokens;
4837        }
4838
4839        // Calculate max attention size for text model with all injected tokens
4840        let max_text_attn =
4841            max_batch_size * text_cfg.num_attention_heads * total_seq_len * total_seq_len;
4842
4843        Ok(max_text_attn)
4844    }
4845
4846    fn non_mapped_max_act_size_elems(
4847        &self,
4848        config: &str,
4849        params: &AutoDeviceMapParams,
4850    ) -> Result<usize> {
4851        let AutoDeviceMapParams::Vision {
4852            max_seq_len: _,
4853            max_batch_size,
4854            max_image_shape: _,
4855            max_num_images,
4856        } = params
4857        else {
4858            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4859        };
4860
4861        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4862
4863        // Calculate max activation sizes for each modality
4864        let mut max_activation = 0;
4865
4866        // Vision activation size
4867        {
4868            // Vision is Gemma3n's MobileNetV5 architecture with Multi-Query Attention
4869            // The peak activation is in the Multi-Query Attention layers
4870
4871            // From the architecture: stages 3 and 4 have MMQA blocks
4872            // Input images are 768x768 (from inputs_processor.rs)
4873            // Stage 3: 640 channels at 48x48 (768/16 downsampling), MMQA with num_heads=12, kv_dim=64
4874            // Stage 4: 1280 channels at 24x24 (768/32 downsampling), MMQA with num_heads=16, kv_dim=96
4875            // MSFA output: 2048 channels at fixed 16x16
4876
4877            let vision_tower_act = {
4878                // Peak is during MMQA attention computation in stage 4
4879                // Stage 4 has higher memory usage than Stage 3 due to more heads (16 vs 12)
4880                // From vision.rs: Stage 4 has num_heads=16, kv_dim=96, kv_stride=1
4881                let num_heads = 16; // Stage 4 configuration
4882                let spatial_size = 24; // 768 / 32 = 24 (input 768x768, stage 4 has 32x downsampling)
4883                let seq_len = spatial_size * spatial_size;
4884
4885                // Attention scores: [B * num_images, num_heads, seq_len, seq_len]
4886                max_batch_size * max_num_images * num_heads * seq_len * seq_len
4887            };
4888
4889            // Vision embedder activations
4890            let vision_embed_act = {
4891                // MSFA output: 2048 channels at fixed 16x16 spatial (from vision.rs line 1115)
4892                let msfa_channels = 2048; // MSFA_OUT_CHANNELS from vision.rs
4893                let spatial_size = 16; // Fixed output resolution from MSFA
4894                let vision_features =
4895                    max_batch_size * max_num_images * msfa_channels * spatial_size * spatial_size;
4896
4897                // After embedding projection to text hidden size
4898                let projected = max_batch_size
4899                    * max_num_images
4900                    * spatial_size
4901                    * spatial_size
4902                    * cfg.text_config.hidden_size;
4903
4904                vision_features.max(projected)
4905            };
4906
4907            max_activation = max_activation.max(vision_tower_act).max(vision_embed_act);
4908        }
4909
4910        // Audio activation size
4911        {
4912            let audio_cfg = &cfg.audio_config;
4913
4914            // Calculate max audio sequence length based on config
4915            // Audio uses conformer with subsampling and reduction
4916
4917            // A rough estimate of max_audio_frames
4918            let max_audio_frames = 1280;
4919
4920            let subsample_factor: usize = audio_cfg
4921                .sscp_conv_stride_size
4922                .iter()
4923                .map(|stride| stride[0]) // Time dimension stride
4924                .product();
4925            let audio_seq_after_subsample = max_audio_frames / subsample_factor;
4926
4927            // Audio encoder activations
4928            let audio_encoder_act = {
4929                // Conformer FFW layers have expansion factor from config
4930                let intermediate_size = audio_cfg.hidden_size * 4; // FFW expansion factor
4931
4932                // Peak is in the FFW layers before reduction
4933                max_batch_size * audio_seq_after_subsample * intermediate_size
4934            };
4935
4936            // Audio attention activations
4937            let audio_attn_act = {
4938                // Attention uses chunked processing with specific context sizes
4939                let chunk_size = audio_cfg.conf_attention_chunk_size;
4940                let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
4941                    + audio_cfg.conf_attention_context_right;
4942
4943                // Peak is attention scores: [B, num_heads, num_chunks, chunk_size, context_size]
4944                let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
4945
4946                max_batch_size
4947                    * audio_cfg.conf_num_attention_heads
4948                    * num_chunks
4949                    * chunk_size
4950                    * context_size
4951            };
4952
4953            max_activation = max_activation.max(audio_encoder_act).max(audio_attn_act);
4954        }
4955
4956        Ok(max_activation)
4957    }
4958
4959    fn non_mapped_size_in_bytes(
4960        &self,
4961        config: &str,
4962        dtype: DType,
4963        weight_pack_factor: usize,
4964        matformer_config: Option<&MatformerSliceConfig>,
4965    ) -> Result<usize> {
4966        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4967
4968        // Apply matformer slicing if configured
4969        let text_cfg = if let Some(matformer_cfg) = matformer_config {
4970            use crate::device_map::DummyDeviceMapper;
4971            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
4972
4973            let dummy_mapper = DummyDeviceMapper {
4974                nm_device: Device::Cpu,
4975            };
4976            let (adjusted_cfg, _, _, _, _) = handle_matformer_slicing(
4977                &cfg.text_config,
4978                &Some(matformer_cfg.clone()),
4979                &dummy_mapper,
4980            )?;
4981            adjusted_cfg
4982        } else {
4983            cfg.text_config.clone()
4984        };
4985
4986        let text_cfg = &text_cfg;
4987
4988        // Text components that are not device-mapped
4989        let text_elems = {
4990            // Embeddings
4991            let embed_tokens = text_cfg.hidden_size * text_cfg.vocab_size;
4992            let embed_tokens_per_layer = text_cfg.num_hidden_layers
4993                * text_cfg.hidden_size_per_layer_input
4994                * text_cfg.vocab_size_per_layer_input;
4995
4996            // LM head (if not tied)
4997            let lm_head = if !text_cfg.tie_word_embeddings || weight_pack_factor != 1 {
4998                text_cfg.hidden_size * text_cfg.vocab_size / weight_pack_factor
4999            } else {
5000                0
5001            };
5002
5003            // Final layer norm
5004            let norm = text_cfg.hidden_size;
5005
5006            // AltUp projections (not device-mapped)
5007            let altup_projections =
5008                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5009                    / weight_pack_factor;
5010            let altup_unembed_projections =
5011                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5012                    / weight_pack_factor;
5013
5014            // Per-layer model projection
5015            let per_layer_model_projection = text_cfg.num_hidden_layers
5016                * text_cfg.hidden_size
5017                * text_cfg.hidden_size_per_layer_input
5018                / weight_pack_factor;
5019            let per_layer_projection_norm = text_cfg.hidden_size;
5020
5021            embed_tokens
5022                + embed_tokens_per_layer
5023                + lm_head
5024                + norm
5025                + altup_projections
5026                + altup_unembed_projections
5027                + per_layer_model_projection
5028                + per_layer_projection_norm
5029        };
5030
5031        // Vision components
5032        let vision_elems = {
5033            let vision_cfg = &cfg.vision_config;
5034            // Vision tower - calculated from actual Gemma3n architecture
5035            // NOTE: Vision tower uses only Conv2d layers, NOT Arc<dyn QuantMethod>,
5036            // so NONE of these should be divided by weight_pack_factor
5037            let vision_tower_elems = {
5038                use crate::vision_models::gemma3n::vision::{
5039                    gemma3n_mobilenet_def, make_divisible, BlockType, INPUT_CHANNELS,
5040                    MSFA_EXPANSION_RATIO, MSFA_IN_CHANNELS, MSFA_OUT_CHANNELS, STEM_KERNEL_SIZE,
5041                    STEM_OUT_CHANNELS,
5042                };
5043
5044                // Stem: ConvNormAct (Conv2d + RMSNorm)
5045                let stem_conv =
5046                    INPUT_CHANNELS * STEM_OUT_CHANNELS * STEM_KERNEL_SIZE * STEM_KERNEL_SIZE;
5047                let stem_norm = STEM_OUT_CHANNELS; // RMSNorm weight
5048
5049                // Track input channels through the network
5050                let mut in_chs = STEM_OUT_CHANNELS;
5051                let mut total_elems = stem_conv + stem_norm;
5052
5053                // Process all stages from gemma3n_mobilenet_def
5054                let block_defs = gemma3n_mobilenet_def();
5055
5056                for stage_blocks in block_defs.iter() {
5057                    for block_type in stage_blocks.iter() {
5058                        match block_type {
5059                            BlockType::EdgeResidual {
5060                                out_channels,
5061                                kernel_size,
5062                                stride: _,
5063                                expand_ratio,
5064                                ..
5065                            } => {
5066                                #[allow(clippy::cast_precision_loss)]
5067                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5068                                // EdgeResidual: all Conv2d layers, not quantizable
5069                                total_elems += in_chs * mid_chs * kernel_size * kernel_size; // conv_exp (Conv2d)
5070                                total_elems += mid_chs; // bn1 weight
5071                                total_elems += mid_chs * out_channels; // conv_pwl (Conv2d)
5072                                total_elems += out_channels; // bn2 weight
5073                                in_chs = *out_channels;
5074                            }
5075                            BlockType::UniversalInvertedResidual {
5076                                out_channels,
5077                                start_kernel_size,
5078                                mid_kernel_size,
5079                                stride: _,
5080                                expand_ratio,
5081                                ..
5082                            } => {
5083                                #[allow(clippy::cast_precision_loss)]
5084                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5085                                // UniversalInvertedResidual: all Conv2d layers, not quantizable
5086                                if *expand_ratio != 1.0 {
5087                                    total_elems += in_chs * mid_chs; // expand conv (Conv2d)
5088                                    total_elems += mid_chs; // expand norm
5089                                }
5090                                if *start_kernel_size > 0 {
5091                                    total_elems += mid_chs * start_kernel_size * start_kernel_size; // depthwise start (Conv2d)
5092                                    total_elems += mid_chs; // norm
5093                                }
5094                                if *mid_kernel_size > 0 {
5095                                    total_elems += mid_chs * mid_kernel_size * mid_kernel_size; // depthwise mid (Conv2d)
5096                                    total_elems += mid_chs; // norm
5097                                }
5098                                total_elems += mid_chs * out_channels; // project conv (Conv2d)
5099                                total_elems += out_channels; // project norm
5100                                total_elems += out_channels; // layer scale gamma
5101                                in_chs = *out_channels;
5102                            }
5103                            BlockType::MultiQueryAttention {
5104                                num_heads,
5105                                kv_dim,
5106                                kv_stride: _,
5107                                ..
5108                            } => {
5109                                // MMQA: all Conv2d layers, not quantizable
5110                                let dw_kernel_size = 3; // Default dw_kernel_size for MMQA
5111                                total_elems += in_chs; // norm weight
5112                                total_elems += in_chs * num_heads * kv_dim; // query_proj (Conv2d)
5113                                total_elems += in_chs * kv_dim; // key_proj (Conv2d)
5114                                total_elems += in_chs * dw_kernel_size * dw_kernel_size; // key_dw_conv (Conv2d)
5115                                total_elems += *kv_dim; // value_down_conv (Conv2d)
5116                                total_elems += 1; // value_norm weight
5117                                total_elems += *kv_dim; // value_proj (Conv2d)
5118                                total_elems += num_heads * kv_dim * in_chs; // output_proj (Conv2d)
5119                                total_elems += in_chs; // layer scale
5120                            }
5121                        }
5122                    }
5123                }
5124
5125                // Multi-scale fusion adapter (msfa) - also uses Conv2d layers
5126                let msfa_in = MSFA_IN_CHANNELS.iter().sum::<usize>();
5127                let msfa_out = MSFA_OUT_CHANNELS;
5128                #[allow(clippy::cast_precision_loss)]
5129                let msfa_mid = make_divisible(msfa_in as f64 * MSFA_EXPANSION_RATIO, 8);
5130
5131                // MSFA FFN (UIR with expansion_ratio) - Conv2d layers, not quantizable
5132                total_elems += msfa_in * msfa_mid; // expand (Conv2d)
5133                total_elems += msfa_mid; // expand norm
5134                total_elems += msfa_mid * msfa_out; // project (Conv2d)
5135                total_elems += msfa_out; // project norm
5136                total_elems += msfa_out; // final norm
5137
5138                total_elems
5139            };
5140
5141            // Vision multimodal embedder components
5142            let embed_vision_elems = {
5143                // Embedding layer (not quantizable)
5144                let embedding = vision_cfg.vocab_size * vision_cfg.hidden_size;
5145
5146                // Normalization layers (not quantizable)
5147                let hard_norm = vision_cfg.hidden_size;
5148                let soft_norm = vision_cfg.hidden_size;
5149
5150                // Projection from vision to text hidden size (IS Arc<dyn QuantMethod>, so quantizable)
5151                let projection = vision_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5152
5153                // Post-projection norm (not quantizable)
5154                let post_norm = text_cfg.hidden_size;
5155
5156                embedding + hard_norm + soft_norm + projection + post_norm
5157            };
5158
5159            vision_tower_elems + embed_vision_elems
5160        };
5161
5162        // Audio components - based on actual audio.rs structure
5163        let audio_elems = {
5164            let audio_cfg = &cfg.audio_config;
5165
5166            // SubSampleConvProjection components
5167            let subsample_conv_projection_elems = {
5168                // Conv blocks (Conv2d layers - NOT quantizable)
5169                let mut conv_elems = 0;
5170
5171                // conv_0: Conv2d from 1 channel to first channel size
5172                let in_ch_0 = 1;
5173                let out_ch_0 = audio_cfg.sscp_conv_channel_size[0];
5174                let kernel_0 = &audio_cfg.sscp_conv_kernel_size[0];
5175                conv_elems += in_ch_0 * out_ch_0 * kernel_0[0] * kernel_0[1];
5176
5177                // conv_1: Conv2d from first to second channel size
5178                let in_ch_1 = out_ch_0;
5179                let out_ch_1 = audio_cfg.sscp_conv_channel_size[1];
5180                let kernel_1 = &audio_cfg.sscp_conv_kernel_size[1];
5181                conv_elems += in_ch_1 * out_ch_1 * kernel_1[0] * kernel_1[1];
5182
5183                // CumulativeGroupNorm for each conv block (weight only, no bias by default)
5184                let norm_0 = out_ch_0; // norm weight for conv_0
5185                let norm_1 = out_ch_1; // norm weight for conv_1
5186
5187                // input_proj_linear (Arc<dyn QuantMethod> - IS quantizable)
5188                let mut f_out = audio_cfg.input_feat_size;
5189                for i in 0..2 {
5190                    let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
5191                    let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
5192                    let pad_left = 1;
5193                    let pad_right = 1;
5194                    f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
5195                }
5196                let input_proj_in_features = out_ch_1 * f_out;
5197                let input_proj_linear =
5198                    input_proj_in_features * audio_cfg.hidden_size / weight_pack_factor;
5199
5200                conv_elems + norm_0 + norm_1 + input_proj_linear
5201            };
5202
5203            // Conformer blocks
5204            let conformer_elems = {
5205                let mut total = 0;
5206
5207                for _ in 0..audio_cfg.conf_num_hidden_layers {
5208                    // ConformerAttention
5209                    let attention_elems = {
5210                        // Norms (NOT quantizable)
5211                        let pre_attn_norm = audio_cfg.hidden_size;
5212                        let post_norm = audio_cfg.hidden_size;
5213
5214                        // Attention projections (Arc<dyn QuantMethod> - IS quantizable)
5215                        let q_proj =
5216                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5217                        let k_proj =
5218                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5219                        let v_proj =
5220                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5221                        let post =
5222                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5223
5224                        // RelativePositionEmbedding
5225                        let pos_proj =
5226                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5227                        let per_dim_scale =
5228                            audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads; // head_dim
5229                        let inv_timescales = audio_cfg.hidden_size / 2; // num_timescales
5230                        let pos_indices = audio_cfg.conf_attention_context_left
5231                            + audio_cfg.conf_attention_context_right
5232                            + 1;
5233
5234                        // Local causal masks (precomputed tensors)
5235                        let chunk_size = audio_cfg.conf_attention_chunk_size;
5236                        let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5237                            + audio_cfg.conf_attention_context_right;
5238                        let local_causal_valid_mask = chunk_size * context_size; // U8 tensor
5239                        let invalid_logits_tensor = 1; // single f32 value
5240
5241                        pre_attn_norm
5242                            + post_norm
5243                            + q_proj
5244                            + k_proj
5245                            + v_proj
5246                            + post
5247                            + pos_proj
5248                            + per_dim_scale
5249                            + inv_timescales
5250                            + pos_indices
5251                            + local_causal_valid_mask
5252                            + invalid_logits_tensor
5253                    };
5254
5255                    // ConformerFeedForward (start and end)
5256                    let ffw_elems = {
5257                        // Each FFW has:
5258                        // - pre_layer_norm (NOT quantizable)
5259                        // - ffw_layer_1 (Arc<dyn QuantMethod> - IS quantizable)
5260                        // - ffw_layer_2 (Arc<dyn QuantMethod> - IS quantizable)
5261                        // - post_layer_norm (NOT quantizable)
5262                        let intermediate_size = audio_cfg.hidden_size * 4;
5263
5264                        let ffw_start = {
5265                            let pre_norm = audio_cfg.hidden_size;
5266                            let layer_1 =
5267                                audio_cfg.hidden_size * intermediate_size / weight_pack_factor;
5268                            let layer_2 =
5269                                intermediate_size * audio_cfg.hidden_size / weight_pack_factor;
5270                            let post_norm = audio_cfg.hidden_size;
5271                            pre_norm + layer_1 + layer_2 + post_norm
5272                        };
5273
5274                        let ffw_end = ffw_start; // Same structure
5275
5276                        ffw_start + ffw_end
5277                    };
5278
5279                    // ConformerLightConv1d
5280                    let lconv1d_elems = {
5281                        // Norms (NOT quantizable)
5282                        let pre_layer_norm = audio_cfg.hidden_size;
5283                        let conv_norm = audio_cfg.hidden_size;
5284
5285                        // Linear layers (Arc<dyn QuantMethod> - IS quantizable)
5286                        let linear_start = audio_cfg.hidden_size * (audio_cfg.hidden_size * 2)
5287                            / weight_pack_factor;
5288                        let linear_end =
5289                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5290
5291                        // depthwise_conv1d (Conv1d - NOT quantizable)
5292                        let depthwise = audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
5293
5294                        pre_layer_norm + conv_norm + linear_start + linear_end + depthwise
5295                    };
5296
5297                    // Final norm for conformer block (NOT quantizable)
5298                    let block_norm = audio_cfg.hidden_size;
5299
5300                    total += attention_elems + ffw_elems + lconv1d_elems + block_norm;
5301                }
5302
5303                total
5304            };
5305
5306            // Audio multimodal embedder (embed_audio)
5307            let embed_audio_elems = {
5308                // Embedding layer (ScaledEmbedding - NOT quantizable)
5309                let embedding = audio_cfg.vocab_size * audio_cfg.hidden_size;
5310
5311                // RMS norms (NOT quantizable)
5312                let hard_embedding_norm = audio_cfg.hidden_size; // with scale
5313                let soft_embedding_norm = audio_cfg.hidden_size; // with scale
5314                let embedding_post_projection_norm = text_cfg.hidden_size; // without scale
5315
5316                // Projection (Arc<dyn QuantMethod> - IS quantizable)
5317                let embedding_projection =
5318                    audio_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5319
5320                embedding
5321                    + hard_embedding_norm
5322                    + soft_embedding_norm
5323                    + embedding_post_projection_norm
5324                    + embedding_projection
5325            };
5326
5327            subsample_conv_projection_elems + conformer_elems + embed_audio_elems
5328        };
5329
5330        let vision_dtype = if dtype == DType::F16 {
5331            // f16 -> f32 for vision model in particular.
5332            DType::F32
5333        } else {
5334            dtype
5335        };
5336
5337        let total_elems = text_elems * dtype.size_in_bytes()
5338            + vision_elems * vision_dtype.size_in_bytes()
5339            + audio_elems * dtype.size_in_bytes();
5340
5341        Ok(total_elems)
5342    }
5343
5344    fn layer_sizes_in_bytes(
5345        &self,
5346        config: &str,
5347        dtype: DType,
5348        weight_pack_factor: usize,
5349        matformer_config: Option<&MatformerSliceConfig>,
5350    ) -> Result<Vec<usize>> {
5351        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5352
5353        // Apply matformer slicing if configured
5354        let (text_cfg, _layer_rename_map, _layers_skipped) = if let Some(matformer_cfg) =
5355            matformer_config
5356        {
5357            use crate::device_map::DummyDeviceMapper;
5358            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5359
5360            let dummy_mapper = DummyDeviceMapper {
5361                nm_device: Device::Cpu,
5362            };
5363            let (adjusted_cfg, _, _, layer_rename_map, layers_skipped) = handle_matformer_slicing(
5364                &cfg.text_config,
5365                &Some(matformer_cfg.clone()),
5366                &dummy_mapper,
5367            )?;
5368            (adjusted_cfg, layer_rename_map, layers_skipped)
5369        } else {
5370            (cfg.text_config.clone(), None, None)
5371        };
5372
5373        let text_cfg = &text_cfg;
5374
5375        // When matformer slicing is applied, we only include the layers that are kept
5376        let mut layer_sizes = Vec::new();
5377
5378        // Note: We don't need orig_intermediate_sizes anymore since the adjusted config
5379        // already has the correct intermediate sizes after matformer slicing
5380
5381        for layer_idx in 0..text_cfg.num_hidden_layers {
5382            let per_layer_elems = {
5383                // Layer norms
5384                let input_layernorm = text_cfg.hidden_size;
5385                let post_attention_layernorm = text_cfg.hidden_size;
5386                let pre_feedforward_layernorm = text_cfg.hidden_size;
5387                let post_feedforward_layernorm = text_cfg.hidden_size;
5388                let post_per_layer_input_norm = text_cfg.hidden_size;
5389
5390                // Attention components
5391                let size_in = text_cfg.hidden_size;
5392                let size_q = text_cfg.num_attention_heads * text_cfg.head_dim;
5393                let size_kv = text_cfg.num_key_value_heads * text_cfg.head_dim;
5394
5395                let q_proj = size_in * size_q / weight_pack_factor;
5396                let k_proj = size_in * size_kv / weight_pack_factor;
5397                let v_proj = size_in * size_kv / weight_pack_factor;
5398                let o_proj = size_q * size_in / weight_pack_factor;
5399
5400                // Q, K, V norms
5401                let q_norm = text_cfg.head_dim;
5402                let k_norm = text_cfg.head_dim;
5403                let v_norm = text_cfg.head_dim; // No bias for v_norm
5404
5405                // MLP components - use the adjusted intermediate sizes from matformer
5406                let intermediate_size = match &text_cfg.intermediate_size {
5407                    IntermediateSize::Single(size) => *size,
5408                    IntermediateSize::PerLayer(sizes) => sizes[layer_idx],
5409                    IntermediateSize::Matformer(sizes, _) => sizes[layer_idx],
5410                };
5411                let gate_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5412                let up_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5413                let down_proj = intermediate_size * text_cfg.hidden_size / weight_pack_factor;
5414
5415                // AltUp components (per layer)
5416                let altup_elems = {
5417                    let correct_output_scale = text_cfg.hidden_size;
5418                    let correction_coefs = text_cfg.altup_num_inputs * text_cfg.altup_num_inputs;
5419                    let prediction_coefs =
5420                        text_cfg.altup_num_inputs * text_cfg.altup_num_inputs.pow(2);
5421                    let modality_router = text_cfg.hidden_size * text_cfg.altup_num_inputs;
5422                    let router_norm = text_cfg.hidden_size;
5423
5424                    correct_output_scale
5425                        + correction_coefs
5426                        + prediction_coefs
5427                        + modality_router
5428                        + router_norm
5429                };
5430
5431                // Laurel block components
5432                let laurel_elems = {
5433                    let left = text_cfg.hidden_size * text_cfg.laurel_rank;
5434                    let right = text_cfg.laurel_rank * text_cfg.hidden_size;
5435                    let post_norm = text_cfg.hidden_size;
5436
5437                    left + right + post_norm
5438                };
5439
5440                // Per-layer input components
5441                let per_layer_input_gate =
5442                    text_cfg.hidden_size * text_cfg.hidden_size_per_layer_input;
5443                let per_layer_projection =
5444                    text_cfg.hidden_size_per_layer_input * text_cfg.hidden_size;
5445
5446                input_layernorm
5447                    + post_attention_layernorm
5448                    + pre_feedforward_layernorm
5449                    + post_feedforward_layernorm
5450                    + post_per_layer_input_norm
5451                    + q_proj
5452                    + k_proj
5453                    + v_proj
5454                    + o_proj
5455                    + q_norm
5456                    + k_norm
5457                    + v_norm
5458                    + gate_proj
5459                    + up_proj
5460                    + down_proj
5461                    + altup_elems
5462                    + laurel_elems
5463                    + per_layer_input_gate
5464                    + per_layer_projection
5465            };
5466
5467            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5468        }
5469
5470        Ok(layer_sizes)
5471    }
5472
5473    fn num_layers(&self, config: &str) -> Result<usize> {
5474        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5475        Ok(cfg.text_config.num_hidden_layers)
5476    }
5477
5478    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5479        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5480        let cfg = cfg.text_config;
5481
5482        let cfg = ModelConfigMetadata {
5483            max_seq_len: cfg.max_position_embeddings,
5484            num_layers: cfg.num_hidden_layers,
5485            hidden_size: cfg.hidden_size,
5486            num_kv_heads: cfg.num_key_value_heads,
5487            num_attn_heads: cfg.num_attention_heads,
5488            sliding_window: None, // None to be more forgiving, some do not
5489            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5490            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5491        };
5492
5493        Ok(Box::new(cfg))
5494    }
5495
5496    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5497        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
5498    }
5499}
5500
5501// ======================== Qwen3VL Loader
5502
5503/// [`VisionLoader`] for an Qwen3VL model.
5504///
5505/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
5506pub struct Qwen3VLLoader;
5507
5508pub struct Qwen3VLPrefixer;
5509
5510impl MultimodalPromptPrefixer for Qwen3VLPrefixer {
5511    // No-op: With MessagesAction::Keep, the chat template handles image tokens
5512    // when it sees {"type": "image"} entries in the content.
5513}
5514
5515impl VisionModelLoader for Qwen3VLLoader {
5516    fn load(
5517        &self,
5518        config: &str,
5519        vb: ShardedVarBuilder,
5520        normal_loading_metadata: NormalLoadingMetadata,
5521        attention_mechanism: AttentionImplementation,
5522    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
5523        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5524        Ok(Box::new(Qwen3VLModel::new(
5525            &cfg,
5526            vb,
5527            self.is_gptx(config),
5528            normal_loading_metadata,
5529            attention_mechanism,
5530        )?))
5531    }
5532    fn is_gptx(&self, _config: &str) -> bool {
5533        true
5534    }
5535    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5536        let config: Qwen3VLConfig = serde_json::from_str(config)?;
5537        Ok(Box::new(config))
5538    }
5539    fn get_processor(
5540        &self,
5541        _model_config: &str,
5542        _processor_config: Option<ProcessorConfig>,
5543        _preprocessor_config: PreProcessorConfig,
5544        max_edge: Option<u32>,
5545    ) -> Arc<dyn Processor + Send + Sync> {
5546        Arc::new(Qwen3VLProcessor::new(max_edge))
5547    }
5548    fn supports_paged_attention(&self, _config: &str) -> bool {
5549        true
5550    }
5551    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5552        Arc::new(Qwen3VLPrefixer)
5553    }
5554    fn modalities(&self, _config: &str) -> Result<Modalities> {
5555        Ok(Modalities {
5556            input: vec![SupportedModality::Text, SupportedModality::Vision],
5557            output: vec![SupportedModality::Text],
5558        })
5559    }
5560}
5561
5562impl IsqModelLoader for Qwen3VLLoader {
5563    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5564        Ok(vec![
5565            Regex::new(r"lm_head\.(weight|bias)$")?,
5566            // Attention
5567            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5568            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5569            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5570            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5571            // MLP
5572            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5573            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5574            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5575        ])
5576    }
5577    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5578        self.isq_layer_regexes(config)
5579    }
5580}
5581
5582impl DeviceMappedModelLoader for Qwen3VLLoader {
5583    fn mapped_max_act_size_elems(
5584        &self,
5585        config: &str,
5586        params: &AutoDeviceMapParams,
5587    ) -> Result<usize> {
5588        let AutoDeviceMapParams::Vision {
5589            max_seq_len,
5590            max_batch_size,
5591            max_image_shape,
5592            max_num_images,
5593        } = params
5594        else {
5595            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5596        };
5597
5598        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5599
5600        // For images, grid_t=1. After spatial merging, grid_h and grid_w are reduced.
5601        let img_seq_len = {
5602            let cfg = &cfg.vision_config;
5603            // grid_t is 1 for images (temporal dimension is for video only)
5604            let grid_t = 1;
5605            // After patch embedding and spatial merge, the effective grid dimensions are reduced
5606            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
5607            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
5608            grid_t * grid_h * grid_w * max_num_images
5609        };
5610
5611        let max_text_attn = {
5612            let cfg = &cfg.text_config;
5613            // This model injects the vision information directly into the input embeddings
5614            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5615            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5616        };
5617
5618        Ok(max_text_attn)
5619    }
5620
5621    fn non_mapped_max_act_size_elems(
5622        &self,
5623        config: &str,
5624        params: &AutoDeviceMapParams,
5625    ) -> Result<usize> {
5626        let AutoDeviceMapParams::Vision {
5627            max_seq_len: _,
5628            max_batch_size,
5629            max_image_shape,
5630            max_num_images,
5631        } = params
5632        else {
5633            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5634        };
5635
5636        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5637
5638        // For the vision encoder, before spatial merging
5639        let img_seq_len = {
5640            let cfg = &cfg.vision_config;
5641            // grid_t is 1 for images
5642            let grid_t = 1;
5643            let grid_h = max_image_shape.0 / cfg.patch_size;
5644            let grid_w = max_image_shape.1 / cfg.patch_size;
5645            grid_t * grid_h * grid_w
5646        };
5647
5648        let max_vision_attn = {
5649            let cfg = &cfg.vision_config;
5650            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5651        };
5652
5653        Ok(max_vision_attn)
5654    }
5655
5656    fn non_mapped_size_in_bytes(
5657        &self,
5658        config: &str,
5659        dtype: DType,
5660        weight_pack_factor: usize,
5661        _matformer_config: Option<&MatformerSliceConfig>,
5662    ) -> Result<usize> {
5663        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5664        let tie = cfg.tie_word_embeddings;
5665        let text_elems = {
5666            let cfg = &cfg.text_config;
5667            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5668            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
5669            let lm_head = if !tie || weight_pack_factor != 1 {
5670                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5671            } else {
5672                0
5673            };
5674            let norm = cfg.hidden_size;
5675            embed_tokens + lm_head + norm
5676        };
5677
5678        let patch_merger = {
5679            let cfg = &cfg.vision_config;
5680            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
5681
5682            let mlp0 = hidden_size * hidden_size + hidden_size;
5683            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
5684
5685            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5686
5687            mlp0 + mlp2 + ln_q
5688        };
5689
5690        let patch_embed = {
5691            let cfg = &cfg.vision_config;
5692            let conv_cfg = Conv3dConfig {
5693                stride: cfg.patch_size,
5694                ..Default::default()
5695            };
5696            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
5697            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
5698                * kernel_sizes[0]
5699                * kernel_sizes[1]
5700                * kernel_sizes[2]
5701        };
5702
5703        let encoder_layer = {
5704            let cfg = &cfg.vision_config;
5705            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5706            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5707
5708            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
5709            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
5710            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
5711
5712            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
5713            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
5714
5715            norm1 + norm2 + fc1 + fc2 + qkv + out
5716        };
5717
5718        let elems =
5719            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
5720
5721        Ok(elems * dtype.size_in_bytes())
5722    }
5723
5724    fn layer_sizes_in_bytes(
5725        &self,
5726        config: &str,
5727        dtype: DType,
5728        weight_pack_factor: usize,
5729        _matformer_config: Option<&MatformerSliceConfig>,
5730    ) -> Result<Vec<usize>> {
5731        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5732        let per_layer_elems = {
5733            let cfg = &cfg.text_config;
5734            let input_layernorm = cfg.hidden_size;
5735            let post_attention_layernorm = cfg.hidden_size;
5736
5737            let size_in = cfg.hidden_size;
5738            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
5739            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
5740            let q_proj = size_in * size_q / weight_pack_factor + size_q;
5741            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
5742            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
5743            let o_proj = size_q * size_in / weight_pack_factor;
5744
5745            let h_size = cfg.hidden_size;
5746            let i_size = cfg.intermediate_size;
5747            let gate_proj = h_size * i_size / weight_pack_factor;
5748            let up_proj = h_size * i_size / weight_pack_factor;
5749            let down_proj = i_size * h_size / weight_pack_factor;
5750
5751            input_layernorm
5752                + post_attention_layernorm
5753                + q_proj
5754                + k_proj
5755                + v_proj
5756                + o_proj
5757                + gate_proj
5758                + up_proj
5759                + down_proj
5760        };
5761        Ok(vec![
5762            per_layer_elems * dtype.size_in_bytes();
5763            cfg.text_config.num_hidden_layers
5764        ])
5765    }
5766
5767    fn num_layers(&self, config: &str) -> Result<usize> {
5768        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5769        let cfg = &cfg.text_config;
5770        Ok(cfg.num_hidden_layers)
5771    }
5772
5773    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5774        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5775        let cfg = &cfg.text_config;
5776
5777        let cfg = ModelConfigMetadata {
5778            max_seq_len: cfg.max_position_embeddings,
5779            num_layers: cfg.num_hidden_layers,
5780            hidden_size: cfg.hidden_size,
5781            num_kv_heads: cfg.num_key_value_heads,
5782            num_attn_heads: cfg.num_attention_heads,
5783            sliding_window: cfg.sliding_window,
5784            k_head_dim: cfg.head_dim,
5785            v_head_dim: cfg.head_dim,
5786        };
5787
5788        Ok(Box::new(cfg))
5789    }
5790
5791    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5792        Some(vec![NonMappedSubModel::Vision])
5793    }
5794}
5795
5796// ======================== Qwen3VLMoE Loader
5797
5798/// [`VisionLoader`] for a Qwen3VLMoE model.
5799///
5800/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
5801pub struct Qwen3VLMoELoader;
5802
5803pub struct Qwen3VLMoEPrefixer;
5804
5805impl MultimodalPromptPrefixer for Qwen3VLMoEPrefixer {
5806    // No-op: With MessagesAction::Keep, the chat template handles image tokens
5807    // when it sees {"type": "image"} entries in the content.
5808}
5809
5810impl VisionModelLoader for Qwen3VLMoELoader {
5811    fn load(
5812        &self,
5813        config: &str,
5814        vb: ShardedVarBuilder,
5815        normal_loading_metadata: NormalLoadingMetadata,
5816        attention_mechanism: AttentionImplementation,
5817    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
5818        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5819        Ok(Box::new(Qwen3VLMoEModel::new(
5820            &cfg,
5821            vb,
5822            self.is_gptx(config),
5823            normal_loading_metadata,
5824            attention_mechanism,
5825        )?))
5826    }
5827    fn is_gptx(&self, _config: &str) -> bool {
5828        true
5829    }
5830    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5831        let config: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5832        Ok(Box::new(config))
5833    }
5834    fn get_processor(
5835        &self,
5836        _model_config: &str,
5837        _processor_config: Option<ProcessorConfig>,
5838        _preprocessor_config: PreProcessorConfig,
5839        max_edge: Option<u32>,
5840    ) -> Arc<dyn Processor + Send + Sync> {
5841        Arc::new(Qwen3VLMoEProcessor::new(max_edge))
5842    }
5843    fn supports_paged_attention(&self, _config: &str) -> bool {
5844        true
5845    }
5846    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5847        Arc::new(Qwen3VLMoEPrefixer)
5848    }
5849    fn modalities(&self, _config: &str) -> Result<Modalities> {
5850        Ok(Modalities {
5851            input: vec![SupportedModality::Text, SupportedModality::Vision],
5852            output: vec![SupportedModality::Text],
5853        })
5854    }
5855}
5856
5857impl IsqModelLoader for Qwen3VLMoELoader {
5858    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5859        Ok(vec![
5860            Regex::new(r"lm_head\.(weight|bias)$")?,
5861            // Attention
5862            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5863            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5864            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5865            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5866            // MLP (dense layers)
5867            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5868            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5869            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5870            // MoE router
5871            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate\.(weight|bias)$")?,
5872            // MoE experts - now unpacked into individual experts
5873            Regex::new(
5874                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
5875            )?,
5876            Regex::new(
5877                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$",
5878            )?,
5879            Regex::new(
5880                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$",
5881            )?,
5882        ])
5883    }
5884    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5885        self.isq_layer_regexes(config)
5886    }
5887}
5888
5889impl DeviceMappedModelLoader for Qwen3VLMoELoader {
5890    fn mapped_max_act_size_elems(
5891        &self,
5892        config: &str,
5893        params: &AutoDeviceMapParams,
5894    ) -> Result<usize> {
5895        let AutoDeviceMapParams::Vision {
5896            max_seq_len,
5897            max_batch_size,
5898            max_image_shape,
5899            max_num_images,
5900        } = params
5901        else {
5902            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5903        };
5904
5905        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5906
5907        // For images, grid_t=1. After spatial merging, grid_h and grid_w are reduced.
5908        let img_seq_len = {
5909            let cfg = &cfg.vision_config;
5910            // grid_t is 1 for images (temporal dimension is for video only)
5911            let grid_t = 1;
5912            // After patch embedding and spatial merge, the effective grid dimensions are reduced
5913            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
5914            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
5915            grid_t * grid_h * grid_w * max_num_images
5916        };
5917
5918        let max_text_attn = {
5919            let cfg = &cfg.text_config;
5920            // This model injects the vision information directly into the input embeddings
5921            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5922            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5923        };
5924
5925        Ok(max_text_attn)
5926    }
5927
5928    fn non_mapped_max_act_size_elems(
5929        &self,
5930        config: &str,
5931        params: &AutoDeviceMapParams,
5932    ) -> Result<usize> {
5933        let AutoDeviceMapParams::Vision {
5934            max_seq_len: _,
5935            max_batch_size,
5936            max_image_shape,
5937            max_num_images,
5938        } = params
5939        else {
5940            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5941        };
5942
5943        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5944
5945        // For the vision encoder, before spatial merging
5946        let img_seq_len = {
5947            let cfg = &cfg.vision_config;
5948            // grid_t is 1 for images
5949            let grid_t = 1;
5950            let grid_h = max_image_shape.0 / cfg.patch_size;
5951            let grid_w = max_image_shape.1 / cfg.patch_size;
5952            grid_t * grid_h * grid_w
5953        };
5954
5955        let max_vision_attn = {
5956            let cfg = &cfg.vision_config;
5957            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5958        };
5959
5960        Ok(max_vision_attn)
5961    }
5962
5963    fn non_mapped_size_in_bytes(
5964        &self,
5965        config: &str,
5966        dtype: DType,
5967        weight_pack_factor: usize,
5968        _matformer_config: Option<&MatformerSliceConfig>,
5969    ) -> Result<usize> {
5970        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5971        let tie = cfg.tie_word_embeddings;
5972        let text_elems = {
5973            let cfg = &cfg.text_config;
5974            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5975            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
5976            let lm_head = if !tie || weight_pack_factor != 1 {
5977                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5978            } else {
5979                0
5980            };
5981            let norm = cfg.hidden_size;
5982            embed_tokens + lm_head + norm
5983        };
5984
5985        let patch_merger = {
5986            let cfg = &cfg.vision_config;
5987            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
5988
5989            let mlp0 = hidden_size * hidden_size + hidden_size;
5990            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
5991
5992            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5993
5994            mlp0 + mlp2 + ln_q
5995        };
5996
5997        let patch_embed = {
5998            let cfg = &cfg.vision_config;
5999            let conv_cfg = Conv3dConfig {
6000                stride: cfg.patch_size,
6001                ..Default::default()
6002            };
6003            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
6004            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
6005                * kernel_sizes[0]
6006                * kernel_sizes[1]
6007                * kernel_sizes[2]
6008        };
6009
6010        let encoder_layer = {
6011            let cfg = &cfg.vision_config;
6012            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6013            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6014
6015            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
6016            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
6017            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
6018
6019            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
6020            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
6021
6022            norm1 + norm2 + fc1 + fc2 + qkv + out
6023        };
6024
6025        let elems =
6026            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
6027
6028        Ok(elems * dtype.size_in_bytes())
6029    }
6030
6031    fn layer_sizes_in_bytes(
6032        &self,
6033        config: &str,
6034        dtype: DType,
6035        weight_pack_factor: usize,
6036        _matformer_config: Option<&MatformerSliceConfig>,
6037    ) -> Result<Vec<usize>> {
6038        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6039        let text_cfg = &cfg.text_config;
6040
6041        let mut layer_sizes = Vec::with_capacity(text_cfg.num_hidden_layers);
6042
6043        for layer_idx in 0..text_cfg.num_hidden_layers {
6044            let input_layernorm = text_cfg.hidden_size;
6045            let post_attention_layernorm = text_cfg.hidden_size;
6046
6047            let size_in = text_cfg.hidden_size;
6048            let size_q = (text_cfg.hidden_size / text_cfg.num_attention_heads)
6049                * text_cfg.num_attention_heads;
6050            let size_kv = (text_cfg.hidden_size / text_cfg.num_attention_heads)
6051                * text_cfg.num_key_value_heads;
6052            let q_proj = size_in * size_q / weight_pack_factor + size_q;
6053            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
6054            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
6055            let o_proj = size_q * size_in / weight_pack_factor;
6056
6057            // Check if this is a MoE layer
6058            let is_moe = !text_cfg.mlp_only_layers.contains(&layer_idx)
6059                && (text_cfg.num_experts > 0
6060                    && (layer_idx + 1) % text_cfg.decoder_sparse_step == 0);
6061
6062            let mlp_elems = if is_moe {
6063                // MoE layer: gate + experts
6064                let gate = text_cfg.hidden_size * text_cfg.num_experts;
6065                let per_expert = {
6066                    let h_size = text_cfg.hidden_size;
6067                    let i_size = text_cfg.moe_intermediate_size;
6068                    let gate_proj = h_size * i_size / weight_pack_factor;
6069                    let up_proj = h_size * i_size / weight_pack_factor;
6070                    let down_proj = i_size * h_size / weight_pack_factor;
6071                    gate_proj + up_proj + down_proj
6072                };
6073                gate + per_expert * text_cfg.num_experts
6074            } else {
6075                // Dense MLP layer
6076                let h_size = text_cfg.hidden_size;
6077                let i_size = text_cfg.intermediate_size;
6078                let gate_proj = h_size * i_size / weight_pack_factor;
6079                let up_proj = h_size * i_size / weight_pack_factor;
6080                let down_proj = i_size * h_size / weight_pack_factor;
6081                gate_proj + up_proj + down_proj
6082            };
6083
6084            let per_layer_elems = input_layernorm
6085                + post_attention_layernorm
6086                + q_proj
6087                + k_proj
6088                + v_proj
6089                + o_proj
6090                + mlp_elems;
6091
6092            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
6093        }
6094
6095        Ok(layer_sizes)
6096    }
6097
6098    fn num_layers(&self, config: &str) -> Result<usize> {
6099        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6100        let cfg = &cfg.text_config;
6101        Ok(cfg.num_hidden_layers)
6102    }
6103
6104    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
6105        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6106        let cfg = &cfg.text_config;
6107
6108        let cfg = ModelConfigMetadata {
6109            max_seq_len: cfg.max_position_embeddings,
6110            num_layers: cfg.num_hidden_layers,
6111            hidden_size: cfg.hidden_size,
6112            num_kv_heads: cfg.num_key_value_heads,
6113            num_attn_heads: cfg.num_attention_heads,
6114            sliding_window: cfg.sliding_window,
6115            k_head_dim: cfg.head_dim,
6116            v_head_dim: cfg.head_dim,
6117        };
6118
6119        Ok(Box::new(cfg))
6120    }
6121
6122    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
6123        Some(vec![NonMappedSubModel::Vision])
6124    }
6125}