mistralrs_core/pipeline/loaders/
vision_loaders.rs

1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::attention::ATTENTION_CHUNK_SIZE;
24use crate::device_map::DeviceMapper;
25use crate::layers::Conv3dConfig;
26use crate::matformer::MatformerSliceConfig;
27use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
28use crate::pipeline::isq::IsqModelLoader;
29use crate::pipeline::loaders::AutoDeviceMapParams;
30use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
31use crate::pipeline::{
32    EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
33    SupportedModality,
34};
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::vision_models::clip::ClipConfig;
37use crate::vision_models::gemma3::config::Gemma3Config;
38use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
39use crate::vision_models::gemma3n::config::{Gemma3nConfig, IntermediateSize};
40use crate::vision_models::gemma3n::{Gemma3nModel, Gemma3nProcessor};
41use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
42use crate::vision_models::idefics2_input_processor::Idefics2Processor;
43use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
44use crate::vision_models::image_processor::ImagePreProcessor;
45use crate::vision_models::inputs_processor::Phi4MMProcessor;
46use crate::vision_models::llama4::{
47    self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
48};
49use crate::vision_models::llava::config::Config as LLaVAConfig;
50use crate::vision_models::llava15::Model as LLaVA;
51use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
52use crate::vision_models::llava_next::Model as LLaVANext;
53use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
54use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
55use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
56use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
57use crate::vision_models::phi3_inputs_processor::Phi3Processor;
58use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
59use crate::vision_models::preprocessor_config::PreProcessorConfig;
60use crate::vision_models::processor_config::ProcessorConfig;
61use crate::vision_models::qwen2_5_vl::{
62    Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
63};
64use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
65use crate::vision_models::qwen3_vl::{Config as Qwen3VLConfig, Qwen3VLModel, Qwen3VLProcessor};
66use crate::vision_models::qwen3_vl_moe::{
67    Config as Qwen3VLMoEConfig, Qwen3VLMoEModel, Qwen3VLMoEProcessor,
68};
69use crate::vision_models::{minicpmo, phi4};
70
71pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
72    // pixel_values and pixel_attention_mask only specified for prompt seqs
73    #[allow(clippy::too_many_arguments)]
74    fn forward(
75        &self,
76        input_ids: &Tensor,
77        pixel_values: Option<Tensor>,
78        seqlen_offsets: &[usize],
79        context_lens: Vec<(usize, usize)>,
80        position_ids: Vec<usize>,
81        model_specific_args: Box<dyn Any>, // pixel attention mask, or image sizes, or anything else
82        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
83        flash_params: &FlashParams,
84    ) -> candle_core::Result<Tensor>;
85    fn device(&self) -> &Device;
86    fn cache(&self) -> &EitherCache;
87    fn cache_mut(&mut self) -> &mut EitherCache;
88    fn max_seq_len(&self) -> usize;
89    fn config(&self) -> &ModelConfigMetadata;
90    /// For a prompt without images. Requires batch size of 1!
91    fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
92}
93
94pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
95    fn load(
96        &self,
97        config: &str,
98        vb: ShardedVarBuilder,
99        normal_loading_metadata: NormalLoadingMetadata,
100        attention_mechanism: AttentionImplementation,
101    ) -> Result<Box<dyn VisionModel + Send + Sync>>;
102    fn is_gptx(&self, config: &str) -> bool;
103    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
104    fn get_processor(
105        &self,
106        model_config: &str,
107        processor_config: Option<ProcessorConfig>,
108        preprocessor_config: PreProcessorConfig,
109        max_edge: Option<u32>,
110    ) -> Arc<dyn Processor + Send + Sync>;
111    fn supports_paged_attention(&self, config: &str) -> bool;
112    fn supports_prefix_cacher(&self, _config: &str) -> bool {
113        // Default is false, specific model must override.
114        false
115    }
116    fn modalities(&self, config: &str) -> Result<Modalities>;
117    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
118    fn get_device_for_tensor(
119        &self,
120        config: &str,
121        _mapper: &dyn DeviceMapper,
122        loading_isq: bool,
123    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
124        if loading_isq {
125            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
126        } else {
127            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
128            let num_layers = self.model_config(config)?.num_layers();
129            let closure = move |name: String| {
130                if let Some(captures) = re.captures(&name) {
131                    captures
132                        .get(1)
133                        .and_then(|m| m.as_str().parse::<usize>().ok())
134                        .map(|l| l.min(num_layers))
135                        .map(DeviceForLoadTensor::Idx)
136                        .unwrap_or(DeviceForLoadTensor::Base)
137                } else {
138                    DeviceForLoadTensor::Base
139                }
140            };
141
142            Ok(Arc::new(closure))
143        }
144    }
145}
146
147#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
148#[derive(Clone, Debug, Deserialize, PartialEq)]
149/// The architecture to load the vision model as.
150pub enum VisionLoaderType {
151    #[serde(rename = "phi3v")]
152    Phi3V,
153    #[serde(rename = "idefics2")]
154    Idefics2,
155    #[serde(rename = "llava_next")]
156    LLaVANext,
157    #[serde(rename = "llava")]
158    LLaVA,
159    #[serde(rename = "vllama")]
160    VLlama,
161    #[serde(rename = "qwen2vl")]
162    Qwen2VL,
163    #[serde(rename = "idefics3")]
164    Idefics3,
165    #[serde(rename = "minicpmo")]
166    MiniCpmO,
167    #[serde(rename = "phi4mm")]
168    Phi4MM,
169    #[serde(rename = "qwen2_5vl")]
170    Qwen2_5VL,
171    #[serde(rename = "gemma3")]
172    Gemma3,
173    #[serde(rename = "mistral3")]
174    Mistral3,
175    #[serde(rename = "llama4")]
176    Llama4,
177    #[serde(rename = "gemma3n")]
178    Gemma3n,
179    #[serde(rename = "qwen3vl")]
180    Qwen3VL,
181    #[serde(rename = "qwen3vlmoe")]
182    Qwen3VLMoE,
183}
184
185// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
186impl VisionLoaderType {
187    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
188        match name {
189            "Phi3VForCausalLM" => Ok(Self::Phi3V),
190            "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
191            "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
192            "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
193            "MllamaForConditionalGeneration" => Ok(Self::VLlama),
194            "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
195            "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
196            "MiniCPMO" => Ok(Self::MiniCpmO),
197            "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
198            "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
199            "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
200            "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
201            "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
202            "Gemma3nForConditionalGeneration" => Ok(Self::Gemma3n),
203            "Qwen3VLForConditionalGeneration" => Ok(Self::Qwen3VL),
204            "Qwen3VLMoeForConditionalGeneration" => Ok(Self::Qwen3VLMoE),
205            other => anyhow::bail!(
206                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
207            ),
208        }
209    }
210}
211
212impl FromStr for VisionLoaderType {
213    type Err = String;
214    fn from_str(s: &str) -> Result<Self, Self::Err> {
215        match s {
216            "phi3v" => Ok(Self::Phi3V),
217            "idefics2" => Ok(Self::Idefics2),
218            "llava_next" => Ok(Self::LLaVANext),
219            "llava" => Ok(Self::LLaVA),
220            "vllama" => Ok(Self::VLlama),
221            "qwen2vl" => Ok(Self::Qwen2VL),
222            "idefics3" => Ok(Self::Idefics3),
223            "minicpmo" => Ok(Self::MiniCpmO),
224            "phi4mm" => Ok(Self::Phi4MM),
225            "qwen2_5vl" => Ok(Self::Qwen2_5VL),
226            "gemma3" => Ok(Self::Gemma3),
227            "mistral3" => Ok(Self::Mistral3),
228            "llama4" => Ok(Self::Llama4),
229            "gemma3n" => Ok(Self::Gemma3n),
230            "qwen3vl" => Ok(Self::Qwen3VL),
231            "qwen3vlmoe" => Ok(Self::Qwen3VLMoE),
232            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`, `gemma3n`, `qwen3vl`, `qwen3vlmoe`.")),
233        }
234    }
235}
236
237impl std::fmt::Display for VisionLoaderType {
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        let name = match self {
240            VisionLoaderType::Phi3V => "phi3v",
241            VisionLoaderType::Idefics2 => "idefics2",
242            VisionLoaderType::LLaVANext => "llava_next",
243            VisionLoaderType::LLaVA => "llava",
244            VisionLoaderType::VLlama => "vllama",
245            VisionLoaderType::Qwen2VL => "qwen2vl",
246            VisionLoaderType::Idefics3 => "idefics3",
247            VisionLoaderType::MiniCpmO => "minicpmo",
248            VisionLoaderType::Phi4MM => "phi4mm",
249            VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
250            VisionLoaderType::Gemma3 => "gemma3",
251            VisionLoaderType::Mistral3 => "mistral3",
252            VisionLoaderType::Llama4 => "llama4",
253            VisionLoaderType::Gemma3n => "gemma3n",
254            VisionLoaderType::Qwen3VL => "qwen3vl",
255            VisionLoaderType::Qwen3VLMoE => "qwen3vlmoe",
256        };
257        write!(f, "{name}")
258    }
259}
260
261#[derive(Deserialize)]
262struct AutoVisionLoaderConfig {
263    architectures: Vec<String>,
264}
265
266/// Automatically selects a VisionModelLoader implementation based on the JSON `architectures` field.
267pub struct AutoVisionLoader;
268
269impl AutoVisionLoader {
270    fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
271        let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
272        if auto_cfg.architectures.len() != 1 {
273            anyhow::bail!("Expected exactly one architecture in config");
274        }
275
276        let name = &auto_cfg.architectures[0];
277        let tp = VisionLoaderType::from_causal_lm_name(name)?;
278
279        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
280
281        // Delegate to the concrete loader
282        Ok(match tp {
283            VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
284            VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
285            VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
286            VisionLoaderType::LLaVA => Box::new(LLaVALoader),
287            VisionLoaderType::VLlama => Box::new(VLlamaLoader),
288            VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
289            VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
290            VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
291            VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
292            VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
293            VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
294            VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
295            VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
296            VisionLoaderType::Gemma3n => Box::new(Gemma3nLoader),
297            VisionLoaderType::Qwen3VL => Box::new(Qwen3VLLoader),
298            VisionLoaderType::Qwen3VLMoE => Box::new(Qwen3VLMoELoader),
299        })
300    }
301}
302
303impl VisionModelLoader for AutoVisionLoader {
304    fn load(
305        &self,
306        config: &str,
307        vb: ShardedVarBuilder,
308        normal_loading_metadata: NormalLoadingMetadata,
309        attention_mechanism: AttentionImplementation,
310    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
311        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
312    }
313
314    fn is_gptx(&self, config: &str) -> bool {
315        Self::get_loader(config)
316            .expect("AutoVisionLoader get_loader")
317            .is_gptx(config)
318    }
319
320    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
321        Self::get_loader(config)?.get_config_repr(config)
322    }
323
324    fn get_processor(
325        &self,
326        model_config: &str,
327        proc_cfg: Option<ProcessorConfig>,
328        preproc_cfg: PreProcessorConfig,
329        max_edge: Option<u32>,
330    ) -> Arc<dyn Processor + Send + Sync> {
331        Self::get_loader(model_config)
332            .expect("AutoVisionLoader get_loader")
333            .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
334    }
335
336    fn supports_paged_attention(&self, config: &str) -> bool {
337        Self::get_loader(config)
338            .expect("AutoVisionLoader")
339            .supports_paged_attention(config)
340    }
341
342    fn modalities(&self, config: &str) -> Result<Modalities> {
343        Self::get_loader(config)?.modalities(config)
344    }
345
346    fn supports_prefix_cacher(&self, config: &str) -> bool {
347        Self::get_loader(config)
348            .expect("AutoVisionLoader")
349            .supports_prefix_cacher(config)
350    }
351
352    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
353        Self::get_loader(config)
354            .expect("AutoVisionLoader")
355            .prefixer(config)
356    }
357
358    fn get_device_for_tensor(
359        &self,
360        config: &str,
361        mapper: &dyn DeviceMapper,
362        loading_isq: bool,
363    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
364        Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
365    }
366}
367
368impl IsqModelLoader for AutoVisionLoader {
369    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
370        Self::get_loader(config)?.isq_layer_regexes(config)
371    }
372    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
373        Self::get_loader(config)?.immediate_isq_predicates(config)
374    }
375}
376
377impl DeviceMappedModelLoader for AutoVisionLoader {
378    fn mapped_max_act_size_elems(
379        &self,
380        config: &str,
381        params: &AutoDeviceMapParams,
382    ) -> Result<usize> {
383        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
384    }
385    fn non_mapped_max_act_size_elems(
386        &self,
387        config: &str,
388        params: &AutoDeviceMapParams,
389    ) -> Result<usize> {
390        Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
391    }
392    fn non_mapped_size_in_bytes(
393        &self,
394        config: &str,
395        dtype: DType,
396        weight_pack_factor: usize,
397        _matformer_config: Option<&MatformerSliceConfig>,
398    ) -> Result<usize> {
399        Self::get_loader(config)?.non_mapped_size_in_bytes(
400            config,
401            dtype,
402            weight_pack_factor,
403            _matformer_config,
404        )
405    }
406    fn layer_sizes_in_bytes(
407        &self,
408        config: &str,
409        dtype: DType,
410        weight_pack_factor: usize,
411        _matformer_config: Option<&MatformerSliceConfig>,
412    ) -> Result<Vec<usize>> {
413        Self::get_loader(config)?.layer_sizes_in_bytes(
414            config,
415            dtype,
416            weight_pack_factor,
417            _matformer_config,
418        )
419    }
420    fn num_layers(&self, config: &str) -> Result<usize> {
421        Self::get_loader(config)?.num_layers(config)
422    }
423    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
424        Self::get_loader(config)?.model_config(config)
425    }
426}
427
428macro_rules! bias_if {
429    ($cond:expr, $size:expr) => {
430        if $cond {
431            $size
432        } else {
433            0
434        }
435    };
436}
437
438fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
439    let pre_layer_norm = cfg.hidden_size;
440    let final_layer_norm = cfg.hidden_size;
441
442    let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
443    let num_positions = num_patches + 1;
444
445    let class_embedding = cfg.hidden_size;
446
447    let position_ids = num_positions;
448    let position_embedding = num_positions * cfg.hidden_size;
449
450    let conv2dconfig = Conv2dConfig {
451        stride: cfg.patch_size,
452        ..Default::default()
453    };
454    let patch_embedding =
455        cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
456
457    let encoder_layer_elems = {
458        let layer_norm1 = cfg.hidden_size;
459        let layer_norm2 = cfg.hidden_size;
460
461        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
462        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
463        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
464        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
465
466        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
467        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
468
469        layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
470    };
471
472    pre_layer_norm
473        + final_layer_norm
474        + class_embedding
475        + position_ids
476        + position_embedding
477        + patch_embedding
478        + cfg.num_hidden_layers * encoder_layer_elems
479}
480
481// ======================== Phi 3 loader
482
483/// [`VisionLoader`] for a Phi 3 Vision model.
484///
485/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
486pub struct Phi3VLoader;
487
488pub struct Phi3VPrefixer;
489
490impl MultimodalPromptPrefixer for Phi3VPrefixer {
491    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
492        // Image indexing starts at 0.
493        format!(
494            "{}{prompt}",
495            image_indexes
496                .into_iter()
497                .map(|image_index| format!("<|image_{}|>", image_index + 1))
498                .join("")
499        )
500    }
501}
502
503impl VisionModelLoader for Phi3VLoader {
504    fn load(
505        &self,
506        config: &str,
507        vb: ShardedVarBuilder,
508        normal_loading_metadata: NormalLoadingMetadata,
509        attention_mechanism: AttentionImplementation,
510    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
511        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
512        Ok(Box::new(Phi3::new(
513            &cfg,
514            vb,
515            self.is_gptx(config),
516            normal_loading_metadata,
517            attention_mechanism,
518        )?))
519    }
520    fn is_gptx(&self, _config: &str) -> bool {
521        true
522    }
523    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
524        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
525        Ok(Box::new(cfg))
526    }
527    fn get_processor(
528        &self,
529        _model_config: &str,
530        processor_config: Option<ProcessorConfig>,
531        preprocessor_config: PreProcessorConfig,
532        _max_edge: Option<u32>,
533    ) -> Arc<dyn Processor + Send + Sync> {
534        Phi3Processor::new_processor(processor_config, preprocessor_config)
535    }
536    fn supports_paged_attention(&self, _config: &str) -> bool {
537        true
538    }
539    fn supports_prefix_cacher(&self, _config: &str) -> bool {
540        true
541    }
542    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
543        Arc::new(Phi3VPrefixer)
544    }
545    fn modalities(&self, _config: &str) -> Result<Modalities> {
546        Ok(Modalities {
547            input: vec![SupportedModality::Text, SupportedModality::Vision],
548            output: vec![SupportedModality::Text],
549        })
550    }
551}
552
553impl IsqModelLoader for Phi3VLoader {
554    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
555        Ok(vec![
556            Regex::new(r"lm_head\.(weight|bias)$")?,
557            // Attention
558            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
559            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
560            // MLP
561            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
562            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
563        ])
564    }
565    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
566        self.isq_layer_regexes(config)
567    }
568}
569
570impl DeviceMappedModelLoader for Phi3VLoader {
571    fn mapped_max_act_size_elems(
572        &self,
573        config: &str,
574        params: &AutoDeviceMapParams,
575    ) -> Result<usize> {
576        // NOTE: we ignore max_num_images although it can only be one...
577        let AutoDeviceMapParams::Vision {
578            max_seq_len,
579            max_batch_size,
580            max_image_shape: _,
581            max_num_images,
582        } = params
583        else {
584            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
585        };
586
587        let cfg: Phi3Config = serde_json::from_str(config)?;
588
589        let vcfg = &PHI3V_CLIP_CONFIG;
590
591        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
592        let img_seq_len = (num_patches + 1) * max_num_images;
593
594        let max_text_attn = {
595            // This model injects the vision information directly into the input embeddings
596            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
597            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
598        };
599
600        Ok(max_text_attn)
601    }
602
603    fn non_mapped_max_act_size_elems(
604        &self,
605        config: &str,
606        params: &AutoDeviceMapParams,
607    ) -> Result<usize> {
608        // NOTE: we ignore max_num_images although it can only be one...
609        let AutoDeviceMapParams::Vision {
610            max_seq_len: _,
611            max_batch_size,
612            max_image_shape: _,
613            max_num_images,
614        } = params
615        else {
616            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
617        };
618
619        let cfg: Phi3Config = serde_json::from_str(config)?;
620
621        let vcfg = &PHI3V_CLIP_CONFIG;
622
623        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
624        let img_seq_len = num_patches + 1;
625
626        let max_vision_attn = {
627            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
628        };
629
630        Ok(max_vision_attn)
631    }
632
633    fn non_mapped_size_in_bytes(
634        &self,
635        config: &str,
636        dtype: DType,
637        weight_pack_factor: usize,
638        _matformer_config: Option<&MatformerSliceConfig>,
639    ) -> Result<usize> {
640        let cfg: Phi3Config = serde_json::from_str(config)?;
641        let elems = {
642            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
643            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
644            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
645                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
646            } else {
647                0
648            };
649            let norm = cfg.hidden_size;
650
651            let image_embed = {
652                let projection_cls = cfg
653                    .embd_layer
654                    .projection_cls
655                    .clone()
656                    .unwrap_or("linear".to_string());
657                let with_learnable_separator =
658                    cfg.embd_layer.with_learnable_separator.unwrap_or(false);
659                let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
660                let image_dim_out = cfg.img_processor.image_dim_out;
661
662                let proj = match (projection_cls.as_str(), use_hd_transform) {
663                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
664                    ("mlp", true) => {
665                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
666                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
667                        a + b
668                    }
669                    ("mlp", false) => {
670                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
671                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
672                        a + b
673                    }
674                    _ => {
675                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
676                    }
677                };
678
679                let (glb_gn, sub_gn) = if with_learnable_separator {
680                    let glb_gn = image_dim_out * 4;
681                    let sub_gn = image_dim_out * 4;
682                    (glb_gn, sub_gn)
683                } else {
684                    (0, 0)
685                };
686
687                let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
688
689                proj + glb_gn + sub_gn + clip_vit
690            };
691
692            embed_tokens + lm_head + norm + image_embed
693        };
694
695        Ok(elems * dtype.size_in_bytes())
696    }
697
698    fn layer_sizes_in_bytes(
699        &self,
700        config: &str,
701        dtype: DType,
702        weight_pack_factor: usize,
703        _matformer_config: Option<&MatformerSliceConfig>,
704    ) -> Result<Vec<usize>> {
705        let cfg: Phi3Config = serde_json::from_str(config)?;
706        let per_layer_elems = {
707            let input_layernorm = cfg.hidden_size;
708            let post_attention_layernorm = cfg.hidden_size;
709
710            let size_in = cfg.hidden_size;
711            let head_dim = cfg.head_dim();
712            let op_size =
713                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
714            let qkv_proj = size_in * op_size / weight_pack_factor;
715            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
716
717            let h_size = cfg.hidden_size;
718            let i_size = cfg.intermediate_size;
719            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
720            let down_proj = h_size * i_size / weight_pack_factor;
721
722            input_layernorm
723                + post_attention_layernorm
724                + qkv_proj
725                + o_proj
726                + gate_up_proj
727                + down_proj
728        };
729        Ok(vec![
730            per_layer_elems * dtype.size_in_bytes();
731            cfg.num_hidden_layers
732        ])
733    }
734
735    fn num_layers(&self, config: &str) -> Result<usize> {
736        let cfg: Phi3Config = serde_json::from_str(config)?;
737        Ok(cfg.num_hidden_layers)
738    }
739
740    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
741        let cfg: Phi3Config = serde_json::from_str(config)?;
742
743        let cfg = ModelConfigMetadata {
744            max_seq_len: cfg.max_position_embeddings,
745            num_layers: cfg.num_hidden_layers,
746            hidden_size: cfg.hidden_size,
747            num_kv_heads: cfg.num_key_value_heads,
748            num_attn_heads: cfg.num_attention_heads,
749            sliding_window: cfg.sliding_window,
750            k_head_dim: cfg.head_dim(),
751            v_head_dim: cfg.head_dim(),
752        };
753
754        Ok(Box::new(cfg))
755    }
756
757    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
758        Some(vec![NonMappedSubModel::Vision])
759    }
760}
761
762// ======================== Idefics 2 loader
763
764/// [`VisionLoader`] for an Idefics 2 Vision model.
765///
766/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
767pub struct Idefics2Loader;
768
769pub struct Idefics2Prefixer;
770
771impl MultimodalPromptPrefixer for Idefics2Prefixer {
772    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
773        // Chat template does it
774        prompt.to_string()
775    }
776}
777
778impl VisionModelLoader for Idefics2Loader {
779    fn load(
780        &self,
781        config: &str,
782        vb: ShardedVarBuilder,
783        normal_loading_metadata: NormalLoadingMetadata,
784        attention_mechanism: AttentionImplementation,
785    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
786        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
787        Ok(Box::new(Idefics2::new(
788            &cfg,
789            vb,
790            self.is_gptx(config),
791            normal_loading_metadata,
792            attention_mechanism,
793        )?))
794    }
795    fn is_gptx(&self, _config: &str) -> bool {
796        true
797    }
798    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
799        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
800        Ok(Box::new(cfg))
801    }
802    fn get_processor(
803        &self,
804        _model_config: &str,
805        processor_config: Option<ProcessorConfig>,
806        preprocessor_config: PreProcessorConfig,
807        max_edge: Option<u32>,
808    ) -> Arc<dyn Processor + Send + Sync> {
809        Arc::new(Idefics2Processor::new(
810            processor_config.unwrap(),
811            preprocessor_config,
812            max_edge,
813        ))
814    }
815    fn supports_paged_attention(&self, _config: &str) -> bool {
816        true
817    }
818    fn supports_prefix_cacher(&self, _config: &str) -> bool {
819        true
820    }
821    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
822        Arc::new(Idefics2Prefixer)
823    }
824    fn modalities(&self, _config: &str) -> Result<Modalities> {
825        Ok(Modalities {
826            input: vec![SupportedModality::Text, SupportedModality::Vision],
827            output: vec![SupportedModality::Text],
828        })
829    }
830}
831
832impl IsqModelLoader for Idefics2Loader {
833    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
834        Ok(vec![
835            Regex::new(r"lm_head\.(weight|bias)$")?,
836            // Attention
837            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
838            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
839            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
840            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
841            // MLP
842            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
843            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
844            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
845        ])
846    }
847    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
848        Ok(vec![
849            Regex::new(r"lm_head\.(weight|bias)$")?,
850            // Attention
851            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
852            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
853            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
854            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
855            // MLP
856            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
857            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
858            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
859        ])
860    }
861}
862
863impl DeviceMappedModelLoader for Idefics2Loader {
864    fn mapped_max_act_size_elems(
865        &self,
866        config: &str,
867        params: &AutoDeviceMapParams,
868    ) -> Result<usize> {
869        let AutoDeviceMapParams::Vision {
870            max_seq_len,
871            max_batch_size,
872            max_image_shape: _,
873            max_num_images,
874        } = params
875        else {
876            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
877        };
878
879        let cfg: Idefics2Config = serde_json::from_str(config)?;
880
881        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
882        let img_seq_len = (num_patches + 1) * max_num_images;
883
884        let max_text_attn = {
885            // This model injects the vision information directly into the input embeddings
886            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
887            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
888        };
889
890        Ok(max_text_attn)
891    }
892
893    fn non_mapped_max_act_size_elems(
894        &self,
895        config: &str,
896        params: &AutoDeviceMapParams,
897    ) -> Result<usize> {
898        let AutoDeviceMapParams::Vision {
899            max_seq_len: _,
900            max_batch_size,
901            max_image_shape: _,
902            max_num_images,
903        } = params
904        else {
905            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
906        };
907
908        let cfg: Idefics2Config = serde_json::from_str(config)?;
909
910        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
911        let img_seq_len = num_patches + 1;
912
913        let max_vision_attn = {
914            // do_image_splitting = true
915            let images_factor = 5;
916
917            (max_batch_size * images_factor * max_num_images)
918                * cfg.vision_config.num_attention_heads
919                * img_seq_len
920                * img_seq_len
921        };
922
923        Ok(max_vision_attn)
924    }
925
926    fn non_mapped_size_in_bytes(
927        &self,
928        config: &str,
929        dtype: DType,
930        weight_pack_factor: usize,
931        _matformer_config: Option<&MatformerSliceConfig>,
932    ) -> Result<usize> {
933        let cfg: Idefics2Config = serde_json::from_str(config)?;
934        let text_elems = {
935            let tie_word_embeddings = cfg.tie_word_embeddings;
936            let cfg = &cfg.text_config;
937
938            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
939            let lm_head = if !tie_word_embeddings {
940                cfg.hidden_size * cfg.vocab_size
941            } else {
942                0
943            };
944            let norm = cfg.hidden_size;
945            embed_tokens + lm_head + norm
946        };
947
948        let connector_elems = {
949            let tcfg = &cfg.text_config;
950            let vcfg = &cfg.vision_config;
951            let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
952            let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
953            let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
954
955            let perceiver_elems = {
956                let tcfg = &cfg.text_config;
957                let pcfg = &cfg.perceiver_config;
958
959                let n_latents = pcfg.resampler_n_latents;
960                let hidden_size = tcfg.hidden_size;
961                let depth = pcfg.resampler_depth;
962
963                let norm = tcfg.hidden_size;
964                let latents = n_latents * hidden_size;
965
966                let layer_elems = {
967                    let input_latents_norm = hidden_size;
968                    let input_context_norm = hidden_size;
969                    let post_attn_norm = hidden_size;
970
971                    let num_heads = pcfg.resampler_n_heads;
972                    let head_dim = pcfg.resampler_head_dim;
973                    let num_key_value_heads = pcfg.num_key_value_heads;
974
975                    let q_proj = hidden_size * num_heads * head_dim;
976                    let k_proj = hidden_size * num_key_value_heads * head_dim;
977                    let v_proj = hidden_size * num_key_value_heads * head_dim;
978                    let o_proj = num_heads * head_dim * hidden_size;
979
980                    let gate_proj = hidden_size * hidden_size * 4;
981                    let up_proj = hidden_size * hidden_size * 4;
982                    let down_proj = hidden_size * 4 * hidden_size;
983
984                    input_latents_norm
985                        + input_context_norm
986                        + post_attn_norm
987                        + q_proj
988                        + k_proj
989                        + v_proj
990                        + o_proj
991                        + gate_proj
992                        + up_proj
993                        + down_proj
994                };
995
996                norm + latents + layer_elems * depth
997            };
998
999            gate_proj + up_proj + down_proj + perceiver_elems
1000        };
1001
1002        let vision_transformer = {
1003            let cfg = &cfg.vision_config;
1004
1005            let post_layernorm = cfg.hidden_size;
1006
1007            let conv_config = Conv2dConfig {
1008                stride: cfg.patch_size,
1009                ..Default::default()
1010            };
1011            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
1012                * cfg.patch_size
1013                * cfg.patch_size;
1014
1015            let num_patches_per_side = cfg.image_size / cfg.patch_size;
1016            let num_patches = num_patches_per_side.pow(2);
1017            let position_embedding = num_patches * cfg.hidden_size;
1018
1019            let layer_elems = {
1020                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1021                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1022
1023                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
1024                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
1025
1026                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1027                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1028                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1029                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1030
1031                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
1032            };
1033
1034            post_layernorm + patch_embedding + position_embedding + layer_elems
1035        };
1036
1037        let elems = text_elems + connector_elems + vision_transformer;
1038
1039        Ok(elems * dtype.size_in_bytes())
1040    }
1041
1042    fn layer_sizes_in_bytes(
1043        &self,
1044        config: &str,
1045        dtype: DType,
1046        weight_pack_factor: usize,
1047        _matformer_config: Option<&MatformerSliceConfig>,
1048    ) -> Result<Vec<usize>> {
1049        let cfg: Idefics2Config = serde_json::from_str(config)?;
1050        let cfg = cfg.text_config;
1051        let per_layer_elems = {
1052            let input_layernorm = cfg.hidden_size;
1053            let post_attention_layernorm = cfg.hidden_size;
1054
1055            let size_in = cfg.hidden_size;
1056            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1057            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1058            let q_proj = size_in * size_q / weight_pack_factor;
1059            let k_proj = size_in * size_kv / weight_pack_factor;
1060            let v_proj = size_in * size_kv / weight_pack_factor;
1061            let o_proj = size_q * size_in / weight_pack_factor;
1062
1063            let h_size = cfg.hidden_size;
1064            let i_size = cfg.intermediate_size;
1065            let gate_proj = h_size * i_size / weight_pack_factor;
1066            let up_proj = h_size * i_size / weight_pack_factor;
1067            let down_proj = i_size * h_size / weight_pack_factor;
1068
1069            input_layernorm
1070                + post_attention_layernorm
1071                + q_proj
1072                + k_proj
1073                + v_proj
1074                + o_proj
1075                + gate_proj
1076                + up_proj
1077                + down_proj
1078        };
1079        Ok(vec![
1080            per_layer_elems * dtype.size_in_bytes();
1081            cfg.num_hidden_layers
1082        ])
1083    }
1084
1085    fn num_layers(&self, config: &str) -> Result<usize> {
1086        let cfg: Idefics2Config = serde_json::from_str(config)?;
1087        Ok(cfg.text_config.num_hidden_layers)
1088    }
1089    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1090        let cfg: Idefics2Config = serde_json::from_str(config)?;
1091        let cfg = &cfg.text_config;
1092
1093        let cfg = ModelConfigMetadata {
1094            max_seq_len: cfg.max_position_embeddings,
1095            num_layers: cfg.num_hidden_layers,
1096            hidden_size: cfg.hidden_size,
1097            num_kv_heads: cfg.num_key_value_heads,
1098            num_attn_heads: cfg.num_attention_heads,
1099            sliding_window: cfg.sliding_window,
1100            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1101            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1102        };
1103
1104        Ok(Box::new(cfg))
1105    }
1106
1107    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1108        Some(vec![NonMappedSubModel::Vision])
1109    }
1110}
1111
1112// ======================== LLaVANext Loader
1113
1114/// [`VisionLoader`] for an LLaVANext Vision model.
1115///
1116/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1117pub struct LLaVANextLoader;
1118
1119pub struct LLaVANextPrefixer;
1120
1121impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1122    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1123        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1124    }
1125}
1126
1127impl VisionModelLoader for LLaVANextLoader {
1128    fn load(
1129        &self,
1130        config: &str,
1131        vb: ShardedVarBuilder,
1132        normal_loading_metadata: NormalLoadingMetadata,
1133        attention_mechanism: AttentionImplementation,
1134    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1135        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1136        Ok(Box::new(LLaVANext::new(
1137            &cfg,
1138            vb,
1139            self.is_gptx(config),
1140            normal_loading_metadata,
1141            attention_mechanism,
1142        )?))
1143    }
1144    fn is_gptx(&self, _config: &str) -> bool {
1145        false
1146    }
1147    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1148        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1149        Ok(Box::new(cfg))
1150    }
1151    fn get_processor(
1152        &self,
1153        model_config: &str,
1154        _processor_config: Option<ProcessorConfig>,
1155        _preprocessor_config: PreProcessorConfig,
1156        _max_edge: Option<u32>,
1157    ) -> Arc<dyn Processor + Send + Sync> {
1158        Arc::new(LLaVANextProcessor::new(model_config))
1159    }
1160    fn supports_paged_attention(&self, _config: &str) -> bool {
1161        true
1162    }
1163    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1164        true
1165    }
1166    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1167        Arc::new(LLaVANextPrefixer)
1168    }
1169    fn modalities(&self, _config: &str) -> Result<Modalities> {
1170        Ok(Modalities {
1171            input: vec![SupportedModality::Text, SupportedModality::Vision],
1172            output: vec![SupportedModality::Text],
1173        })
1174    }
1175}
1176
1177impl IsqModelLoader for LLaVANextLoader {
1178    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1179        Ok(vec![
1180            Regex::new(r"lm_head\.(weight|bias)$")?,
1181            // Attention
1182            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1183            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1184            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1185            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1186            // MLP
1187            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1188            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1189            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1190        ])
1191    }
1192    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1193        Ok(vec![
1194            Regex::new(r"lm_head\.(weight|bias)$")?,
1195            // Attention
1196            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1197            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1198            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1199            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1200            // MLP
1201            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1202            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1203            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1204        ])
1205    }
1206}
1207
1208impl DeviceMappedModelLoader for LLaVANextLoader {
1209    fn mapped_max_act_size_elems(
1210        &self,
1211        config: &str,
1212        params: &AutoDeviceMapParams,
1213    ) -> Result<usize> {
1214        let AutoDeviceMapParams::Vision {
1215            max_seq_len,
1216            max_batch_size,
1217            max_image_shape,
1218            max_num_images,
1219        } = params
1220        else {
1221            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1222        };
1223
1224        let config: LLaVAConfig = serde_json::from_str(config)?;
1225
1226        #[allow(clippy::cast_possible_truncation)]
1227        let img_seq_len =
1228            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1229                &config,
1230                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1231            );
1232        let img_seq_len = img_seq_len * max_num_images;
1233
1234        let max_text_attn = {
1235            let cfg = &config.text_config;
1236            // This model injects the vision information directly into the input embeddings
1237            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1238
1239            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1240        };
1241
1242        Ok(max_text_attn)
1243    }
1244
1245    fn non_mapped_max_act_size_elems(
1246        &self,
1247        config: &str,
1248        params: &AutoDeviceMapParams,
1249    ) -> Result<usize> {
1250        let AutoDeviceMapParams::Vision {
1251            max_seq_len: _,
1252            max_batch_size,
1253            max_image_shape,
1254            max_num_images,
1255        } = params
1256        else {
1257            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1258        };
1259
1260        let config: LLaVAConfig = serde_json::from_str(config)?;
1261
1262        #[allow(clippy::cast_possible_truncation)]
1263        let img_seq_len =
1264            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1265                &config,
1266                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1267            );
1268
1269        let max_vision_attn = {
1270            (max_batch_size * max_num_images)
1271                * config.vision_config.num_attention_heads
1272                * img_seq_len
1273                * img_seq_len
1274        };
1275
1276        Ok(max_vision_attn)
1277    }
1278
1279    fn non_mapped_size_in_bytes(
1280        &self,
1281        config: &str,
1282        dtype: DType,
1283        weight_pack_factor: usize,
1284        _matformer_config: Option<&MatformerSliceConfig>,
1285    ) -> Result<usize> {
1286        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1287        let text_elems = {
1288            let cfg = &cfg.text_config;
1289            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1290            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1291            let norm = cfg.hidden_size;
1292            embed_tokens + lm_head + norm
1293        };
1294
1295        let image_newline = cfg.text_config.hidden_size;
1296        let mmproj = {
1297            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1298                + cfg.text_config.hidden_size;
1299            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1300                + cfg.text_config.hidden_size;
1301
1302            linear_1 + linear_2
1303        };
1304        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1305
1306        let elems = text_elems + image_newline + mmproj + vision_tower;
1307        Ok(elems * dtype.size_in_bytes())
1308    }
1309
1310    fn layer_sizes_in_bytes(
1311        &self,
1312        config: &str,
1313        dtype: DType,
1314        weight_pack_factor: usize,
1315        _matformer_config: Option<&MatformerSliceConfig>,
1316    ) -> Result<Vec<usize>> {
1317        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1318        let per_layer_elems = {
1319            let cfg = &cfg.text_config;
1320            let input_layernorm = cfg.hidden_size;
1321            let post_attention_layernorm = cfg.hidden_size;
1322
1323            let size_in = cfg.hidden_size;
1324            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1325            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1326            let q_proj = size_in * size_q / weight_pack_factor;
1327            let k_proj = size_in * size_kv / weight_pack_factor;
1328            let v_proj = size_in * size_kv / weight_pack_factor;
1329            let o_proj = size_q * size_in / weight_pack_factor;
1330
1331            let h_size = cfg.hidden_size;
1332            let i_size = cfg.intermediate_size;
1333            let gate_proj = h_size * i_size / weight_pack_factor;
1334            let up_proj = h_size * i_size / weight_pack_factor;
1335            let down_proj = i_size * h_size / weight_pack_factor;
1336
1337            input_layernorm
1338                + post_attention_layernorm
1339                + q_proj
1340                + k_proj
1341                + v_proj
1342                + o_proj
1343                + gate_proj
1344                + up_proj
1345                + down_proj
1346        };
1347        Ok(vec![
1348            per_layer_elems * dtype.size_in_bytes();
1349            cfg.text_config.num_hidden_layers
1350        ])
1351    }
1352
1353    fn num_layers(&self, config: &str) -> Result<usize> {
1354        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1355        Ok(cfg.text_config.num_hidden_layers)
1356    }
1357
1358    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1359        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1360        let cfg = &cfg.text_config;
1361
1362        let cfg = ModelConfigMetadata {
1363            max_seq_len: cfg.max_position_embeddings,
1364            num_layers: cfg.num_hidden_layers,
1365            hidden_size: cfg.hidden_size,
1366            num_kv_heads: cfg.num_key_value_heads,
1367            num_attn_heads: cfg.num_attention_heads,
1368            sliding_window: cfg.sliding_window,
1369            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1370            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1371        };
1372
1373        Ok(Box::new(cfg))
1374    }
1375
1376    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1377        Some(vec![NonMappedSubModel::Vision])
1378    }
1379}
1380
1381// ======================== LLaVA Loader
1382
1383/// [`VisionLoader`] for an LLaVA Vision model.
1384///
1385/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1386pub struct LLaVALoader;
1387
1388pub struct LLaVAPrefixer;
1389
1390impl MultimodalPromptPrefixer for LLaVAPrefixer {
1391    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1392        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1393    }
1394}
1395
1396impl VisionModelLoader for LLaVALoader {
1397    fn load(
1398        &self,
1399        config: &str,
1400        vb: ShardedVarBuilder,
1401        normal_loading_metadata: NormalLoadingMetadata,
1402        attention_mechanism: AttentionImplementation,
1403    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1404        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1405        Ok(Box::new(LLaVA::new(
1406            &cfg,
1407            vb,
1408            self.is_gptx(config),
1409            normal_loading_metadata,
1410            attention_mechanism,
1411        )?))
1412    }
1413    fn is_gptx(&self, _config: &str) -> bool {
1414        false
1415    }
1416    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1417        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1418        Ok(Box::new(cfg))
1419    }
1420    fn get_processor(
1421        &self,
1422        model_config: &str,
1423        _processor_config: Option<ProcessorConfig>,
1424        _preprocessor_config: PreProcessorConfig,
1425        _max_edge: Option<u32>,
1426    ) -> Arc<dyn Processor + Send + Sync> {
1427        Arc::new(LLaVAProcessor::new(model_config))
1428    }
1429    fn supports_paged_attention(&self, _config: &str) -> bool {
1430        true
1431    }
1432    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1433        true
1434    }
1435    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1436        Arc::new(LLaVAPrefixer)
1437    }
1438    fn modalities(&self, _config: &str) -> Result<Modalities> {
1439        Ok(Modalities {
1440            input: vec![SupportedModality::Text, SupportedModality::Vision],
1441            output: vec![SupportedModality::Text],
1442        })
1443    }
1444}
1445
1446impl IsqModelLoader for LLaVALoader {
1447    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1448        Ok(vec![
1449            Regex::new(r"lm_head\.(weight|bias)$")?,
1450            // Attention
1451            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1452            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1453            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1454            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1455            // MLP
1456            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1457            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1458            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1459        ])
1460    }
1461    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1462        Ok(vec![
1463            Regex::new(r"lm_head\.(weight|bias)$")?,
1464            // Attention
1465            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1466            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1467            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1468            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1469            // MLP
1470            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1471            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1472            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1473        ])
1474    }
1475}
1476
1477impl DeviceMappedModelLoader for LLaVALoader {
1478    fn mapped_max_act_size_elems(
1479        &self,
1480        config: &str,
1481        params: &AutoDeviceMapParams,
1482    ) -> Result<usize> {
1483        let AutoDeviceMapParams::Vision {
1484            max_seq_len,
1485            max_batch_size,
1486            max_image_shape: _,
1487            max_num_images,
1488        } = params
1489        else {
1490            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1491        };
1492
1493        let config: LLaVAConfig = serde_json::from_str(config)?;
1494
1495        let img_seq_len =
1496            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1497        let img_seq_len = img_seq_len * max_num_images;
1498
1499        let max_text_attn = {
1500            let cfg = &config.text_config;
1501            // This model injects the vision information directly into the input embeddings
1502            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1503
1504            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1505        };
1506
1507        Ok(max_text_attn)
1508    }
1509
1510    fn non_mapped_max_act_size_elems(
1511        &self,
1512        config: &str,
1513        params: &AutoDeviceMapParams,
1514    ) -> Result<usize> {
1515        let AutoDeviceMapParams::Vision {
1516            max_seq_len: _,
1517            max_batch_size,
1518            max_image_shape: _,
1519            max_num_images,
1520        } = params
1521        else {
1522            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1523        };
1524
1525        let config: LLaVAConfig = serde_json::from_str(config)?;
1526
1527        let img_seq_len =
1528            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1529
1530        let max_vision_attn = {
1531            (max_batch_size * max_num_images)
1532                * config.vision_config.num_attention_heads
1533                * img_seq_len
1534                * img_seq_len
1535        };
1536
1537        Ok(max_vision_attn)
1538    }
1539
1540    fn non_mapped_size_in_bytes(
1541        &self,
1542        config: &str,
1543        dtype: DType,
1544        weight_pack_factor: usize,
1545        _matformer_config: Option<&MatformerSliceConfig>,
1546    ) -> Result<usize> {
1547        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1548        let text_elems = {
1549            let cfg = &cfg.text_config;
1550            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1551            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1552            let norm = cfg.hidden_size;
1553            embed_tokens + lm_head + norm
1554        };
1555
1556        let image_newline = cfg.text_config.hidden_size;
1557        let mmproj = {
1558            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1559                + cfg.text_config.hidden_size;
1560            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1561                + cfg.text_config.hidden_size;
1562
1563            linear_1 + linear_2
1564        };
1565        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1566
1567        let elems = text_elems + image_newline + mmproj + vision_tower;
1568        Ok(elems * dtype.size_in_bytes())
1569    }
1570
1571    fn layer_sizes_in_bytes(
1572        &self,
1573        config: &str,
1574        dtype: DType,
1575        weight_pack_factor: usize,
1576        _matformer_config: Option<&MatformerSliceConfig>,
1577    ) -> Result<Vec<usize>> {
1578        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1579        let per_layer_elems = {
1580            let cfg = &cfg.text_config;
1581            let input_layernorm = cfg.hidden_size;
1582            let post_attention_layernorm = cfg.hidden_size;
1583
1584            let size_in = cfg.hidden_size;
1585            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1586            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1587            let q_proj = size_in * size_q / weight_pack_factor;
1588            let k_proj = size_in * size_kv / weight_pack_factor;
1589            let v_proj = size_in * size_kv / weight_pack_factor;
1590            let o_proj = size_q * size_in / weight_pack_factor;
1591
1592            let h_size = cfg.hidden_size;
1593            let i_size = cfg.intermediate_size;
1594            let gate_proj = h_size * i_size / weight_pack_factor;
1595            let up_proj = h_size * i_size / weight_pack_factor;
1596            let down_proj = i_size * h_size / weight_pack_factor;
1597
1598            input_layernorm
1599                + post_attention_layernorm
1600                + q_proj
1601                + k_proj
1602                + v_proj
1603                + o_proj
1604                + gate_proj
1605                + up_proj
1606                + down_proj
1607        };
1608        Ok(vec![
1609            per_layer_elems * dtype.size_in_bytes();
1610            cfg.text_config.num_hidden_layers
1611        ])
1612    }
1613
1614    fn num_layers(&self, config: &str) -> Result<usize> {
1615        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1616        Ok(cfg.text_config.num_hidden_layers)
1617    }
1618
1619    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1620        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1621        let cfg = &cfg.text_config;
1622
1623        let cfg = ModelConfigMetadata {
1624            max_seq_len: cfg.max_position_embeddings,
1625            num_layers: cfg.num_hidden_layers,
1626            hidden_size: cfg.hidden_size,
1627            num_kv_heads: cfg.num_key_value_heads,
1628            num_attn_heads: cfg.num_attention_heads,
1629            sliding_window: cfg.sliding_window,
1630            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1631            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1632        };
1633
1634        Ok(Box::new(cfg))
1635    }
1636
1637    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1638        Some(vec![NonMappedSubModel::Vision])
1639    }
1640}
1641
1642// ======================== MLlama Loader
1643
1644/// [`VisionLoader`] for an Llama Vision model.
1645///
1646/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1647pub struct VLlamaLoader;
1648
1649pub struct VLlamaPrefixer;
1650
1651impl MultimodalPromptPrefixer for VLlamaPrefixer {
1652    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1653        format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1654    }
1655}
1656
1657impl VisionModelLoader for VLlamaLoader {
1658    fn load(
1659        &self,
1660        config: &str,
1661        vb: ShardedVarBuilder,
1662        normal_loading_metadata: NormalLoadingMetadata,
1663        attention_mechanism: AttentionImplementation,
1664    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1665        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1666        Ok(Box::new(MLlamaModel::new(
1667            &cfg,
1668            vb,
1669            self.is_gptx(config),
1670            normal_loading_metadata,
1671            attention_mechanism,
1672        )?))
1673    }
1674    fn is_gptx(&self, _config: &str) -> bool {
1675        true
1676    }
1677    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1678        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1679        Ok(Box::new(cfg))
1680    }
1681    fn get_processor(
1682        &self,
1683        _model_config: &str,
1684        _processor_config: Option<ProcessorConfig>,
1685        _preprocessor_config: PreProcessorConfig,
1686        _max_edge: Option<u32>,
1687    ) -> Arc<dyn Processor + Send + Sync> {
1688        Arc::new(MLlamaProcessor::new())
1689    }
1690    fn supports_paged_attention(&self, _config: &str) -> bool {
1691        false
1692    }
1693    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1694        true
1695    }
1696    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1697        Arc::new(VLlamaPrefixer)
1698    }
1699    fn modalities(&self, _config: &str) -> Result<Modalities> {
1700        Ok(Modalities {
1701            input: vec![SupportedModality::Text, SupportedModality::Vision],
1702            output: vec![SupportedModality::Text],
1703        })
1704    }
1705}
1706
1707impl IsqModelLoader for VLlamaLoader {
1708    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1709        let config: MLlamaConfig = serde_json::from_str(config)?;
1710        let cross_attn_layers = &config.text_config.cross_attention_layers;
1711        let transformer_layers =
1712            (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1713        let mut text_regexes = Vec::new();
1714        for layer in transformer_layers {
1715            text_regexes.extend(vec![
1716                // Attention text
1717                Regex::new(&format!(
1718                    r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1719                ))?,
1720                Regex::new(&format!(
1721                    r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1722                ))?,
1723                Regex::new(&format!(
1724                    r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1725                ))?,
1726                Regex::new(&format!(
1727                    r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1728                ))?,
1729                // MLP text
1730                Regex::new(&format!(
1731                    r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1732                ))?,
1733                Regex::new(&format!(
1734                    r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1735                ))?,
1736                Regex::new(&format!(
1737                    r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1738                ))?,
1739            ]);
1740        }
1741        let vision_regexes = vec![
1742            // Vision attention (transformer)
1743            Regex::new(
1744                r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1745            )?,
1746            Regex::new(
1747                r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1748            )?,
1749            Regex::new(
1750                r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1751            )?,
1752            Regex::new(
1753                r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1754            )?,
1755            // Vision attention (global transforemr)
1756            Regex::new(
1757                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1758            )?,
1759            Regex::new(
1760                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1761            )?,
1762            Regex::new(
1763                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1764            )?,
1765            Regex::new(
1766                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1767            )?,
1768            // MLP vision
1769            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1770            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1771        ];
1772
1773        Ok([text_regexes, vision_regexes].concat())
1774    }
1775    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1776        self.isq_layer_regexes(config)
1777    }
1778}
1779
1780impl DeviceMappedModelLoader for VLlamaLoader {
1781    fn mapped_max_act_size_elems(
1782        &self,
1783        config: &str,
1784        params: &AutoDeviceMapParams,
1785    ) -> Result<usize> {
1786        let AutoDeviceMapParams::Vision {
1787            max_seq_len,
1788            max_batch_size,
1789            max_image_shape: _,
1790            max_num_images,
1791        } = params
1792        else {
1793            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1794        };
1795
1796        let config: MLlamaConfig = serde_json::from_str(config)?;
1797
1798        let img_seq_len = {
1799            let cfg = &config.vision_config;
1800            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1801            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1802            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1803        };
1804        let img_seq_len = img_seq_len * max_num_images;
1805
1806        let max_cross_text_attn = {
1807            let cfg = &config.text_config;
1808            max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1809        };
1810
1811        let max_self_text_attn = {
1812            let cfg = &config.text_config;
1813            max_batch_size * cfg.num_attention_heads * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)
1814        };
1815
1816        Ok(max_self_text_attn.max(max_cross_text_attn))
1817    }
1818
1819    fn non_mapped_max_act_size_elems(
1820        &self,
1821        config: &str,
1822        params: &AutoDeviceMapParams,
1823    ) -> Result<usize> {
1824        let AutoDeviceMapParams::Vision {
1825            max_seq_len: _,
1826            max_batch_size,
1827            max_image_shape: _,
1828            max_num_images,
1829        } = params
1830        else {
1831            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1832        };
1833
1834        let config: MLlamaConfig = serde_json::from_str(config)?;
1835
1836        let img_seq_len = {
1837            let cfg = &config.vision_config;
1838            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1839            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1840            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1841        };
1842        let max_vision_attn = {
1843            let cfg = &config.vision_config;
1844            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1845        };
1846
1847        Ok(max_vision_attn)
1848    }
1849
1850    fn non_mapped_size_in_bytes(
1851        &self,
1852        config: &str,
1853        dtype: DType,
1854        weight_pack_factor: usize,
1855        _matformer_config: Option<&MatformerSliceConfig>,
1856    ) -> Result<usize> {
1857        let config: MLlamaConfig = serde_json::from_str(config)?;
1858        let text_elems = {
1859            let cfg = &config.text_config;
1860            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1861            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1862            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1863                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1864            } else {
1865                0
1866            };
1867            let norm = cfg.hidden_size;
1868            embed_tokens + lm_head + norm
1869        };
1870
1871        let vision_elems = {
1872            let cfg = &config.vision_config;
1873
1874            let conv_cfg = Conv2dConfig {
1875                stride: cfg.patch_size,
1876                ..Default::default()
1877            };
1878            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1879                * cfg.patch_size
1880                * cfg.patch_size;
1881
1882            let class_embedding = cfg.hidden_size;
1883
1884            let gated_positional_embedding = {
1885                let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1886                let embedding = num_patches * cfg.hidden_size;
1887                let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1888                    * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1889
1890                embedding + tile_embedding
1891            };
1892
1893            let pre_tile_positional_embedding =
1894                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1895            let post_tile_positional_embedding =
1896                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1897
1898            let layernorm_pre = cfg.hidden_size;
1899            let layernorm_post = cfg.hidden_size;
1900
1901            let encoder_layer = {
1902                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1903                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1904
1905                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1906                let q_proj =
1907                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1908                let k_proj =
1909                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1910                let v_proj =
1911                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1912                let o_proj =
1913                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1914
1915                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1916                    + cfg.intermediate_size;
1917                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1918                    + cfg.hidden_size;
1919
1920                input_layernorm
1921                    + post_attention_layernorm
1922                    + q_proj
1923                    + k_proj
1924                    + v_proj
1925                    + o_proj
1926                    + fc1
1927                    + fc2
1928            };
1929
1930            patch_embedding
1931                + class_embedding
1932                + gated_positional_embedding
1933                + pre_tile_positional_embedding
1934                + post_tile_positional_embedding
1935                + layernorm_pre
1936                + layernorm_post
1937                + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1938        };
1939
1940        let elems = text_elems + vision_elems;
1941        Ok(elems * dtype.size_in_bytes())
1942    }
1943
1944    fn layer_sizes_in_bytes(
1945        &self,
1946        config: &str,
1947        dtype: DType,
1948        weight_pack_factor: usize,
1949        _matformer_config: Option<&MatformerSliceConfig>,
1950    ) -> Result<Vec<usize>> {
1951        let config: MLlamaConfig = serde_json::from_str(config)?;
1952        let cfg = &config.text_config;
1953
1954        let mut layer_sizes = Vec::new();
1955
1956        for i in 0..cfg.num_hidden_layers {
1957            let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1958                // No isq for cross attention
1959                1
1960            } else {
1961                weight_pack_factor
1962            };
1963
1964            let per_layer_elems = {
1965                let input_layernorm = cfg.hidden_size;
1966                let post_attention_layernorm = cfg.hidden_size;
1967
1968                let size_in = cfg.hidden_size;
1969                let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1970                let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1971                let q_proj = size_in * size_q / weight_pack_factor;
1972                let k_proj = size_in * size_kv / weight_pack_factor;
1973                let v_proj = size_in * size_kv / weight_pack_factor;
1974                let o_proj = size_q * size_in / weight_pack_factor;
1975
1976                let h_size = cfg.hidden_size;
1977                let i_size = cfg.intermediate_size;
1978                let gate_proj = h_size * i_size / weight_pack_factor;
1979                let up_proj = h_size * i_size / weight_pack_factor;
1980                let down_proj = i_size * h_size / weight_pack_factor;
1981
1982                input_layernorm
1983                    + post_attention_layernorm
1984                    + q_proj
1985                    + k_proj
1986                    + v_proj
1987                    + o_proj
1988                    + gate_proj
1989                    + up_proj
1990                    + down_proj
1991            };
1992
1993            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1994        }
1995
1996        Ok(layer_sizes)
1997    }
1998
1999    fn num_layers(&self, config: &str) -> Result<usize> {
2000        let config: MLlamaConfig = serde_json::from_str(config)?;
2001        Ok(config.text_config.num_hidden_layers)
2002    }
2003
2004    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2005        let cfg: MLlamaConfig = serde_json::from_str(config)?;
2006        let cfg = &cfg.text_config;
2007
2008        let cfg = ModelConfigMetadata {
2009            max_seq_len: cfg.max_position_embeddings,
2010            num_layers: cfg.num_hidden_layers,
2011            hidden_size: cfg.hidden_size,
2012            num_kv_heads: cfg.num_key_value_heads,
2013            num_attn_heads: cfg.num_attention_heads,
2014            sliding_window: None,
2015            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2016            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2017        };
2018
2019        Ok(Box::new(cfg))
2020    }
2021
2022    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2023        Some(vec![NonMappedSubModel::Vision])
2024    }
2025}
2026
2027// ======================== Qwen2VL Loader
2028
2029/// [`VisionLoader`] for an Qwen2-VL model.
2030///
2031/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2032pub struct Qwen2VLLoader;
2033
2034pub struct Qwen2VLPrefixer;
2035
2036impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
2037    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2038        format!(
2039            "{}{prompt}",
2040            format!(
2041                "{}{}{}",
2042                Qwen2VLProcessor::VISION_START,
2043                Qwen2VLProcessor::IMAGE_PAD,
2044                Qwen2VLProcessor::VISION_END
2045            )
2046            .repeat(image_indexes.len())
2047        )
2048    }
2049}
2050
2051impl VisionModelLoader for Qwen2VLLoader {
2052    fn load(
2053        &self,
2054        config: &str,
2055        vb: ShardedVarBuilder,
2056        normal_loading_metadata: NormalLoadingMetadata,
2057        attention_mechanism: AttentionImplementation,
2058    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2059        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2060        Ok(Box::new(Qwen2VLModel::new(
2061            &cfg,
2062            vb,
2063            self.is_gptx(config),
2064            normal_loading_metadata,
2065            attention_mechanism,
2066        )?))
2067    }
2068    fn is_gptx(&self, _config: &str) -> bool {
2069        true
2070    }
2071    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2072        let config: Qwen2VLConfig = serde_json::from_str(config)?;
2073        Ok(Box::new(config))
2074    }
2075    fn get_processor(
2076        &self,
2077        _model_config: &str,
2078        _processor_config: Option<ProcessorConfig>,
2079        _preprocessor_config: PreProcessorConfig,
2080        max_edge: Option<u32>,
2081    ) -> Arc<dyn Processor + Send + Sync> {
2082        Arc::new(Qwen2VLProcessor::new(max_edge))
2083    }
2084    fn supports_paged_attention(&self, _config: &str) -> bool {
2085        false
2086    }
2087    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2088        Arc::new(Qwen2VLPrefixer)
2089    }
2090    fn modalities(&self, _config: &str) -> Result<Modalities> {
2091        Ok(Modalities {
2092            input: vec![SupportedModality::Text, SupportedModality::Vision],
2093            output: vec![SupportedModality::Text],
2094        })
2095    }
2096}
2097
2098impl IsqModelLoader for Qwen2VLLoader {
2099    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2100        Ok(vec![
2101            Regex::new(r"lm_head\.(weight|bias)$")?,
2102            // Attention
2103            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2104            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2105            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2106            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2107            // MLP
2108            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2109            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2110            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2111        ])
2112    }
2113    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2114        self.isq_layer_regexes(config)
2115    }
2116}
2117
2118impl DeviceMappedModelLoader for Qwen2VLLoader {
2119    fn mapped_max_act_size_elems(
2120        &self,
2121        config: &str,
2122        params: &AutoDeviceMapParams,
2123    ) -> Result<usize> {
2124        let AutoDeviceMapParams::Vision {
2125            max_seq_len,
2126            max_batch_size,
2127            max_image_shape,
2128            max_num_images,
2129        } = params
2130        else {
2131            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2132        };
2133
2134        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2135
2136        // For images, grid_t=1. After spatial merging, grid_h and grid_w are reduced.
2137        let img_seq_len = {
2138            let cfg = &cfg.vision_config;
2139            // grid_t is 1 for images (temporal dimension is for video only)
2140            let grid_t = 1;
2141            // After patch embedding and spatial merge, the effective grid dimensions are reduced
2142            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
2143            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
2144            grid_t * grid_h * grid_w * max_num_images
2145        };
2146
2147        let max_text_attn = {
2148            // This model injects the vision information directly into the input embeddings
2149            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2150            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2151        };
2152
2153        Ok(max_text_attn)
2154    }
2155
2156    fn non_mapped_max_act_size_elems(
2157        &self,
2158        config: &str,
2159        params: &AutoDeviceMapParams,
2160    ) -> Result<usize> {
2161        let AutoDeviceMapParams::Vision {
2162            max_seq_len: _,
2163            max_batch_size,
2164            max_image_shape,
2165            max_num_images,
2166        } = params
2167        else {
2168            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2169        };
2170
2171        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2172
2173        // For the vision encoder, before spatial merging
2174        let img_seq_len = {
2175            let cfg = &cfg.vision_config;
2176            // grid_t is 1 for images
2177            let grid_t = 1;
2178            let grid_h = max_image_shape.0 / cfg.patch_size;
2179            let grid_w = max_image_shape.1 / cfg.patch_size;
2180            grid_t * grid_h * grid_w
2181        };
2182
2183        let max_vision_attn = {
2184            let cfg = &cfg.vision_config;
2185            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2186        };
2187
2188        Ok(max_vision_attn)
2189    }
2190
2191    fn non_mapped_size_in_bytes(
2192        &self,
2193        config: &str,
2194        dtype: DType,
2195        weight_pack_factor: usize,
2196        _matformer_config: Option<&MatformerSliceConfig>,
2197    ) -> Result<usize> {
2198        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2199        let text_elems = {
2200            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2201            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2202            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2203                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2204            } else {
2205                0
2206            };
2207            let norm = cfg.hidden_size;
2208            embed_tokens + lm_head + norm
2209        };
2210
2211        let patch_merger = {
2212            let cfg = &cfg.vision_config;
2213            let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2214
2215            let mlp0 = hidden_size * hidden_size + hidden_size;
2216            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2217
2218            let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2219
2220            mlp0 + mlp2 + ln_q
2221        };
2222
2223        let patch_embed = {
2224            let cfg = &cfg.vision_config;
2225            let conv_cfg = Conv3dConfig {
2226                stride: cfg.patch_size,
2227                ..Default::default()
2228            };
2229            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2230            cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2231                * kernel_sizes[0]
2232                * kernel_sizes[1]
2233                * kernel_sizes[2]
2234        };
2235
2236        let encoder_layer = {
2237            let cfg = &cfg.vision_config;
2238            let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2239            let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2240
2241            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2242            let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2243            let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2244            let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2245
2246            let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2247            let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2248
2249            norm1 + norm2 + fc1 + fc2 + qkv + out
2250        };
2251
2252        let elems =
2253            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2254
2255        Ok(elems * dtype.size_in_bytes())
2256    }
2257
2258    fn layer_sizes_in_bytes(
2259        &self,
2260        config: &str,
2261        dtype: DType,
2262        weight_pack_factor: usize,
2263        _matformer_config: Option<&MatformerSliceConfig>,
2264    ) -> Result<Vec<usize>> {
2265        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2266        let per_layer_elems = {
2267            let input_layernorm = cfg.hidden_size;
2268            let post_attention_layernorm = cfg.hidden_size;
2269
2270            let size_in = cfg.hidden_size;
2271            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2272            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2273            let q_proj = size_in * size_q / weight_pack_factor + size_q;
2274            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2275            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2276            let o_proj = size_q * size_in / weight_pack_factor;
2277
2278            let h_size = cfg.hidden_size;
2279            let i_size = cfg.intermediate_size;
2280            let gate_proj = h_size * i_size / weight_pack_factor;
2281            let up_proj = h_size * i_size / weight_pack_factor;
2282            let down_proj = i_size * h_size / weight_pack_factor;
2283
2284            input_layernorm
2285                + post_attention_layernorm
2286                + q_proj
2287                + k_proj
2288                + v_proj
2289                + o_proj
2290                + gate_proj
2291                + up_proj
2292                + down_proj
2293        };
2294        Ok(vec![
2295            per_layer_elems * dtype.size_in_bytes();
2296            cfg.num_hidden_layers
2297        ])
2298    }
2299
2300    fn num_layers(&self, config: &str) -> Result<usize> {
2301        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2302        Ok(cfg.num_hidden_layers)
2303    }
2304
2305    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2306        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2307
2308        let cfg = ModelConfigMetadata {
2309            max_seq_len: cfg.max_position_embeddings,
2310            num_layers: cfg.num_hidden_layers,
2311            hidden_size: cfg.hidden_size,
2312            num_kv_heads: cfg.num_key_value_heads,
2313            num_attn_heads: cfg.num_attention_heads,
2314            sliding_window: cfg.sliding_window,
2315            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2316            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2317        };
2318
2319        Ok(Box::new(cfg))
2320    }
2321
2322    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2323        Some(vec![NonMappedSubModel::Vision])
2324    }
2325}
2326
2327// ======================== Idefics 3 loader
2328
2329/// [`VisionLoader`] for an Idefics 3 Vision model.
2330///
2331/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2332pub struct Idefics3Loader;
2333
2334pub struct Idefics3Prefixer;
2335
2336impl MultimodalPromptPrefixer for Idefics3Prefixer {
2337    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2338        // Chat template does it
2339        prompt.to_string()
2340    }
2341}
2342
2343impl VisionModelLoader for Idefics3Loader {
2344    fn load(
2345        &self,
2346        config: &str,
2347        vb: ShardedVarBuilder,
2348        normal_loading_metadata: NormalLoadingMetadata,
2349        attention_mechanism: AttentionImplementation,
2350    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2351        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2352        Ok(Box::new(Idefics3Model::new(
2353            &cfg,
2354            vb,
2355            self.is_gptx(config),
2356            normal_loading_metadata,
2357            attention_mechanism,
2358        )?))
2359    }
2360    fn is_gptx(&self, _config: &str) -> bool {
2361        true
2362    }
2363    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2364        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2365        Ok(Box::new(cfg))
2366    }
2367    fn get_processor(
2368        &self,
2369        _model_config: &str,
2370        processor_config: Option<ProcessorConfig>,
2371        preprocessor_config: PreProcessorConfig,
2372        max_edge: Option<u32>,
2373    ) -> Arc<dyn Processor + Send + Sync> {
2374        Arc::new(Idefics3Processor::new(
2375            processor_config.unwrap_or_default(),
2376            preprocessor_config,
2377            max_edge,
2378        ))
2379    }
2380    fn supports_paged_attention(&self, _config: &str) -> bool {
2381        true
2382    }
2383    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2384        true
2385    }
2386    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2387        Arc::new(Idefics3Prefixer)
2388    }
2389    fn modalities(&self, _config: &str) -> Result<Modalities> {
2390        Ok(Modalities {
2391            input: vec![SupportedModality::Text, SupportedModality::Vision],
2392            output: vec![SupportedModality::Text],
2393        })
2394    }
2395}
2396
2397impl IsqModelLoader for Idefics3Loader {
2398    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2399        Ok(vec![
2400            Regex::new(r"lm_head\.(weight|bias)$")?,
2401            // Attention
2402            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2403            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2404            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2405            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2406            // MLP
2407            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2408            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2409            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2410        ])
2411    }
2412    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2413        Ok(vec![
2414            Regex::new(r"lm_head\.(weight|bias)$")?,
2415            // Attention
2416            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2417            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2418            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2419            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2420            // MLP
2421            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2422            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2423            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2424            // // Attention (vision)
2425            // Regex::new(
2426            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2427            // )?,
2428            // Regex::new(
2429            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
2430            // )?,
2431            // Regex::new(
2432            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
2433            // )?,
2434            // Regex::new(
2435            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(weight|bias)$",
2436            // )?,
2437            // MLP (vision)
2438            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2439            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
2440        ])
2441    }
2442}
2443
2444impl DeviceMappedModelLoader for Idefics3Loader {
2445    fn mapped_max_act_size_elems(
2446        &self,
2447        config: &str,
2448        params: &AutoDeviceMapParams,
2449    ) -> Result<usize> {
2450        let AutoDeviceMapParams::Vision {
2451            max_seq_len,
2452            max_batch_size,
2453            max_image_shape: _,
2454            max_num_images,
2455        } = params
2456        else {
2457            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2458        };
2459
2460        let cfg: Idefics3Config = serde_json::from_str(config)?;
2461
2462        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2463        let img_seq_len = (num_patches + 1) * max_num_images;
2464
2465        let max_text_attn = {
2466            // This model injects the vision information directly into the input embeddings
2467            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2468            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2469        };
2470
2471        Ok(max_text_attn)
2472    }
2473
2474    fn non_mapped_max_act_size_elems(
2475        &self,
2476        config: &str,
2477        params: &AutoDeviceMapParams,
2478    ) -> Result<usize> {
2479        let AutoDeviceMapParams::Vision {
2480            max_seq_len: _,
2481            max_batch_size,
2482            max_image_shape: _,
2483            max_num_images,
2484        } = params
2485        else {
2486            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2487        };
2488
2489        let cfg: Idefics3Config = serde_json::from_str(config)?;
2490
2491        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2492        let img_seq_len = num_patches + 1;
2493
2494        let max_vision_attn = {
2495            // do_image_splitting = true
2496            let images_factor = 5;
2497
2498            (max_batch_size * images_factor * max_num_images)
2499                * cfg.vision_config.num_attention_heads
2500                * img_seq_len
2501                * img_seq_len
2502        };
2503
2504        Ok(max_vision_attn)
2505    }
2506
2507    fn non_mapped_size_in_bytes(
2508        &self,
2509        config: &str,
2510        dtype: DType,
2511        weight_pack_factor: usize,
2512        _matformer_config: Option<&MatformerSliceConfig>,
2513    ) -> Result<usize> {
2514        let cfg: Idefics3Config = serde_json::from_str(config)?;
2515        let text_elems = {
2516            let cfg = &cfg.text_config;
2517
2518            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2519            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2520            let norm = cfg.hidden_size;
2521            embed_tokens + lm_head + norm
2522        };
2523
2524        let connector_elems = {
2525            let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2526            let out_dim = cfg.text_config.hidden_size;
2527
2528            in_dim * out_dim
2529        };
2530
2531        let vision_transformer = {
2532            let cfg = &cfg.vision_config;
2533
2534            let post_layernorm = cfg.hidden_size;
2535
2536            let conv_config = Conv2dConfig {
2537                stride: cfg.patch_size,
2538                ..Default::default()
2539            };
2540            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2541                * cfg.patch_size
2542                * cfg.patch_size;
2543
2544            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2545            let num_patches = num_patches_per_side.pow(2);
2546            let position_embedding = num_patches * cfg.hidden_size;
2547
2548            let layer_elems = {
2549                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2550                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2551
2552                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2553                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2554
2555                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2556                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2557                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2558                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2559
2560                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2561            };
2562
2563            post_layernorm
2564                + patch_embedding
2565                + position_embedding
2566                + layer_elems * cfg.num_hidden_layers
2567        };
2568
2569        let elems = text_elems + connector_elems + vision_transformer;
2570
2571        Ok(elems * dtype.size_in_bytes())
2572    }
2573
2574    fn layer_sizes_in_bytes(
2575        &self,
2576        config: &str,
2577        dtype: DType,
2578        weight_pack_factor: usize,
2579        _matformer_config: Option<&MatformerSliceConfig>,
2580    ) -> Result<Vec<usize>> {
2581        let cfg: Idefics3Config = serde_json::from_str(config)?;
2582        let cfg = cfg.text_config;
2583        let per_layer_elems = {
2584            let input_layernorm = cfg.hidden_size;
2585            let post_attention_layernorm = cfg.hidden_size;
2586
2587            let size_in = cfg.hidden_size;
2588            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2589            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2590            let q_proj = size_in * size_q / weight_pack_factor;
2591            let k_proj = size_in * size_kv / weight_pack_factor;
2592            let v_proj = size_in * size_kv / weight_pack_factor;
2593            let o_proj = size_q * size_in / weight_pack_factor;
2594
2595            let h_size = cfg.hidden_size;
2596            let i_size = cfg.intermediate_size;
2597            let gate_proj = h_size * i_size / weight_pack_factor;
2598            let up_proj = h_size * i_size / weight_pack_factor;
2599            let down_proj = i_size * h_size / weight_pack_factor;
2600
2601            input_layernorm
2602                + post_attention_layernorm
2603                + q_proj
2604                + k_proj
2605                + v_proj
2606                + o_proj
2607                + gate_proj
2608                + up_proj
2609                + down_proj
2610        };
2611        Ok(vec![
2612            per_layer_elems * dtype.size_in_bytes();
2613            cfg.num_hidden_layers
2614        ])
2615    }
2616
2617    fn num_layers(&self, config: &str) -> Result<usize> {
2618        let cfg: Idefics3Config = serde_json::from_str(config)?;
2619        Ok(cfg.text_config.num_hidden_layers)
2620    }
2621    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2622        let cfg: Idefics3Config = serde_json::from_str(config)?;
2623        let cfg = &cfg.text_config;
2624
2625        let cfg = ModelConfigMetadata {
2626            max_seq_len: cfg.max_position_embeddings,
2627            num_layers: cfg.num_hidden_layers,
2628            hidden_size: cfg.hidden_size,
2629            num_kv_heads: cfg.num_key_value_heads,
2630            num_attn_heads: cfg.num_attention_heads,
2631            sliding_window: None,
2632            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2633            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2634        };
2635
2636        Ok(Box::new(cfg))
2637    }
2638
2639    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2640        Some(vec![NonMappedSubModel::Vision])
2641    }
2642}
2643
2644// ======================== MiniCpm-O loader
2645
2646/// [`VisionLoader`] for an MiniCpm-O model.
2647///
2648/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2649pub struct MiniCpmOLoader;
2650
2651pub struct MiniCpmOPrefixer;
2652
2653impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2654    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2655        format!(
2656            "{}{prompt}",
2657            "(<image>./</image>)".repeat(image_indexes.len())
2658        )
2659    }
2660}
2661
2662impl VisionModelLoader for MiniCpmOLoader {
2663    fn load(
2664        &self,
2665        config: &str,
2666        vb: ShardedVarBuilder,
2667        normal_loading_metadata: NormalLoadingMetadata,
2668        attention_mechanism: AttentionImplementation,
2669    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2670        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2671        Ok(Box::new(MiniCpmOModel::new(
2672            &cfg,
2673            vb,
2674            self.is_gptx(config),
2675            normal_loading_metadata,
2676            attention_mechanism,
2677        )?))
2678    }
2679    fn is_gptx(&self, _config: &str) -> bool {
2680        true
2681    }
2682    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2683        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2684        Ok(Box::new(cfg))
2685    }
2686    fn get_processor(
2687        &self,
2688        _model_config: &str,
2689        processor_config: Option<ProcessorConfig>,
2690        preprocessor_config: PreProcessorConfig,
2691        max_edge: Option<u32>,
2692    ) -> Arc<dyn Processor + Send + Sync> {
2693        Arc::new(MiniCpmOProcessor::new(
2694            processor_config.unwrap_or_default(),
2695            preprocessor_config,
2696            max_edge,
2697        ))
2698    }
2699    fn supports_paged_attention(&self, _config: &str) -> bool {
2700        true
2701    }
2702    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2703        Arc::new(MiniCpmOPrefixer)
2704    }
2705    fn modalities(&self, _config: &str) -> Result<Modalities> {
2706        Ok(Modalities {
2707            input: vec![SupportedModality::Text, SupportedModality::Vision],
2708            output: vec![SupportedModality::Text],
2709        })
2710    }
2711}
2712
2713impl IsqModelLoader for MiniCpmOLoader {
2714    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2715        Ok(vec![
2716            Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2717            // Attention
2718            Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2719            Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2720            Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2721            Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2722            // MLP
2723            Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2724            Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2725            Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2726        ])
2727    }
2728    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2729        self.isq_layer_regexes(config)
2730    }
2731}
2732
2733impl DeviceMappedModelLoader for MiniCpmOLoader {
2734    fn mapped_max_act_size_elems(
2735        &self,
2736        config: &str,
2737        params: &AutoDeviceMapParams,
2738    ) -> Result<usize> {
2739        let AutoDeviceMapParams::Vision {
2740            max_seq_len,
2741            max_batch_size,
2742            max_image_shape: _,
2743            max_num_images,
2744        } = params
2745        else {
2746            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2747        };
2748
2749        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2750
2751        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2752        let img_seq_len = (num_patches + 1) * max_num_images;
2753
2754        let max_text_attn = {
2755            // This model injects the vision information directly into the input embeddings
2756            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2757            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2758        };
2759
2760        Ok(max_text_attn)
2761    }
2762
2763    fn non_mapped_max_act_size_elems(
2764        &self,
2765        config: &str,
2766        params: &AutoDeviceMapParams,
2767    ) -> Result<usize> {
2768        let AutoDeviceMapParams::Vision {
2769            max_seq_len: _,
2770            max_batch_size,
2771            max_image_shape: _,
2772            max_num_images,
2773        } = params
2774        else {
2775            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2776        };
2777
2778        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2779
2780        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2781        let img_seq_len = num_patches + 1;
2782
2783        let max_vision_attn = {
2784            // do_image_splitting = true
2785            let images_factor = 5;
2786
2787            (max_batch_size * images_factor * max_num_images)
2788                * cfg.vision_config.num_attention_heads
2789                * img_seq_len
2790                * img_seq_len
2791        };
2792
2793        Ok(max_vision_attn)
2794    }
2795
2796    fn non_mapped_size_in_bytes(
2797        &self,
2798        config: &str,
2799        dtype: DType,
2800        weight_pack_factor: usize,
2801        _matformer_config: Option<&MatformerSliceConfig>,
2802    ) -> Result<usize> {
2803        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2804        let text_elems = {
2805            let cfg = &cfg.text_config;
2806
2807            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2808            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2809            let norm = cfg.hidden_size;
2810            embed_tokens + lm_head + norm
2811        };
2812
2813        let vision_transformer = {
2814            let cfg = &cfg.vision_config;
2815
2816            let post_layernorm = cfg.hidden_size;
2817
2818            let conv_config = Conv2dConfig {
2819                stride: cfg.patch_size,
2820                ..Default::default()
2821            };
2822            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2823                * cfg.patch_size
2824                * cfg.patch_size;
2825
2826            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2827            let num_patches = num_patches_per_side.pow(2);
2828            let position_embedding = num_patches * cfg.hidden_size;
2829
2830            let layer_elems = {
2831                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2832                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2833
2834                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2835                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2836
2837                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2838                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2839                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2840                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2841
2842                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2843            };
2844
2845            post_layernorm
2846                + patch_embedding
2847                + position_embedding
2848                + layer_elems * cfg.num_hidden_layers
2849        };
2850
2851        let elems = text_elems + vision_transformer;
2852
2853        Ok(elems * dtype.size_in_bytes())
2854    }
2855
2856    fn layer_sizes_in_bytes(
2857        &self,
2858        config: &str,
2859        dtype: DType,
2860        weight_pack_factor: usize,
2861        _matformer_config: Option<&MatformerSliceConfig>,
2862    ) -> Result<Vec<usize>> {
2863        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2864        let cfg = cfg.text_config;
2865        let per_layer_elems = {
2866            let input_layernorm = cfg.hidden_size;
2867            let post_attention_layernorm = cfg.hidden_size;
2868
2869            let size_in = cfg.hidden_size;
2870            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2871            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2872            let q_proj = size_in * size_q / weight_pack_factor;
2873            let k_proj = size_in * size_kv / weight_pack_factor;
2874            let v_proj = size_in * size_kv / weight_pack_factor;
2875            let o_proj = size_q * size_in / weight_pack_factor;
2876
2877            let h_size = cfg.hidden_size;
2878            let i_size = cfg.intermediate_size;
2879            let gate_proj = h_size * i_size / weight_pack_factor;
2880            let up_proj = h_size * i_size / weight_pack_factor;
2881            let down_proj = i_size * h_size / weight_pack_factor;
2882
2883            input_layernorm
2884                + post_attention_layernorm
2885                + q_proj
2886                + k_proj
2887                + v_proj
2888                + o_proj
2889                + gate_proj
2890                + up_proj
2891                + down_proj
2892        };
2893        Ok(vec![
2894            per_layer_elems * dtype.size_in_bytes();
2895            cfg.num_hidden_layers
2896        ])
2897    }
2898
2899    fn num_layers(&self, config: &str) -> Result<usize> {
2900        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2901        Ok(cfg.text_config.num_hidden_layers)
2902    }
2903    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2904        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2905        let cfg = &cfg.text_config;
2906
2907        let cfg = ModelConfigMetadata {
2908            max_seq_len: cfg.max_position_embeddings,
2909            num_layers: cfg.num_hidden_layers,
2910            hidden_size: cfg.hidden_size,
2911            num_kv_heads: cfg.num_key_value_heads,
2912            num_attn_heads: cfg.num_attention_heads,
2913            sliding_window: None,
2914            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2915            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2916        };
2917
2918        Ok(Box::new(cfg))
2919    }
2920}
2921
2922// ======================== Phi 4MM loader
2923
2924/// [`VisionLoader`] for a Phi 4MM Vision model.
2925///
2926/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2927pub struct Phi4MMLoader;
2928
2929pub struct Phi4MMPrefixer;
2930
2931impl MultimodalPromptPrefixer for Phi4MMPrefixer {
2932    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2933        // Image indexing starts at 0.
2934
2935        format!(
2936            "{}{prompt}",
2937            image_indexes
2938                .into_iter()
2939                .map(|image_index| format!("<|image_{}|>", image_index + 1))
2940                .join("")
2941        )
2942    }
2943    fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
2944        // Image indexing starts at 0.
2945
2946        format!(
2947            "{}{prompt}",
2948            audio_indexes
2949                .into_iter()
2950                .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
2951                .join("")
2952        )
2953    }
2954}
2955
2956impl VisionModelLoader for Phi4MMLoader {
2957    fn load(
2958        &self,
2959        config: &str,
2960        vb: ShardedVarBuilder,
2961        normal_loading_metadata: NormalLoadingMetadata,
2962        attention_mechanism: AttentionImplementation,
2963    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2964        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2965        Ok(Box::new(Phi4MMModel::new(
2966            &cfg,
2967            vb,
2968            self.is_gptx(config),
2969            normal_loading_metadata,
2970            attention_mechanism,
2971        )?))
2972    }
2973    fn is_gptx(&self, _config: &str) -> bool {
2974        true
2975    }
2976    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2977        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2978        Ok(Box::new(cfg))
2979    }
2980    fn get_processor(
2981        &self,
2982        _model_config: &str,
2983        processor_config: Option<ProcessorConfig>,
2984        preprocessor_config: PreProcessorConfig,
2985        _max_edge: Option<u32>,
2986    ) -> Arc<dyn Processor + Send + Sync> {
2987        Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2988    }
2989    fn supports_paged_attention(&self, _config: &str) -> bool {
2990        true
2991    }
2992    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2993        true
2994    }
2995    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2996        Arc::new(Phi4MMPrefixer)
2997    }
2998    fn modalities(&self, _config: &str) -> Result<Modalities> {
2999        Ok(Modalities {
3000            input: vec![
3001                SupportedModality::Text,
3002                SupportedModality::Vision,
3003                SupportedModality::Audio,
3004            ],
3005            output: vec![SupportedModality::Text],
3006        })
3007    }
3008}
3009
3010impl IsqModelLoader for Phi4MMLoader {
3011    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3012        Ok(vec![
3013            Regex::new(r"lm_head\.(weight|bias)$")?,
3014            // Attention
3015            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
3016            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3017            // MLP
3018            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
3019            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3020        ])
3021    }
3022    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3023        self.isq_layer_regexes(config)
3024    }
3025}
3026
3027impl DeviceMappedModelLoader for Phi4MMLoader {
3028    fn mapped_max_act_size_elems(
3029        &self,
3030        config: &str,
3031        params: &AutoDeviceMapParams,
3032    ) -> Result<usize> {
3033        // NOTE: we ignore max_num_images although it can only be one...
3034        let AutoDeviceMapParams::Vision {
3035            max_seq_len,
3036            max_batch_size,
3037            max_image_shape: _,
3038            max_num_images,
3039        } = params
3040        else {
3041            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3042        };
3043
3044        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3045
3046        let vcfg = &PHI4_MM_VISION_CFG;
3047
3048        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3049        let img_seq_len = (num_patches + 1) * max_num_images;
3050
3051        let max_text_attn = {
3052            // This model injects the vision information directly into the input embeddings
3053            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3054            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3055        };
3056
3057        Ok(max_text_attn)
3058    }
3059
3060    fn non_mapped_max_act_size_elems(
3061        &self,
3062        _config: &str,
3063        params: &AutoDeviceMapParams,
3064    ) -> Result<usize> {
3065        let AutoDeviceMapParams::Vision {
3066            max_seq_len: _,
3067            max_batch_size,
3068            max_image_shape,
3069            max_num_images,
3070        } = params
3071        else {
3072            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3073        };
3074
3075        let vcfg = &PHI4_MM_VISION_CFG;
3076
3077        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3078        let img_seq_len = num_patches + 1;
3079
3080        let max_batch_size = max_batch_size
3081            * (max_image_shape
3082                .0
3083                .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3084                * max_image_shape
3085                    .1
3086                    .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3087                + 1);
3088
3089        let max_vision_attn = (max_batch_size * max_num_images)
3090            * vcfg.num_attention_heads
3091            * img_seq_len
3092            * img_seq_len;
3093        let max_qkv = 3
3094            * (max_batch_size
3095                * vcfg.num_attention_heads
3096                * img_seq_len
3097                * (vcfg.hidden_size / vcfg.num_attention_heads));
3098
3099        Ok(max_vision_attn + max_qkv)
3100    }
3101
3102    fn non_mapped_size_in_bytes(
3103        &self,
3104        config: &str,
3105        dtype: DType,
3106        weight_pack_factor: usize,
3107        _matformer_config: Option<&MatformerSliceConfig>,
3108    ) -> Result<usize> {
3109        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3110        let elems = {
3111            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3112            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3113            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3114                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3115            } else {
3116                0
3117            };
3118            let norm = cfg.hidden_size;
3119
3120            let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3121                let projection_cls = img_embed
3122                    .projection_cls
3123                    .clone()
3124                    .unwrap_or("linear".to_string());
3125                let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3126                let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3127                let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3128
3129                let proj = match (projection_cls.as_str(), use_hd_transform) {
3130                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3131                    ("mlp", true) => {
3132                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3133                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3134                        a + b
3135                    }
3136                    ("mlp", false) => {
3137                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3138                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3139                        a + b
3140                    }
3141                    _ => {
3142                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3143                    }
3144                };
3145
3146                let (glb_gn, sub_gn) = if with_learnable_separator {
3147                    let glb_gn = image_dim_out * 4;
3148                    let sub_gn = image_dim_out * 4;
3149                    (glb_gn, sub_gn)
3150                } else {
3151                    (0, 0)
3152                };
3153
3154                let vision_transformer = {
3155                    let cfg = &PHI4_MM_VISION_CFG;
3156
3157                    let post_layernorm = cfg.hidden_size;
3158
3159                    let conv_config = Conv2dConfig {
3160                        stride: cfg.patch_size,
3161                        ..Default::default()
3162                    };
3163                    let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3164                        * cfg.patch_size
3165                        * cfg.patch_size;
3166
3167                    let num_patches_per_side = cfg.image_size / cfg.patch_size;
3168                    let num_patches = num_patches_per_side.pow(2);
3169                    let position_embedding = num_patches * cfg.hidden_size;
3170
3171                    let layer_elems = {
3172                        let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3173                        let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3174
3175                        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3176                        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3177
3178                        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3179                        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3180                        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3181                        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3182
3183                        layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3184                    };
3185
3186                    post_layernorm
3187                        + patch_embedding
3188                        + position_embedding
3189                        + layer_elems * cfg.num_hidden_layers
3190                };
3191
3192                proj + glb_gn + sub_gn + vision_transformer
3193            } else {
3194                0
3195            };
3196
3197            embed_tokens + lm_head + norm + image_embed
3198        };
3199
3200        Ok(elems * dtype.size_in_bytes())
3201    }
3202
3203    fn layer_sizes_in_bytes(
3204        &self,
3205        config: &str,
3206        dtype: DType,
3207        weight_pack_factor: usize,
3208        _matformer_config: Option<&MatformerSliceConfig>,
3209    ) -> Result<Vec<usize>> {
3210        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3211        let per_layer_elems = {
3212            let input_layernorm = cfg.hidden_size;
3213            let post_attention_layernorm = cfg.hidden_size;
3214
3215            let size_in = cfg.hidden_size;
3216            let head_dim = cfg.head_dim();
3217            let op_size =
3218                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3219            let qkv_proj = size_in * op_size / weight_pack_factor;
3220            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3221
3222            let h_size = cfg.hidden_size;
3223            let i_size = cfg.intermediate_size;
3224            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3225            let down_proj = h_size * i_size / weight_pack_factor;
3226
3227            input_layernorm
3228                + post_attention_layernorm
3229                + qkv_proj
3230                + o_proj
3231                + gate_up_proj
3232                + down_proj
3233        };
3234        Ok(vec![
3235            per_layer_elems * dtype.size_in_bytes();
3236            cfg.num_hidden_layers
3237        ])
3238    }
3239
3240    fn num_layers(&self, config: &str) -> Result<usize> {
3241        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3242        Ok(cfg.num_hidden_layers)
3243    }
3244
3245    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3246        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3247
3248        let cfg = ModelConfigMetadata {
3249            max_seq_len: cfg.max_position_embeddings,
3250            num_layers: cfg.num_hidden_layers,
3251            hidden_size: cfg.hidden_size,
3252            num_kv_heads: cfg.num_key_value_heads(),
3253            num_attn_heads: cfg.num_attention_heads,
3254            sliding_window: cfg.sliding_window,
3255            k_head_dim: cfg.head_dim(),
3256            v_head_dim: cfg.head_dim(),
3257        };
3258
3259        Ok(Box::new(cfg))
3260    }
3261
3262    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3263        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
3264    }
3265}
3266
3267// ======================== Qwen2_5VL Loader
3268
3269/// [`VisionLoader`] for an Qwen2_5VL model.
3270///
3271/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3272pub struct Qwen2_5VLLoader;
3273
3274pub struct Qwen2_5VLPrefixer;
3275
3276impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3277    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3278        format!(
3279            "{}{prompt}",
3280            format!(
3281                "{}{}{}",
3282                Qwen2_5VLProcessor::VISION_START,
3283                Qwen2_5VLProcessor::IMAGE_PAD,
3284                Qwen2_5VLProcessor::VISION_END
3285            )
3286            .repeat(image_indexes.len())
3287        )
3288    }
3289}
3290
3291impl VisionModelLoader for Qwen2_5VLLoader {
3292    fn load(
3293        &self,
3294        config: &str,
3295        vb: ShardedVarBuilder,
3296        normal_loading_metadata: NormalLoadingMetadata,
3297        attention_mechanism: AttentionImplementation,
3298    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3299        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3300        Ok(Box::new(Qwen2_5VLModel::new(
3301            &cfg,
3302            vb,
3303            self.is_gptx(config),
3304            normal_loading_metadata,
3305            attention_mechanism,
3306        )?))
3307    }
3308    fn is_gptx(&self, _config: &str) -> bool {
3309        true
3310    }
3311    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3312        let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3313        Ok(Box::new(config))
3314    }
3315    fn get_processor(
3316        &self,
3317        _model_config: &str,
3318        _processor_config: Option<ProcessorConfig>,
3319        _preprocessor_config: PreProcessorConfig,
3320        max_edge: Option<u32>,
3321    ) -> Arc<dyn Processor + Send + Sync> {
3322        Arc::new(Qwen2_5VLProcessor::new(max_edge))
3323    }
3324    fn supports_paged_attention(&self, _config: &str) -> bool {
3325        false
3326    }
3327    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3328        Arc::new(Qwen2_5VLPrefixer)
3329    }
3330    fn modalities(&self, _config: &str) -> Result<Modalities> {
3331        Ok(Modalities {
3332            input: vec![SupportedModality::Text, SupportedModality::Vision],
3333            output: vec![SupportedModality::Text],
3334        })
3335    }
3336}
3337
3338impl IsqModelLoader for Qwen2_5VLLoader {
3339    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3340        Ok(vec![
3341            Regex::new(r"lm_head\.(weight|bias)$")?,
3342            // Attention
3343            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3344            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3345            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3346            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3347            // MLP
3348            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3349            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3350            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3351        ])
3352    }
3353    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3354        self.isq_layer_regexes(config)
3355    }
3356}
3357
3358impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3359    fn mapped_max_act_size_elems(
3360        &self,
3361        config: &str,
3362        params: &AutoDeviceMapParams,
3363    ) -> Result<usize> {
3364        let AutoDeviceMapParams::Vision {
3365            max_seq_len,
3366            max_batch_size,
3367            max_image_shape,
3368            max_num_images,
3369        } = params
3370        else {
3371            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3372        };
3373
3374        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3375
3376        let img_seq_len = {
3377            let cfg = &cfg.vision_config;
3378            let grid_t = max_num_images / cfg.temporal_patch_size;
3379            let grid_h = max_image_shape.0 / cfg.patch_size;
3380            let grid_w = max_image_shape.1 / cfg.patch_size;
3381            grid_t * grid_h * grid_w
3382        };
3383        let img_seq_len = img_seq_len * max_num_images;
3384
3385        let max_text_attn = {
3386            // This model injects the vision information directly into the input embeddings
3387            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3388            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3389        };
3390
3391        Ok(max_text_attn)
3392    }
3393
3394    fn non_mapped_max_act_size_elems(
3395        &self,
3396        config: &str,
3397        params: &AutoDeviceMapParams,
3398    ) -> Result<usize> {
3399        let AutoDeviceMapParams::Vision {
3400            max_seq_len: _,
3401            max_batch_size,
3402            max_image_shape,
3403            max_num_images,
3404        } = params
3405        else {
3406            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3407        };
3408
3409        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3410
3411        let img_seq_len = {
3412            let cfg = &cfg.vision_config;
3413            let grid_t = max_num_images / cfg.temporal_patch_size;
3414            let grid_h = max_image_shape.0 / cfg.patch_size;
3415            let grid_w = max_image_shape.1 / cfg.patch_size;
3416            grid_t * grid_h * grid_w
3417        };
3418
3419        let max_vision_attn = {
3420            let cfg = &cfg.vision_config;
3421            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3422        };
3423
3424        Ok(max_vision_attn)
3425    }
3426
3427    fn non_mapped_size_in_bytes(
3428        &self,
3429        config: &str,
3430        dtype: DType,
3431        weight_pack_factor: usize,
3432        _matformer_config: Option<&MatformerSliceConfig>,
3433    ) -> Result<usize> {
3434        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3435        let text_elems = {
3436            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3437            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3438            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3439                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3440            } else {
3441                0
3442            };
3443            let norm = cfg.hidden_size;
3444            embed_tokens + lm_head + norm
3445        };
3446
3447        let patch_merger = {
3448            let cfg = &cfg.vision_config;
3449            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3450
3451            let mlp0 = hidden_size * hidden_size + hidden_size;
3452            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3453
3454            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3455
3456            mlp0 + mlp2 + ln_q
3457        };
3458
3459        let patch_embed = {
3460            let cfg = &cfg.vision_config;
3461            let conv_cfg = Conv3dConfig {
3462                stride: cfg.patch_size,
3463                ..Default::default()
3464            };
3465            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3466            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3467                * kernel_sizes[0]
3468                * kernel_sizes[1]
3469                * kernel_sizes[2]
3470        };
3471
3472        let encoder_layer = {
3473            let cfg = &cfg.vision_config;
3474            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3475            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3476
3477            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3478            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3479            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3480
3481            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3482            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3483
3484            norm1 + norm2 + fc1 + fc2 + qkv + out
3485        };
3486
3487        let elems =
3488            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3489
3490        Ok(elems * dtype.size_in_bytes())
3491    }
3492
3493    fn layer_sizes_in_bytes(
3494        &self,
3495        config: &str,
3496        dtype: DType,
3497        weight_pack_factor: usize,
3498        _matformer_config: Option<&MatformerSliceConfig>,
3499    ) -> Result<Vec<usize>> {
3500        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3501        let per_layer_elems = {
3502            let input_layernorm = cfg.hidden_size;
3503            let post_attention_layernorm = cfg.hidden_size;
3504
3505            let size_in = cfg.hidden_size;
3506            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3507            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3508            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3509            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3510            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3511            let o_proj = size_q * size_in / weight_pack_factor;
3512
3513            let h_size = cfg.hidden_size;
3514            let i_size = cfg.intermediate_size;
3515            let gate_proj = h_size * i_size / weight_pack_factor;
3516            let up_proj = h_size * i_size / weight_pack_factor;
3517            let down_proj = i_size * h_size / weight_pack_factor;
3518
3519            input_layernorm
3520                + post_attention_layernorm
3521                + q_proj
3522                + k_proj
3523                + v_proj
3524                + o_proj
3525                + gate_proj
3526                + up_proj
3527                + down_proj
3528        };
3529        Ok(vec![
3530            per_layer_elems * dtype.size_in_bytes();
3531            cfg.num_hidden_layers
3532        ])
3533    }
3534
3535    fn num_layers(&self, config: &str) -> Result<usize> {
3536        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3537        Ok(cfg.num_hidden_layers)
3538    }
3539
3540    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3541        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3542
3543        let cfg = ModelConfigMetadata {
3544            max_seq_len: cfg.max_position_embeddings,
3545            num_layers: cfg.num_hidden_layers,
3546            hidden_size: cfg.hidden_size,
3547            num_kv_heads: cfg.num_key_value_heads,
3548            num_attn_heads: cfg.num_attention_heads,
3549            sliding_window: cfg.sliding_window,
3550            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3551            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3552        };
3553
3554        Ok(Box::new(cfg))
3555    }
3556
3557    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3558        Some(vec![NonMappedSubModel::Vision])
3559    }
3560}
3561
3562// ======================== Gemma 3 Loader
3563
3564/// [`VisionLoader`] for an Gemma 3 model.
3565///
3566/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3567pub struct Gemma3Loader;
3568
3569pub struct Gemma3Prefixer;
3570
3571impl MultimodalPromptPrefixer for Gemma3Prefixer {
3572    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3573        prompt.to_string()
3574    }
3575}
3576
3577impl VisionModelLoader for Gemma3Loader {
3578    fn load(
3579        &self,
3580        config: &str,
3581        vb: ShardedVarBuilder,
3582        normal_loading_metadata: NormalLoadingMetadata,
3583        attention_mechanism: AttentionImplementation,
3584    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3585        let cfg: Gemma3Config = serde_json::from_str(config)?;
3586        Ok(Box::new(Gemma3Model::new(
3587            &cfg,
3588            vb,
3589            self.is_gptx(config),
3590            normal_loading_metadata,
3591            attention_mechanism,
3592        )?))
3593    }
3594    fn is_gptx(&self, _config: &str) -> bool {
3595        true
3596    }
3597    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3598        let config: Gemma3Config = serde_json::from_str(config)?;
3599        Ok(Box::new(config))
3600    }
3601    fn get_processor(
3602        &self,
3603        config: &str,
3604        processor_config: Option<ProcessorConfig>,
3605        _preprocessor_config: PreProcessorConfig,
3606        _max_edge: Option<u32>,
3607    ) -> Arc<dyn Processor + Send + Sync> {
3608        let config: Gemma3Config = serde_json::from_str(config).unwrap();
3609        // Handle the Gemma 3 1b case here
3610        Arc::new(Gemma3Processor::new(
3611            processor_config.unwrap_or_default(),
3612            matches!(config, Gemma3Config::WithVision { .. }),
3613        ))
3614    }
3615    fn supports_paged_attention(&self, _config: &str) -> bool {
3616        true
3617    }
3618    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3619        true
3620    }
3621    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3622        Arc::new(Gemma3Prefixer)
3623    }
3624    fn modalities(&self, _config: &str) -> Result<Modalities> {
3625        Ok(Modalities {
3626            input: vec![SupportedModality::Text, SupportedModality::Vision],
3627            output: vec![SupportedModality::Text],
3628        })
3629    }
3630}
3631
3632impl IsqModelLoader for Gemma3Loader {
3633    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3634        Ok(vec![
3635            Regex::new(r"lm_head\.(weight|bias)$")?,
3636            // Attention
3637            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3638            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3639            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3640            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3641            // MLP
3642            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3643            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3644            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3645        ])
3646    }
3647    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3648        Ok(vec![
3649            Regex::new(r"lm_head\.(weight|bias)$")?,
3650            // Attention
3651            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3652            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3653            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3654            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3655            // MLP
3656            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3657            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3658            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3659        ])
3660    }
3661}
3662
3663impl DeviceMappedModelLoader for Gemma3Loader {
3664    fn mapped_max_act_size_elems(
3665        &self,
3666        config: &str,
3667        params: &AutoDeviceMapParams,
3668    ) -> Result<usize> {
3669        let AutoDeviceMapParams::Vision {
3670            max_seq_len,
3671            max_batch_size,
3672            max_image_shape: _,
3673            max_num_images,
3674        } = params
3675        else {
3676            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3677        };
3678
3679        let cfg: Gemma3Config = serde_json::from_str(config)?;
3680
3681        match cfg {
3682            Gemma3Config::Text(text_config) => Ok(max_batch_size
3683                * text_config.num_attention_heads
3684                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)),
3685            Gemma3Config::WithVision {
3686                text_config,
3687                vision_config,
3688                ..
3689            } => {
3690                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3691                let img_seq_len = (num_patches + 1) * max_num_images;
3692
3693                let max_text_attn = {
3694                    // This model injects the vision information directly into the input embeddings
3695                    let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3696                    max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3697                };
3698                Ok(max_text_attn)
3699            }
3700        }
3701    }
3702
3703    fn non_mapped_max_act_size_elems(
3704        &self,
3705        config: &str,
3706        params: &AutoDeviceMapParams,
3707    ) -> Result<usize> {
3708        let AutoDeviceMapParams::Vision {
3709            max_seq_len: _,
3710            max_batch_size,
3711            max_image_shape: _,
3712            max_num_images,
3713        } = params
3714        else {
3715            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3716        };
3717
3718        let cfg: Gemma3Config = serde_json::from_str(config)?;
3719
3720        match cfg {
3721            Gemma3Config::WithVision { vision_config, .. } => {
3722                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3723                let img_seq_len = num_patches + 1;
3724
3725                let max_vision_attn = {
3726                    (max_batch_size * max_num_images)
3727                        * vision_config.num_attention_heads
3728                        * img_seq_len
3729                        * img_seq_len
3730                };
3731
3732                Ok(max_vision_attn)
3733            }
3734            Gemma3Config::Text(_) => Ok(0),
3735        }
3736    }
3737
3738    fn non_mapped_size_in_bytes(
3739        &self,
3740        config: &str,
3741        dtype: DType,
3742        weight_pack_factor: usize,
3743        _matformer_config: Option<&MatformerSliceConfig>,
3744    ) -> Result<usize> {
3745        let cfg: Gemma3Config = serde_json::from_str(config)?;
3746
3747        let text_elems = {
3748            let cfg = match &cfg {
3749                Gemma3Config::Text(cfg) => cfg,
3750                Gemma3Config::WithVision { text_config, .. } => text_config,
3751            };
3752            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3753            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3754            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3755                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3756            } else {
3757                0
3758            };
3759            let norm = cfg.hidden_size;
3760            embed_tokens + lm_head + norm
3761        };
3762
3763        let vision_transformer = if let Gemma3Config::WithVision {
3764            vision_config: cfg, ..
3765        } = &cfg
3766        {
3767            let post_layernorm = cfg.hidden_size;
3768
3769            let conv_config = Conv2dConfig {
3770                stride: cfg.patch_size,
3771                ..Default::default()
3772            };
3773            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3774                * cfg.patch_size
3775                * cfg.patch_size;
3776
3777            let num_patches_per_side = cfg.image_size / cfg.patch_size;
3778            let num_patches = num_patches_per_side.pow(2);
3779            let position_embedding = num_patches * cfg.hidden_size;
3780
3781            let layer_elems = {
3782                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3783                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3784
3785                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3786                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3787
3788                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3789                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3790                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3791                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3792
3793                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3794            };
3795
3796            post_layernorm
3797                + patch_embedding
3798                + position_embedding
3799                + layer_elems * cfg.num_hidden_layers
3800        } else {
3801            0
3802        };
3803
3804        let elems = text_elems + vision_transformer;
3805
3806        Ok(elems * dtype.size_in_bytes())
3807    }
3808
3809    fn layer_sizes_in_bytes(
3810        &self,
3811        config: &str,
3812        dtype: DType,
3813        weight_pack_factor: usize,
3814        _matformer_config: Option<&MatformerSliceConfig>,
3815    ) -> Result<Vec<usize>> {
3816        let cfg: Gemma3Config = serde_json::from_str(config)?;
3817
3818        let txt_cfg = match &cfg {
3819            Gemma3Config::Text(cfg) => cfg,
3820            Gemma3Config::WithVision { text_config, .. } => text_config,
3821        };
3822        let per_layer_elems = {
3823            let cfg = txt_cfg;
3824
3825            let input_layernorm = cfg.hidden_size;
3826            let post_attention_layernorm = cfg.hidden_size;
3827
3828            let size_in = cfg.hidden_size;
3829            let size_q = cfg.head_dim * cfg.num_attention_heads;
3830            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3831            let q_proj =
3832                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3833            let k_proj =
3834                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3835            let v_proj =
3836                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3837            let o_proj =
3838                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3839
3840            let h_size = cfg.hidden_size;
3841            let i_size = cfg.intermediate_size;
3842            let gate_proj = h_size * i_size / weight_pack_factor;
3843            let up_proj = h_size * i_size / weight_pack_factor;
3844            let down_proj = i_size * h_size / weight_pack_factor;
3845
3846            input_layernorm
3847                + post_attention_layernorm
3848                + q_proj
3849                + k_proj
3850                + v_proj
3851                + o_proj
3852                + gate_proj
3853                + up_proj
3854                + down_proj
3855        };
3856        Ok(vec![
3857            per_layer_elems * dtype.size_in_bytes();
3858            txt_cfg.num_hidden_layers
3859        ])
3860    }
3861
3862    fn num_layers(&self, config: &str) -> Result<usize> {
3863        let cfg: Gemma3Config = serde_json::from_str(config)?;
3864
3865        let txt_cfg = match &cfg {
3866            Gemma3Config::Text(cfg) => cfg,
3867            Gemma3Config::WithVision { text_config, .. } => text_config,
3868        };
3869
3870        Ok(txt_cfg.num_hidden_layers)
3871    }
3872
3873    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3874        let cfg: Gemma3Config = serde_json::from_str(config)?;
3875
3876        let cfg = match &cfg {
3877            Gemma3Config::Text(cfg) => cfg,
3878            Gemma3Config::WithVision { text_config, .. } => text_config,
3879        };
3880
3881        let cfg = ModelConfigMetadata {
3882            max_seq_len: cfg.max_position_embeddings,
3883            num_layers: cfg.num_hidden_layers,
3884            hidden_size: cfg.hidden_size,
3885            num_kv_heads: cfg.num_key_value_heads,
3886            num_attn_heads: cfg.num_attention_heads,
3887            sliding_window: None, // None to be more forgiving, some do not
3888            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3889            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3890        };
3891
3892        Ok(Box::new(cfg))
3893    }
3894
3895    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3896        Some(vec![NonMappedSubModel::Vision])
3897    }
3898}
3899
3900// ======================== Mistral 3 Loader
3901
3902/// [`VisionLoader`] for an Mistral 3 model.
3903///
3904/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3905pub struct Mistral3Loader;
3906
3907pub struct Mistral3Prefixer;
3908
3909impl MultimodalPromptPrefixer for Mistral3Prefixer {
3910    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3911        prompt.to_string()
3912    }
3913}
3914
3915impl VisionModelLoader for Mistral3Loader {
3916    fn load(
3917        &self,
3918        config: &str,
3919        vb: ShardedVarBuilder,
3920        normal_loading_metadata: NormalLoadingMetadata,
3921        attention_mechanism: AttentionImplementation,
3922    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3923        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3924        Ok(Box::new(Mistral3Model::new(
3925            &cfg,
3926            vb,
3927            self.is_gptx(config),
3928            normal_loading_metadata,
3929            attention_mechanism,
3930        )?))
3931    }
3932    fn is_gptx(&self, _config: &str) -> bool {
3933        true
3934    }
3935    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3936        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3937        Ok(Box::new(cfg))
3938    }
3939    fn get_processor(
3940        &self,
3941        _model_config: &str,
3942        processor_config: Option<ProcessorConfig>,
3943        _preprocessor_config: PreProcessorConfig,
3944        _max_edge: Option<u32>,
3945    ) -> Arc<dyn Processor + Send + Sync> {
3946        Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3947    }
3948    fn supports_paged_attention(&self, _config: &str) -> bool {
3949        true
3950    }
3951    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3952        true
3953    }
3954    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3955        Arc::new(Mistral3Prefixer)
3956    }
3957    fn modalities(&self, _config: &str) -> Result<Modalities> {
3958        Ok(Modalities {
3959            input: vec![SupportedModality::Text, SupportedModality::Vision],
3960            output: vec![SupportedModality::Text],
3961        })
3962    }
3963}
3964
3965impl IsqModelLoader for Mistral3Loader {
3966    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3967        Ok(vec![
3968            Regex::new(r"lm_head\.(weight|bias)$")?,
3969            // Attention
3970            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3971            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3972            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3973            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3974            // MLP
3975            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3976            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3977            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3978        ])
3979    }
3980    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3981        Ok(vec![
3982            Regex::new(r"lm_head\.(weight|bias)$")?,
3983            // Attention
3984            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3985            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3986            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3987            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3988            // MLP
3989            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3990            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3991            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3992        ])
3993    }
3994}
3995
3996#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3997impl DeviceMappedModelLoader for Mistral3Loader {
3998    fn mapped_max_act_size_elems(
3999        &self,
4000        config: &str,
4001        params: &AutoDeviceMapParams,
4002    ) -> Result<usize> {
4003        let cfg: Mistral3Config = serde_json::from_str(config)?;
4004        let vcfg = &cfg.vision_config;
4005        let tcfg = &cfg.text_config;
4006
4007        let AutoDeviceMapParams::Vision {
4008            max_seq_len,
4009            max_batch_size,
4010            max_image_shape: (mut height, mut width),
4011            max_num_images,
4012        } = params
4013        else {
4014            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4015        };
4016
4017        let img_seq_len = {
4018            // Reshaping algorithm
4019
4020            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
4021            let (max_height, max_width) = (1540, 1540);
4022            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4023            if ratio > 1. {
4024                height = (height as f64 / ratio).floor() as usize;
4025                width = (width as f64 / ratio).floor() as usize;
4026            }
4027
4028            let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
4029            let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
4030
4031            height = num_height_tokens * vcfg.patch_size;
4032            width = num_width_tokens * vcfg.patch_size;
4033
4034            let num_height_tokens = height / vcfg.patch_size;
4035            let num_width_tokens = width / vcfg.patch_size;
4036
4037            (num_width_tokens + 1) * num_height_tokens
4038        };
4039
4040        // This model injects the vision information directly into the input embeddings
4041        let max_seq_len = img_seq_len * max_num_images + *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4042        Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
4043    }
4044
4045    fn non_mapped_max_act_size_elems(
4046        &self,
4047        config: &str,
4048        params: &AutoDeviceMapParams,
4049    ) -> Result<usize> {
4050        let cfg: Mistral3Config = serde_json::from_str(config)?;
4051        let cfg = &cfg.vision_config;
4052
4053        let AutoDeviceMapParams::Vision {
4054            max_seq_len: _,
4055            max_batch_size,
4056            max_image_shape: (mut height, mut width),
4057            max_num_images,
4058        } = params
4059        else {
4060            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4061        };
4062
4063        let img_seq_len = {
4064            // Reshaping algorithm
4065
4066            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
4067            let (max_height, max_width) = (1540, 1540);
4068            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4069            if ratio > 1. {
4070                height = (height as f64 / ratio).floor() as usize;
4071                width = (width as f64 / ratio).floor() as usize;
4072            }
4073
4074            let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4075            let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4076
4077            height = num_height_tokens * cfg.patch_size;
4078            width = num_width_tokens * cfg.patch_size;
4079
4080            let num_height_tokens = height / cfg.patch_size;
4081            let num_width_tokens = width / cfg.patch_size;
4082
4083            (num_width_tokens + 1) * num_height_tokens
4084        };
4085
4086        Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4087    }
4088
4089    fn non_mapped_size_in_bytes(
4090        &self,
4091        config: &str,
4092        dtype: DType,
4093        weight_pack_factor: usize,
4094        _matformer_config: Option<&MatformerSliceConfig>,
4095    ) -> Result<usize> {
4096        let cfg: Mistral3Config = serde_json::from_str(config)?;
4097
4098        let text_elems = {
4099            let cfg = &cfg.text_config;
4100
4101            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4102            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4103            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4104                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4105            } else {
4106                0
4107            };
4108            let norm = cfg.hidden_size;
4109            embed_tokens + lm_head + norm
4110        };
4111
4112        let vision_elems = {
4113            let cfg = &cfg.vision_config;
4114
4115            let patch_embed = {
4116                let conv_cfg = Conv2dConfig {
4117                    stride: cfg.patch_size,
4118                    ..Default::default()
4119                };
4120                cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4121                    * cfg.patch_size
4122                    * cfg.patch_size
4123                    * cfg.patch_size
4124            };
4125            let ln_pre = cfg.hidden_size;
4126            let vision_layer = {
4127                let attn_norm = cfg.hidden_size;
4128                let ffn_norm = cfg.hidden_size;
4129
4130                let gate = cfg.hidden_size * cfg.intermediate_size;
4131                let up = cfg.hidden_size * cfg.intermediate_size;
4132                let down = cfg.hidden_size * cfg.intermediate_size;
4133
4134                let q = cfg.hidden_size * cfg.hidden_size;
4135                let k = cfg.hidden_size * cfg.hidden_size;
4136                let v = cfg.hidden_size * cfg.hidden_size;
4137                let o = cfg.hidden_size * cfg.hidden_size;
4138
4139                attn_norm + ffn_norm + gate + up + down + q + k + v + o
4140            };
4141
4142            patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4143        };
4144
4145        let elems = text_elems + vision_elems;
4146
4147        Ok(elems * dtype.size_in_bytes())
4148    }
4149
4150    fn layer_sizes_in_bytes(
4151        &self,
4152        config: &str,
4153        dtype: DType,
4154        weight_pack_factor: usize,
4155        _matformer_config: Option<&MatformerSliceConfig>,
4156    ) -> Result<Vec<usize>> {
4157        let cfg: Mistral3Config = serde_json::from_str(config)?;
4158        let cfg = &cfg.text_config;
4159
4160        let per_layer_elems = {
4161            let input_layernorm = cfg.hidden_size;
4162            let post_attention_layernorm = cfg.hidden_size;
4163
4164            let size_in = cfg.hidden_size;
4165            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4166            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4167            let q_proj = size_in * size_q / weight_pack_factor;
4168            let k_proj = size_in * size_kv / weight_pack_factor;
4169            let v_proj = size_in * size_kv / weight_pack_factor;
4170            let o_proj = size_q * size_in / weight_pack_factor;
4171
4172            let h_size = cfg.hidden_size;
4173            let i_size = cfg.intermediate_size;
4174            let gate_proj = h_size * i_size / weight_pack_factor;
4175            let up_proj = h_size * i_size / weight_pack_factor;
4176            let down_proj = i_size * h_size / weight_pack_factor;
4177
4178            input_layernorm
4179                + post_attention_layernorm
4180                + q_proj
4181                + k_proj
4182                + v_proj
4183                + o_proj
4184                + gate_proj
4185                + up_proj
4186                + down_proj
4187        };
4188        Ok(vec![
4189            per_layer_elems * dtype.size_in_bytes();
4190            cfg.num_hidden_layers
4191        ])
4192    }
4193
4194    fn num_layers(&self, config: &str) -> Result<usize> {
4195        let cfg: Mistral3Config = serde_json::from_str(config)?;
4196        let cfg = &cfg.text_config;
4197        Ok(cfg.num_hidden_layers)
4198    }
4199
4200    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4201        let cfg: Mistral3Config = serde_json::from_str(config)?;
4202        let cfg = &cfg.text_config;
4203
4204        let cfg = ModelConfigMetadata {
4205            max_seq_len: cfg.max_position_embeddings,
4206            num_layers: cfg.num_hidden_layers,
4207            hidden_size: cfg.hidden_size,
4208            num_kv_heads: cfg.num_key_value_heads,
4209            num_attn_heads: cfg.num_attention_heads,
4210            sliding_window: cfg.sliding_window,
4211            k_head_dim: cfg.head_dim(),
4212            v_head_dim: cfg.head_dim(),
4213        };
4214
4215        Ok(Box::new(cfg))
4216    }
4217
4218    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4219        Some(vec![NonMappedSubModel::Vision])
4220    }
4221}
4222
4223// ======================== Llama 4 Loader
4224
4225/// [`VisionLoader`] for an Llama Vision model.
4226///
4227/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
4228pub struct VLlama4Loader;
4229
4230pub struct VLlama4Prefixer;
4231
4232impl MultimodalPromptPrefixer for VLlama4Prefixer {
4233    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4234        format!(
4235            "{}{prompt}",
4236            llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4237        )
4238    }
4239}
4240
4241impl VisionModelLoader for VLlama4Loader {
4242    fn load(
4243        &self,
4244        config: &str,
4245        vb: ShardedVarBuilder,
4246        normal_loading_metadata: NormalLoadingMetadata,
4247        attention_mechanism: AttentionImplementation,
4248    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4249        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4250        Ok(Box::new(Llama4Model::new(
4251            &cfg,
4252            vb,
4253            self.is_gptx(config),
4254            normal_loading_metadata,
4255            attention_mechanism,
4256        )?))
4257    }
4258    fn is_gptx(&self, _config: &str) -> bool {
4259        false
4260    }
4261    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4262        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4263        Ok(Box::new(cfg))
4264    }
4265    fn get_processor(
4266        &self,
4267        _model_config: &str,
4268        processor_config: Option<ProcessorConfig>,
4269        _preprocessor_config: PreProcessorConfig,
4270        _max_edge: Option<u32>,
4271    ) -> Arc<dyn Processor + Send + Sync> {
4272        Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4273    }
4274    fn supports_paged_attention(&self, _config: &str) -> bool {
4275        true
4276    }
4277    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4278        Arc::new(VLlama4Prefixer)
4279    }
4280    fn modalities(&self, _config: &str) -> Result<Modalities> {
4281        Ok(Modalities {
4282            input: vec![SupportedModality::Text, SupportedModality::Vision],
4283            output: vec![SupportedModality::Text],
4284        })
4285    }
4286}
4287
4288impl IsqModelLoader for VLlama4Loader {
4289    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4290        Ok(vec![
4291            Regex::new(r"lm_head\.(weight|bias)$")?,
4292            // Attention
4293            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4294            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4295            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4296            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4297            // FF MoE
4298            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4299            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4300            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4301            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4302            Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4303            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4304            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4305            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4306            // FF MLP
4307            Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4308            Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4309            Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4310        ])
4311    }
4312    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4313        Ok(vec![
4314            Regex::new(r"lm_head\.(weight|bias)$")?,
4315            // Attention
4316            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4317            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4318            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4319            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4320            // FF MoE
4321            Regex::new(
4322                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4323            )?,
4324            Regex::new(
4325                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4326            )?,
4327            Regex::new(
4328                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4329            )?,
4330            Regex::new(
4331                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4332            )?,
4333            Regex::new(
4334                r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4335            )?,
4336            Regex::new(
4337                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4338            )?,
4339            Regex::new(
4340                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4341            )?,
4342            Regex::new(
4343                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4344            )?,
4345            // FF MLP
4346            Regex::new(
4347                r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4348            )?,
4349            Regex::new(
4350                r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4351            )?,
4352            Regex::new(
4353                r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4354            )?,
4355        ])
4356    }
4357}
4358
4359impl VLlama4Loader {
4360    /// This incorporates the max batch size!
4361    /// Returns (pixels max batch size, num text image tokens)
4362    #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4363    fn run_dummy_processing(
4364        &self,
4365        cfg: &Llama4Config,
4366        height: usize,
4367        width: usize,
4368        max_num_images: usize,
4369        max_batch_size: usize,
4370    ) -> Result<(usize, usize)> {
4371        let cfg = &cfg.vision_config;
4372
4373        let img_processor =
4374            Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4375        let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4376        let res = img_processor.preprocess(
4377            vec![image; max_num_images],
4378            vec![],
4379            &PreProcessorConfig::default(),
4380            &Device::Cpu,
4381            (max_batch_size, max_num_images),
4382        )?;
4383
4384        let pixels_batch_size = res.pixel_values.dim(0)?;
4385        let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4386
4387        let (image_h, image_w) = (
4388            res.pixel_values.dim(D::Minus2).unwrap(),
4389            res.pixel_values.dim(D::Minus1).unwrap(),
4390        );
4391        let num_patches_per_chunk = (image_h / img_processor.patch_size)
4392            * (image_w / img_processor.patch_size)
4393            / img_processor.downsample_ratio;
4394
4395        Ok((
4396            pixels_max_batch_size,
4397            num_patches_per_chunk * pixels_max_batch_size,
4398        ))
4399    }
4400}
4401
4402impl DeviceMappedModelLoader for VLlama4Loader {
4403    fn mapped_max_act_size_elems(
4404        &self,
4405        config: &str,
4406        params: &AutoDeviceMapParams,
4407    ) -> Result<usize> {
4408        let AutoDeviceMapParams::Vision {
4409            max_seq_len,
4410            max_batch_size,
4411            max_image_shape: (height, width),
4412            max_num_images,
4413        } = params
4414        else {
4415            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4416        };
4417
4418        let cfg: Llama4Config = serde_json::from_str(config)?;
4419
4420        let (_pixels_batch_size, num_text_image_toks) =
4421            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4422
4423        let max_seq_len = max_seq_len.min(&ATTENTION_CHUNK_SIZE) + num_text_image_toks;
4424
4425        Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4426    }
4427    fn non_mapped_max_act_size_elems(
4428        &self,
4429        config: &str,
4430        params: &AutoDeviceMapParams,
4431    ) -> Result<usize> {
4432        let AutoDeviceMapParams::Vision {
4433            max_seq_len: _,
4434            max_batch_size,
4435            max_image_shape: (height, width),
4436            max_num_images,
4437        } = params
4438        else {
4439            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4440        };
4441
4442        let cfg: Llama4Config = serde_json::from_str(config)?;
4443
4444        let (pixels_batch_size, _num_text_image_toks) =
4445            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4446        let max_seq_len = cfg.vision_config.num_patches();
4447
4448        Ok((max_batch_size * pixels_batch_size)
4449            * cfg.vision_config.num_attention_heads
4450            * max_seq_len
4451            * max_seq_len)
4452    }
4453
4454    fn non_mapped_size_in_bytes(
4455        &self,
4456        config: &str,
4457        dtype: DType,
4458        weight_pack_factor: usize,
4459        _matformer_config: Option<&MatformerSliceConfig>,
4460    ) -> Result<usize> {
4461        let cfg: Llama4Config = serde_json::from_str(config)?;
4462        let tcfg = &cfg.text_config;
4463
4464        let text_elems = {
4465            let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4466            let lm_head = if !tcfg.tie_word_embeddings {
4467                tcfg.hidden_size * tcfg.vocab_size
4468            } else {
4469                0
4470            };
4471            let norm = tcfg.hidden_size;
4472            embed_tokens + lm_head + norm
4473        };
4474
4475        let vision_elems = {
4476            let cfg = &cfg.vision_config;
4477
4478            let num_patches = cfg.num_patches();
4479
4480            let unfold_elems =
4481                (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4482            let class_embeddng_elems = cfg.hidden_size;
4483            let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4484            let layernorm_pre_elems = cfg.hidden_size;
4485            let layernorm_post_elems = cfg.hidden_size;
4486
4487            let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4488                / weight_pack_factor
4489                + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4490
4491            let encoder_layer = {
4492                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4493                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4494
4495                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4496                let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4497                    / weight_pack_factor
4498                    + cfg.num_attention_heads * head_dim;
4499                let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4500                    / weight_pack_factor
4501                    + cfg.num_attention_heads * head_dim;
4502                let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4503                    / weight_pack_factor
4504                    + cfg.num_attention_heads * head_dim;
4505                let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4506                    / weight_pack_factor
4507                    + cfg.num_attention_heads * head_dim;
4508
4509                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4510                    + cfg.intermediate_size;
4511                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4512                    + cfg.hidden_size;
4513
4514                input_layernorm
4515                    + post_attention_layernorm
4516                    + q_proj
4517                    + k_proj
4518                    + v_proj
4519                    + o_proj
4520                    + fc1
4521                    + fc2
4522            };
4523
4524            unfold_elems
4525                + class_embeddng_elems
4526                + positional_embedding_vlm_elems
4527                + layernorm_post_elems
4528                + layernorm_pre_elems
4529                + pixel_shuffle_elems
4530                + encoder_layer * cfg.num_hidden_layers
4531        };
4532
4533        let elems = text_elems + vision_elems;
4534
4535        Ok(elems * dtype.size_in_bytes())
4536    }
4537
4538    fn layer_sizes_in_bytes(
4539        &self,
4540        config: &str,
4541        dtype: DType,
4542        weight_pack_factor: usize,
4543        _matformer_config: Option<&MatformerSliceConfig>,
4544    ) -> Result<Vec<usize>> {
4545        let cfg: Llama4Config = serde_json::from_str(config)?;
4546        let tcfg = &cfg.text_config;
4547
4548        let mut per_layer_elems = Vec::new();
4549
4550        for layer_idx in 0..tcfg.num_hidden_layers {
4551            let input_layernorm = tcfg.hidden_size;
4552            let post_attention_layernorm = tcfg.hidden_size;
4553
4554            let size_in = tcfg.hidden_size;
4555            let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4556            let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4557            let q_proj = size_in * size_q / weight_pack_factor;
4558            let k_proj = size_in * size_kv / weight_pack_factor;
4559            let v_proj = size_in * size_kv / weight_pack_factor;
4560            let o_proj = size_q * size_in / weight_pack_factor;
4561
4562            let use_moe = tcfg.moe_layers().contains(&layer_idx);
4563            let moe_block = if use_moe {
4564                let h_size = tcfg.hidden_size;
4565                let i_size = tcfg.intermediate_size;
4566                let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4567                let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4568                let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4569
4570                gate_proj + up_proj + down_proj
4571            } else {
4572                let h_size = tcfg.hidden_size;
4573                let i_size = tcfg.intermediate_size_mlp;
4574                let gate_proj = h_size * i_size / weight_pack_factor;
4575                let up_proj = h_size * i_size / weight_pack_factor;
4576                let down_proj = i_size * h_size / weight_pack_factor;
4577
4578                gate_proj + up_proj + down_proj
4579            };
4580
4581            per_layer_elems.push(
4582                input_layernorm
4583                    + post_attention_layernorm
4584                    + q_proj
4585                    + k_proj
4586                    + v_proj
4587                    + o_proj
4588                    + moe_block,
4589            );
4590        }
4591
4592        Ok(per_layer_elems
4593            .into_iter()
4594            .map(|x| x * dtype.size_in_bytes())
4595            .collect())
4596    }
4597
4598    fn num_layers(&self, config: &str) -> Result<usize> {
4599        let cfg: Llama4Config = serde_json::from_str(config)?;
4600        Ok(cfg.text_config.num_hidden_layers)
4601    }
4602
4603    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4604        let cfg: Llama4Config = serde_json::from_str(config)?;
4605        let cfg = &cfg.text_config;
4606
4607        let cfg = ModelConfigMetadata {
4608            max_seq_len: cfg.max_position_embeddings,
4609            num_layers: cfg.num_hidden_layers,
4610            hidden_size: cfg.hidden_size,
4611            num_kv_heads: cfg.num_attention_heads,
4612            num_attn_heads: cfg.num_attention_heads,
4613            sliding_window: None,
4614            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4615            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4616        };
4617
4618        Ok(Box::new(cfg))
4619    }
4620
4621    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4622        Some(vec![NonMappedSubModel::Vision])
4623    }
4624}
4625
4626// ======================== Gemma 3n Loader
4627
4628/// [`VisionLoader`] for an Gemma 3n model.
4629///
4630/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
4631pub struct Gemma3nLoader;
4632
4633#[allow(dead_code)]
4634pub struct Gemma3nPrefixer;
4635
4636impl MultimodalPromptPrefixer for Gemma3nPrefixer {
4637    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4638        prompt.to_string()
4639    }
4640}
4641
4642impl VisionModelLoader for Gemma3nLoader {
4643    fn load(
4644        &self,
4645        config: &str,
4646        vb: ShardedVarBuilder,
4647        normal_loading_metadata: NormalLoadingMetadata,
4648        attention_mechanism: AttentionImplementation,
4649    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4650        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4651        Ok(Box::new(Gemma3nModel::new(
4652            &cfg,
4653            vb,
4654            self.is_gptx(config),
4655            normal_loading_metadata,
4656            attention_mechanism,
4657        )?))
4658    }
4659    fn is_gptx(&self, _config: &str) -> bool {
4660        true
4661    }
4662    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4663        let config: Gemma3nConfig = serde_json::from_str(config)?;
4664        Ok(Box::new(config))
4665    }
4666    fn get_processor(
4667        &self,
4668        _config: &str,
4669        processor_config: Option<ProcessorConfig>,
4670        _preprocessor_config: PreProcessorConfig,
4671        _max_edge: Option<u32>,
4672    ) -> Arc<dyn Processor + Send + Sync> {
4673        // Handle the Gemma 3 1b case here
4674        Arc::new(Gemma3nProcessor::new(
4675            processor_config.unwrap_or_default(),
4676            true,
4677        ))
4678    }
4679    fn supports_paged_attention(&self, _config: &str) -> bool {
4680        false
4681    }
4682    fn supports_prefix_cacher(&self, _config: &str) -> bool {
4683        true
4684    }
4685    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4686        Arc::new(Gemma3Prefixer)
4687    }
4688    fn modalities(&self, _config: &str) -> Result<Modalities> {
4689        Ok(Modalities {
4690            input: vec![
4691                SupportedModality::Text,
4692                SupportedModality::Vision,
4693                SupportedModality::Audio,
4694            ],
4695            output: vec![SupportedModality::Text],
4696        })
4697    }
4698}
4699
4700impl IsqModelLoader for Gemma3nLoader {
4701    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4702        Ok(vec![
4703            Regex::new(r"lm_head\.(weight|bias)$")?,
4704            // Language model attention
4705            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4706            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4707            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4708            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4709            // Language model MLP
4710            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4711            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4712            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4713            // Audio conformer attention layers
4714            Regex::new(r"conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$")?,
4715            Regex::new(r"conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$")?,
4716            Regex::new(r"conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$")?,
4717            Regex::new(
4718                r"conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4719            )?,
4720            Regex::new(r"conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4721            // Audio conformer FFW layers
4722            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$")?,
4723            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$")?,
4724            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$")?,
4725            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$")?,
4726            // Audio conformer conv1d layers
4727            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$")?,
4728            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$")?,
4729            // Audio subsample projection
4730            Regex::new(r"subsample_conv_projection\.input_proj_linear\.(weight|bias)$")?,
4731            // Multimodal embedders
4732            Regex::new(r"embed_vision\.embedding_projection\.(weight|bias)$")?,
4733            Regex::new(r"embed_audio\.embedding_projection\.(weight|bias)$")?,
4734        ])
4735    }
4736    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4737        Ok(vec![
4738            Regex::new(r"lm_head\.(weight|bias)$")?,
4739            // Language model attention
4740            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4741            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4742            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4743            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4744            // Language model MLP
4745            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4746            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4747            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4748            // Projections
4749            Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
4750            Regex::new(r"model\.language_model\.altup_projections\.(\d+)\.(weight|bias)$")?,
4751            Regex::new(r"model\.language_model\.altup_unembed_projections\.(\d+)\.(weight|bias)$")?,
4752            // Audio conformer attention layers
4753            Regex::new(
4754                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$",
4755            )?,
4756            Regex::new(
4757                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$",
4758            )?,
4759            Regex::new(
4760                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$",
4761            )?,
4762            Regex::new(
4763                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4764            )?,
4765            Regex::new(r"model\.audio_tower\.conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4766            // Audio conformer FFW layers
4767            Regex::new(
4768                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$",
4769            )?,
4770            Regex::new(
4771                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$",
4772            )?,
4773            Regex::new(
4774                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$",
4775            )?,
4776            Regex::new(
4777                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$",
4778            )?,
4779            // Audio conformer conv1d layers
4780            Regex::new(
4781                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$",
4782            )?,
4783            Regex::new(
4784                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$",
4785            )?,
4786            // Audio subsample projection
4787            Regex::new(
4788                r"model\.audio_tower\.subsample_conv_projection\.input_proj_linear\.(weight|bias)$",
4789            )?,
4790            // Multimodal embedders
4791            Regex::new(r"model\.embed_vision\.embedding_projection\.(weight|bias)$")?,
4792            Regex::new(r"model\.embed_audio\.embedding_projection\.(weight|bias)$")?,
4793        ])
4794    }
4795}
4796
4797impl DeviceMappedModelLoader for Gemma3nLoader {
4798    fn mapped_max_act_size_elems(
4799        &self,
4800        config: &str,
4801        params: &AutoDeviceMapParams,
4802    ) -> Result<usize> {
4803        let AutoDeviceMapParams::Vision {
4804            max_seq_len,
4805            max_batch_size,
4806            max_image_shape: _,
4807            max_num_images,
4808        } = params
4809        else {
4810            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4811        };
4812
4813        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4814        let text_cfg = &cfg.text_config;
4815
4816        // Gemma3n is an "inject into the prompt" model, similar to Gemma3
4817        // We need to account for vision and audio tokens in the sequence length
4818
4819        let mut total_seq_len = *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4820
4821        // Add vision tokens
4822        {
4823            // Vision tokens are injected into the prompt
4824            // MSFA outputs fixed 16x16 features regardless of input size
4825            let msfa_spatial_size = 16; // Fixed from vision.rs line 1115
4826            let vision_tokens_per_image = msfa_spatial_size * msfa_spatial_size; // 256 tokens
4827            total_seq_len += vision_tokens_per_image * max_num_images;
4828        }
4829
4830        // Add audio tokens
4831        {
4832            // Audio tokens are injected into the prompt
4833            // From config field audio_soft_tokens_per_image (typically 188)
4834            let audio_tokens = cfg.audio_soft_tokens_per_image;
4835            total_seq_len += audio_tokens;
4836        }
4837
4838        // Calculate max attention size for text model with all injected tokens
4839        let max_text_attn =
4840            max_batch_size * text_cfg.num_attention_heads * total_seq_len * total_seq_len;
4841
4842        Ok(max_text_attn)
4843    }
4844
4845    fn non_mapped_max_act_size_elems(
4846        &self,
4847        config: &str,
4848        params: &AutoDeviceMapParams,
4849    ) -> Result<usize> {
4850        let AutoDeviceMapParams::Vision {
4851            max_seq_len: _,
4852            max_batch_size,
4853            max_image_shape: _,
4854            max_num_images,
4855        } = params
4856        else {
4857            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4858        };
4859
4860        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4861
4862        // Calculate max activation sizes for each modality
4863        let mut max_activation = 0;
4864
4865        // Vision activation size
4866        {
4867            // Vision is Gemma3n's MobileNetV5 architecture with Multi-Query Attention
4868            // The peak activation is in the Multi-Query Attention layers
4869
4870            // From the architecture: stages 3 and 4 have MMQA blocks
4871            // Input images are 768x768 (from inputs_processor.rs)
4872            // Stage 3: 640 channels at 48x48 (768/16 downsampling), MMQA with num_heads=12, kv_dim=64
4873            // Stage 4: 1280 channels at 24x24 (768/32 downsampling), MMQA with num_heads=16, kv_dim=96
4874            // MSFA output: 2048 channels at fixed 16x16
4875
4876            let vision_tower_act = {
4877                // Peak is during MMQA attention computation in stage 4
4878                // Stage 4 has higher memory usage than Stage 3 due to more heads (16 vs 12)
4879                // From vision.rs: Stage 4 has num_heads=16, kv_dim=96, kv_stride=1
4880                let num_heads = 16; // Stage 4 configuration
4881                let spatial_size = 24; // 768 / 32 = 24 (input 768x768, stage 4 has 32x downsampling)
4882                let seq_len = spatial_size * spatial_size;
4883
4884                // Attention scores: [B * num_images, num_heads, seq_len, seq_len]
4885                max_batch_size * max_num_images * num_heads * seq_len * seq_len
4886            };
4887
4888            // Vision embedder activations
4889            let vision_embed_act = {
4890                // MSFA output: 2048 channels at fixed 16x16 spatial (from vision.rs line 1115)
4891                let msfa_channels = 2048; // MSFA_OUT_CHANNELS from vision.rs
4892                let spatial_size = 16; // Fixed output resolution from MSFA
4893                let vision_features =
4894                    max_batch_size * max_num_images * msfa_channels * spatial_size * spatial_size;
4895
4896                // After embedding projection to text hidden size
4897                let projected = max_batch_size
4898                    * max_num_images
4899                    * spatial_size
4900                    * spatial_size
4901                    * cfg.text_config.hidden_size;
4902
4903                vision_features.max(projected)
4904            };
4905
4906            max_activation = max_activation.max(vision_tower_act).max(vision_embed_act);
4907        }
4908
4909        // Audio activation size
4910        {
4911            let audio_cfg = &cfg.audio_config;
4912
4913            // Calculate max audio sequence length based on config
4914            // Audio uses conformer with subsampling and reduction
4915
4916            // A rough estimate of max_audio_frames
4917            let max_audio_frames = 1280;
4918
4919            let subsample_factor: usize = audio_cfg
4920                .sscp_conv_stride_size
4921                .iter()
4922                .map(|stride| stride[0]) // Time dimension stride
4923                .product();
4924            let audio_seq_after_subsample = max_audio_frames / subsample_factor;
4925
4926            // Audio encoder activations
4927            let audio_encoder_act = {
4928                // Conformer FFW layers have expansion factor from config
4929                let intermediate_size = audio_cfg.hidden_size * 4; // FFW expansion factor
4930
4931                // Peak is in the FFW layers before reduction
4932                max_batch_size * audio_seq_after_subsample * intermediate_size
4933            };
4934
4935            // Audio attention activations
4936            let audio_attn_act = {
4937                // Attention uses chunked processing with specific context sizes
4938                let chunk_size = audio_cfg.conf_attention_chunk_size;
4939                let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
4940                    + audio_cfg.conf_attention_context_right;
4941
4942                // Peak is attention scores: [B, num_heads, num_chunks, chunk_size, context_size]
4943                let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
4944
4945                max_batch_size
4946                    * audio_cfg.conf_num_attention_heads
4947                    * num_chunks
4948                    * chunk_size
4949                    * context_size
4950            };
4951
4952            max_activation = max_activation.max(audio_encoder_act).max(audio_attn_act);
4953        }
4954
4955        Ok(max_activation)
4956    }
4957
4958    fn non_mapped_size_in_bytes(
4959        &self,
4960        config: &str,
4961        dtype: DType,
4962        weight_pack_factor: usize,
4963        matformer_config: Option<&MatformerSliceConfig>,
4964    ) -> Result<usize> {
4965        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4966
4967        // Apply matformer slicing if configured
4968        let text_cfg = if let Some(matformer_cfg) = matformer_config {
4969            use crate::device_map::DummyDeviceMapper;
4970            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
4971
4972            let dummy_mapper = DummyDeviceMapper {
4973                nm_device: Device::Cpu,
4974            };
4975            let (adjusted_cfg, _, _, _, _) = handle_matformer_slicing(
4976                &cfg.text_config,
4977                &Some(matformer_cfg.clone()),
4978                &dummy_mapper,
4979            )?;
4980            adjusted_cfg
4981        } else {
4982            cfg.text_config.clone()
4983        };
4984
4985        let text_cfg = &text_cfg;
4986
4987        // Text components that are not device-mapped
4988        let text_elems = {
4989            // Embeddings
4990            let embed_tokens = text_cfg.hidden_size * text_cfg.vocab_size;
4991            let embed_tokens_per_layer = text_cfg.num_hidden_layers
4992                * text_cfg.hidden_size_per_layer_input
4993                * text_cfg.vocab_size_per_layer_input;
4994
4995            // LM head (if not tied)
4996            let lm_head = if !text_cfg.tie_word_embeddings || weight_pack_factor != 1 {
4997                text_cfg.hidden_size * text_cfg.vocab_size / weight_pack_factor
4998            } else {
4999                0
5000            };
5001
5002            // Final layer norm
5003            let norm = text_cfg.hidden_size;
5004
5005            // AltUp projections (not device-mapped)
5006            let altup_projections =
5007                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5008                    / weight_pack_factor;
5009            let altup_unembed_projections =
5010                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5011                    / weight_pack_factor;
5012
5013            // Per-layer model projection
5014            let per_layer_model_projection = text_cfg.num_hidden_layers
5015                * text_cfg.hidden_size
5016                * text_cfg.hidden_size_per_layer_input
5017                / weight_pack_factor;
5018            let per_layer_projection_norm = text_cfg.hidden_size;
5019
5020            embed_tokens
5021                + embed_tokens_per_layer
5022                + lm_head
5023                + norm
5024                + altup_projections
5025                + altup_unembed_projections
5026                + per_layer_model_projection
5027                + per_layer_projection_norm
5028        };
5029
5030        // Vision components
5031        let vision_elems = {
5032            let vision_cfg = &cfg.vision_config;
5033            // Vision tower - calculated from actual Gemma3n architecture
5034            // NOTE: Vision tower uses only Conv2d layers, NOT Arc<dyn QuantMethod>,
5035            // so NONE of these should be divided by weight_pack_factor
5036            let vision_tower_elems = {
5037                use crate::vision_models::gemma3n::vision::{
5038                    gemma3n_mobilenet_def, make_divisible, BlockType, INPUT_CHANNELS,
5039                    MSFA_EXPANSION_RATIO, MSFA_IN_CHANNELS, MSFA_OUT_CHANNELS, STEM_KERNEL_SIZE,
5040                    STEM_OUT_CHANNELS,
5041                };
5042
5043                // Stem: ConvNormAct (Conv2d + RMSNorm)
5044                let stem_conv =
5045                    INPUT_CHANNELS * STEM_OUT_CHANNELS * STEM_KERNEL_SIZE * STEM_KERNEL_SIZE;
5046                let stem_norm = STEM_OUT_CHANNELS; // RMSNorm weight
5047
5048                // Track input channels through the network
5049                let mut in_chs = STEM_OUT_CHANNELS;
5050                let mut total_elems = stem_conv + stem_norm;
5051
5052                // Process all stages from gemma3n_mobilenet_def
5053                let block_defs = gemma3n_mobilenet_def();
5054
5055                for stage_blocks in block_defs.iter() {
5056                    for block_type in stage_blocks.iter() {
5057                        match block_type {
5058                            BlockType::EdgeResidual {
5059                                out_channels,
5060                                kernel_size,
5061                                stride: _,
5062                                expand_ratio,
5063                                ..
5064                            } => {
5065                                #[allow(clippy::cast_precision_loss)]
5066                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5067                                // EdgeResidual: all Conv2d layers, not quantizable
5068                                total_elems += in_chs * mid_chs * kernel_size * kernel_size; // conv_exp (Conv2d)
5069                                total_elems += mid_chs; // bn1 weight
5070                                total_elems += mid_chs * out_channels; // conv_pwl (Conv2d)
5071                                total_elems += out_channels; // bn2 weight
5072                                in_chs = *out_channels;
5073                            }
5074                            BlockType::UniversalInvertedResidual {
5075                                out_channels,
5076                                start_kernel_size,
5077                                mid_kernel_size,
5078                                stride: _,
5079                                expand_ratio,
5080                                ..
5081                            } => {
5082                                #[allow(clippy::cast_precision_loss)]
5083                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5084                                // UniversalInvertedResidual: all Conv2d layers, not quantizable
5085                                if *expand_ratio != 1.0 {
5086                                    total_elems += in_chs * mid_chs; // expand conv (Conv2d)
5087                                    total_elems += mid_chs; // expand norm
5088                                }
5089                                if *start_kernel_size > 0 {
5090                                    total_elems += mid_chs * start_kernel_size * start_kernel_size; // depthwise start (Conv2d)
5091                                    total_elems += mid_chs; // norm
5092                                }
5093                                if *mid_kernel_size > 0 {
5094                                    total_elems += mid_chs * mid_kernel_size * mid_kernel_size; // depthwise mid (Conv2d)
5095                                    total_elems += mid_chs; // norm
5096                                }
5097                                total_elems += mid_chs * out_channels; // project conv (Conv2d)
5098                                total_elems += out_channels; // project norm
5099                                total_elems += out_channels; // layer scale gamma
5100                                in_chs = *out_channels;
5101                            }
5102                            BlockType::MultiQueryAttention {
5103                                num_heads,
5104                                kv_dim,
5105                                kv_stride: _,
5106                                ..
5107                            } => {
5108                                // MMQA: all Conv2d layers, not quantizable
5109                                let dw_kernel_size = 3; // Default dw_kernel_size for MMQA
5110                                total_elems += in_chs; // norm weight
5111                                total_elems += in_chs * num_heads * kv_dim; // query_proj (Conv2d)
5112                                total_elems += in_chs * kv_dim; // key_proj (Conv2d)
5113                                total_elems += in_chs * dw_kernel_size * dw_kernel_size; // key_dw_conv (Conv2d)
5114                                total_elems += *kv_dim; // value_down_conv (Conv2d)
5115                                total_elems += 1; // value_norm weight
5116                                total_elems += *kv_dim; // value_proj (Conv2d)
5117                                total_elems += num_heads * kv_dim * in_chs; // output_proj (Conv2d)
5118                                total_elems += in_chs; // layer scale
5119                            }
5120                        }
5121                    }
5122                }
5123
5124                // Multi-scale fusion adapter (msfa) - also uses Conv2d layers
5125                let msfa_in = MSFA_IN_CHANNELS.iter().sum::<usize>();
5126                let msfa_out = MSFA_OUT_CHANNELS;
5127                #[allow(clippy::cast_precision_loss)]
5128                let msfa_mid = make_divisible(msfa_in as f64 * MSFA_EXPANSION_RATIO, 8);
5129
5130                // MSFA FFN (UIR with expansion_ratio) - Conv2d layers, not quantizable
5131                total_elems += msfa_in * msfa_mid; // expand (Conv2d)
5132                total_elems += msfa_mid; // expand norm
5133                total_elems += msfa_mid * msfa_out; // project (Conv2d)
5134                total_elems += msfa_out; // project norm
5135                total_elems += msfa_out; // final norm
5136
5137                total_elems
5138            };
5139
5140            // Vision multimodal embedder components
5141            let embed_vision_elems = {
5142                // Embedding layer (not quantizable)
5143                let embedding = vision_cfg.vocab_size * vision_cfg.hidden_size;
5144
5145                // Normalization layers (not quantizable)
5146                let hard_norm = vision_cfg.hidden_size;
5147                let soft_norm = vision_cfg.hidden_size;
5148
5149                // Projection from vision to text hidden size (IS Arc<dyn QuantMethod>, so quantizable)
5150                let projection = vision_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5151
5152                // Post-projection norm (not quantizable)
5153                let post_norm = text_cfg.hidden_size;
5154
5155                embedding + hard_norm + soft_norm + projection + post_norm
5156            };
5157
5158            vision_tower_elems + embed_vision_elems
5159        };
5160
5161        // Audio components - based on actual audio.rs structure
5162        let audio_elems = {
5163            let audio_cfg = &cfg.audio_config;
5164
5165            // SubSampleConvProjection components
5166            let subsample_conv_projection_elems = {
5167                // Conv blocks (Conv2d layers - NOT quantizable)
5168                let mut conv_elems = 0;
5169
5170                // conv_0: Conv2d from 1 channel to first channel size
5171                let in_ch_0 = 1;
5172                let out_ch_0 = audio_cfg.sscp_conv_channel_size[0];
5173                let kernel_0 = &audio_cfg.sscp_conv_kernel_size[0];
5174                conv_elems += in_ch_0 * out_ch_0 * kernel_0[0] * kernel_0[1];
5175
5176                // conv_1: Conv2d from first to second channel size
5177                let in_ch_1 = out_ch_0;
5178                let out_ch_1 = audio_cfg.sscp_conv_channel_size[1];
5179                let kernel_1 = &audio_cfg.sscp_conv_kernel_size[1];
5180                conv_elems += in_ch_1 * out_ch_1 * kernel_1[0] * kernel_1[1];
5181
5182                // CumulativeGroupNorm for each conv block (weight only, no bias by default)
5183                let norm_0 = out_ch_0; // norm weight for conv_0
5184                let norm_1 = out_ch_1; // norm weight for conv_1
5185
5186                // input_proj_linear (Arc<dyn QuantMethod> - IS quantizable)
5187                let mut f_out = audio_cfg.input_feat_size;
5188                for i in 0..2 {
5189                    let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
5190                    let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
5191                    let pad_left = 1;
5192                    let pad_right = 1;
5193                    f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
5194                }
5195                let input_proj_in_features = out_ch_1 * f_out;
5196                let input_proj_linear =
5197                    input_proj_in_features * audio_cfg.hidden_size / weight_pack_factor;
5198
5199                conv_elems + norm_0 + norm_1 + input_proj_linear
5200            };
5201
5202            // Conformer blocks
5203            let conformer_elems = {
5204                let mut total = 0;
5205
5206                for _ in 0..audio_cfg.conf_num_hidden_layers {
5207                    // ConformerAttention
5208                    let attention_elems = {
5209                        // Norms (NOT quantizable)
5210                        let pre_attn_norm = audio_cfg.hidden_size;
5211                        let post_norm = audio_cfg.hidden_size;
5212
5213                        // Attention projections (Arc<dyn QuantMethod> - IS quantizable)
5214                        let q_proj =
5215                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5216                        let k_proj =
5217                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5218                        let v_proj =
5219                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5220                        let post =
5221                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5222
5223                        // RelativePositionEmbedding
5224                        let pos_proj =
5225                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5226                        let per_dim_scale =
5227                            audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads; // head_dim
5228                        let inv_timescales = audio_cfg.hidden_size / 2; // num_timescales
5229                        let pos_indices = audio_cfg.conf_attention_context_left
5230                            + audio_cfg.conf_attention_context_right
5231                            + 1;
5232
5233                        // Local causal masks (precomputed tensors)
5234                        let chunk_size = audio_cfg.conf_attention_chunk_size;
5235                        let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5236                            + audio_cfg.conf_attention_context_right;
5237                        let local_causal_valid_mask = chunk_size * context_size; // U8 tensor
5238                        let invalid_logits_tensor = 1; // single f32 value
5239
5240                        pre_attn_norm
5241                            + post_norm
5242                            + q_proj
5243                            + k_proj
5244                            + v_proj
5245                            + post
5246                            + pos_proj
5247                            + per_dim_scale
5248                            + inv_timescales
5249                            + pos_indices
5250                            + local_causal_valid_mask
5251                            + invalid_logits_tensor
5252                    };
5253
5254                    // ConformerFeedForward (start and end)
5255                    let ffw_elems = {
5256                        // Each FFW has:
5257                        // - pre_layer_norm (NOT quantizable)
5258                        // - ffw_layer_1 (Arc<dyn QuantMethod> - IS quantizable)
5259                        // - ffw_layer_2 (Arc<dyn QuantMethod> - IS quantizable)
5260                        // - post_layer_norm (NOT quantizable)
5261                        let intermediate_size = audio_cfg.hidden_size * 4;
5262
5263                        let ffw_start = {
5264                            let pre_norm = audio_cfg.hidden_size;
5265                            let layer_1 =
5266                                audio_cfg.hidden_size * intermediate_size / weight_pack_factor;
5267                            let layer_2 =
5268                                intermediate_size * audio_cfg.hidden_size / weight_pack_factor;
5269                            let post_norm = audio_cfg.hidden_size;
5270                            pre_norm + layer_1 + layer_2 + post_norm
5271                        };
5272
5273                        let ffw_end = ffw_start; // Same structure
5274
5275                        ffw_start + ffw_end
5276                    };
5277
5278                    // ConformerLightConv1d
5279                    let lconv1d_elems = {
5280                        // Norms (NOT quantizable)
5281                        let pre_layer_norm = audio_cfg.hidden_size;
5282                        let conv_norm = audio_cfg.hidden_size;
5283
5284                        // Linear layers (Arc<dyn QuantMethod> - IS quantizable)
5285                        let linear_start = audio_cfg.hidden_size * (audio_cfg.hidden_size * 2)
5286                            / weight_pack_factor;
5287                        let linear_end =
5288                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5289
5290                        // depthwise_conv1d (Conv1d - NOT quantizable)
5291                        let depthwise = audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
5292
5293                        pre_layer_norm + conv_norm + linear_start + linear_end + depthwise
5294                    };
5295
5296                    // Final norm for conformer block (NOT quantizable)
5297                    let block_norm = audio_cfg.hidden_size;
5298
5299                    total += attention_elems + ffw_elems + lconv1d_elems + block_norm;
5300                }
5301
5302                total
5303            };
5304
5305            // Audio multimodal embedder (embed_audio)
5306            let embed_audio_elems = {
5307                // Embedding layer (ScaledEmbedding - NOT quantizable)
5308                let embedding = audio_cfg.vocab_size * audio_cfg.hidden_size;
5309
5310                // RMS norms (NOT quantizable)
5311                let hard_embedding_norm = audio_cfg.hidden_size; // with scale
5312                let soft_embedding_norm = audio_cfg.hidden_size; // with scale
5313                let embedding_post_projection_norm = text_cfg.hidden_size; // without scale
5314
5315                // Projection (Arc<dyn QuantMethod> - IS quantizable)
5316                let embedding_projection =
5317                    audio_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5318
5319                embedding
5320                    + hard_embedding_norm
5321                    + soft_embedding_norm
5322                    + embedding_post_projection_norm
5323                    + embedding_projection
5324            };
5325
5326            subsample_conv_projection_elems + conformer_elems + embed_audio_elems
5327        };
5328
5329        let vision_dtype = if dtype == DType::F16 {
5330            // f16 -> f32 for vision model in particular.
5331            DType::F32
5332        } else {
5333            dtype
5334        };
5335
5336        let total_elems = text_elems * dtype.size_in_bytes()
5337            + vision_elems * vision_dtype.size_in_bytes()
5338            + audio_elems * dtype.size_in_bytes();
5339
5340        Ok(total_elems)
5341    }
5342
5343    fn layer_sizes_in_bytes(
5344        &self,
5345        config: &str,
5346        dtype: DType,
5347        weight_pack_factor: usize,
5348        matformer_config: Option<&MatformerSliceConfig>,
5349    ) -> Result<Vec<usize>> {
5350        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5351
5352        // Apply matformer slicing if configured
5353        let (text_cfg, _layer_rename_map, _layers_skipped) = if let Some(matformer_cfg) =
5354            matformer_config
5355        {
5356            use crate::device_map::DummyDeviceMapper;
5357            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5358
5359            let dummy_mapper = DummyDeviceMapper {
5360                nm_device: Device::Cpu,
5361            };
5362            let (adjusted_cfg, _, _, layer_rename_map, layers_skipped) = handle_matformer_slicing(
5363                &cfg.text_config,
5364                &Some(matformer_cfg.clone()),
5365                &dummy_mapper,
5366            )?;
5367            (adjusted_cfg, layer_rename_map, layers_skipped)
5368        } else {
5369            (cfg.text_config.clone(), None, None)
5370        };
5371
5372        let text_cfg = &text_cfg;
5373
5374        // When matformer slicing is applied, we only include the layers that are kept
5375        let mut layer_sizes = Vec::new();
5376
5377        // Note: We don't need orig_intermediate_sizes anymore since the adjusted config
5378        // already has the correct intermediate sizes after matformer slicing
5379
5380        for layer_idx in 0..text_cfg.num_hidden_layers {
5381            let per_layer_elems = {
5382                // Layer norms
5383                let input_layernorm = text_cfg.hidden_size;
5384                let post_attention_layernorm = text_cfg.hidden_size;
5385                let pre_feedforward_layernorm = text_cfg.hidden_size;
5386                let post_feedforward_layernorm = text_cfg.hidden_size;
5387                let post_per_layer_input_norm = text_cfg.hidden_size;
5388
5389                // Attention components
5390                let size_in = text_cfg.hidden_size;
5391                let size_q = text_cfg.num_attention_heads * text_cfg.head_dim;
5392                let size_kv = text_cfg.num_key_value_heads * text_cfg.head_dim;
5393
5394                let q_proj = size_in * size_q / weight_pack_factor;
5395                let k_proj = size_in * size_kv / weight_pack_factor;
5396                let v_proj = size_in * size_kv / weight_pack_factor;
5397                let o_proj = size_q * size_in / weight_pack_factor;
5398
5399                // Q, K, V norms
5400                let q_norm = text_cfg.head_dim;
5401                let k_norm = text_cfg.head_dim;
5402                let v_norm = text_cfg.head_dim; // No bias for v_norm
5403
5404                // MLP components - use the adjusted intermediate sizes from matformer
5405                let intermediate_size = match &text_cfg.intermediate_size {
5406                    IntermediateSize::Single(size) => *size,
5407                    IntermediateSize::PerLayer(sizes) => sizes[layer_idx],
5408                    IntermediateSize::Matformer(sizes, _) => sizes[layer_idx],
5409                };
5410                let gate_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5411                let up_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5412                let down_proj = intermediate_size * text_cfg.hidden_size / weight_pack_factor;
5413
5414                // AltUp components (per layer)
5415                let altup_elems = {
5416                    let correct_output_scale = text_cfg.hidden_size;
5417                    let correction_coefs = text_cfg.altup_num_inputs * text_cfg.altup_num_inputs;
5418                    let prediction_coefs =
5419                        text_cfg.altup_num_inputs * text_cfg.altup_num_inputs.pow(2);
5420                    let modality_router = text_cfg.hidden_size * text_cfg.altup_num_inputs;
5421                    let router_norm = text_cfg.hidden_size;
5422
5423                    correct_output_scale
5424                        + correction_coefs
5425                        + prediction_coefs
5426                        + modality_router
5427                        + router_norm
5428                };
5429
5430                // Laurel block components
5431                let laurel_elems = {
5432                    let left = text_cfg.hidden_size * text_cfg.laurel_rank;
5433                    let right = text_cfg.laurel_rank * text_cfg.hidden_size;
5434                    let post_norm = text_cfg.hidden_size;
5435
5436                    left + right + post_norm
5437                };
5438
5439                // Per-layer input components
5440                let per_layer_input_gate =
5441                    text_cfg.hidden_size * text_cfg.hidden_size_per_layer_input;
5442                let per_layer_projection =
5443                    text_cfg.hidden_size_per_layer_input * text_cfg.hidden_size;
5444
5445                input_layernorm
5446                    + post_attention_layernorm
5447                    + pre_feedforward_layernorm
5448                    + post_feedforward_layernorm
5449                    + post_per_layer_input_norm
5450                    + q_proj
5451                    + k_proj
5452                    + v_proj
5453                    + o_proj
5454                    + q_norm
5455                    + k_norm
5456                    + v_norm
5457                    + gate_proj
5458                    + up_proj
5459                    + down_proj
5460                    + altup_elems
5461                    + laurel_elems
5462                    + per_layer_input_gate
5463                    + per_layer_projection
5464            };
5465
5466            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5467        }
5468
5469        Ok(layer_sizes)
5470    }
5471
5472    fn num_layers(&self, config: &str) -> Result<usize> {
5473        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5474        Ok(cfg.text_config.num_hidden_layers)
5475    }
5476
5477    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5478        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5479        let cfg = cfg.text_config;
5480
5481        let cfg = ModelConfigMetadata {
5482            max_seq_len: cfg.max_position_embeddings,
5483            num_layers: cfg.num_hidden_layers,
5484            hidden_size: cfg.hidden_size,
5485            num_kv_heads: cfg.num_key_value_heads,
5486            num_attn_heads: cfg.num_attention_heads,
5487            sliding_window: None, // None to be more forgiving, some do not
5488            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5489            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5490        };
5491
5492        Ok(Box::new(cfg))
5493    }
5494
5495    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5496        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
5497    }
5498}
5499
5500// ======================== Qwen3VL Loader
5501
5502/// [`VisionLoader`] for an Qwen3VL model.
5503///
5504/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
5505pub struct Qwen3VLLoader;
5506
5507pub struct Qwen3VLPrefixer;
5508
5509impl MultimodalPromptPrefixer for Qwen3VLPrefixer {
5510    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
5511        format!(
5512            "{}{prompt}",
5513            format!(
5514                "{}{}{}",
5515                Qwen3VLProcessor::VISION_START,
5516                Qwen3VLProcessor::IMAGE_PAD,
5517                Qwen3VLProcessor::VISION_END
5518            )
5519            .repeat(image_indexes.len())
5520        )
5521    }
5522}
5523
5524impl VisionModelLoader for Qwen3VLLoader {
5525    fn load(
5526        &self,
5527        config: &str,
5528        vb: ShardedVarBuilder,
5529        normal_loading_metadata: NormalLoadingMetadata,
5530        attention_mechanism: AttentionImplementation,
5531    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
5532        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5533        Ok(Box::new(Qwen3VLModel::new(
5534            &cfg,
5535            vb,
5536            self.is_gptx(config),
5537            normal_loading_metadata,
5538            attention_mechanism,
5539        )?))
5540    }
5541    fn is_gptx(&self, _config: &str) -> bool {
5542        true
5543    }
5544    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5545        let config: Qwen3VLConfig = serde_json::from_str(config)?;
5546        Ok(Box::new(config))
5547    }
5548    fn get_processor(
5549        &self,
5550        _model_config: &str,
5551        _processor_config: Option<ProcessorConfig>,
5552        _preprocessor_config: PreProcessorConfig,
5553        max_edge: Option<u32>,
5554    ) -> Arc<dyn Processor + Send + Sync> {
5555        Arc::new(Qwen3VLProcessor::new(max_edge))
5556    }
5557    fn supports_paged_attention(&self, _config: &str) -> bool {
5558        true
5559    }
5560    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5561        Arc::new(Qwen3VLPrefixer)
5562    }
5563    fn modalities(&self, _config: &str) -> Result<Modalities> {
5564        Ok(Modalities {
5565            input: vec![SupportedModality::Text, SupportedModality::Vision],
5566            output: vec![SupportedModality::Text],
5567        })
5568    }
5569}
5570
5571impl IsqModelLoader for Qwen3VLLoader {
5572    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5573        Ok(vec![
5574            Regex::new(r"lm_head\.(weight|bias)$")?,
5575            // Attention
5576            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5577            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5578            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5579            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5580            // MLP
5581            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5582            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5583            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5584        ])
5585    }
5586    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5587        self.isq_layer_regexes(config)
5588    }
5589}
5590
5591impl DeviceMappedModelLoader for Qwen3VLLoader {
5592    fn mapped_max_act_size_elems(
5593        &self,
5594        config: &str,
5595        params: &AutoDeviceMapParams,
5596    ) -> Result<usize> {
5597        let AutoDeviceMapParams::Vision {
5598            max_seq_len,
5599            max_batch_size,
5600            max_image_shape,
5601            max_num_images,
5602        } = params
5603        else {
5604            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5605        };
5606
5607        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5608
5609        // For images, grid_t=1. After spatial merging, grid_h and grid_w are reduced.
5610        let img_seq_len = {
5611            let cfg = &cfg.vision_config;
5612            // grid_t is 1 for images (temporal dimension is for video only)
5613            let grid_t = 1;
5614            // After patch embedding and spatial merge, the effective grid dimensions are reduced
5615            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
5616            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
5617            grid_t * grid_h * grid_w * max_num_images
5618        };
5619
5620        let max_text_attn = {
5621            let cfg = &cfg.text_config;
5622            // This model injects the vision information directly into the input embeddings
5623            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5624            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5625        };
5626
5627        Ok(max_text_attn)
5628    }
5629
5630    fn non_mapped_max_act_size_elems(
5631        &self,
5632        config: &str,
5633        params: &AutoDeviceMapParams,
5634    ) -> Result<usize> {
5635        let AutoDeviceMapParams::Vision {
5636            max_seq_len: _,
5637            max_batch_size,
5638            max_image_shape,
5639            max_num_images,
5640        } = params
5641        else {
5642            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5643        };
5644
5645        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5646
5647        // For the vision encoder, before spatial merging
5648        let img_seq_len = {
5649            let cfg = &cfg.vision_config;
5650            // grid_t is 1 for images
5651            let grid_t = 1;
5652            let grid_h = max_image_shape.0 / cfg.patch_size;
5653            let grid_w = max_image_shape.1 / cfg.patch_size;
5654            grid_t * grid_h * grid_w
5655        };
5656
5657        let max_vision_attn = {
5658            let cfg = &cfg.vision_config;
5659            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5660        };
5661
5662        Ok(max_vision_attn)
5663    }
5664
5665    fn non_mapped_size_in_bytes(
5666        &self,
5667        config: &str,
5668        dtype: DType,
5669        weight_pack_factor: usize,
5670        _matformer_config: Option<&MatformerSliceConfig>,
5671    ) -> Result<usize> {
5672        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5673        let tie = cfg.tie_word_embeddings;
5674        let text_elems = {
5675            let cfg = &cfg.text_config;
5676            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5677            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
5678            let lm_head = if !tie || weight_pack_factor != 1 {
5679                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5680            } else {
5681                0
5682            };
5683            let norm = cfg.hidden_size;
5684            embed_tokens + lm_head + norm
5685        };
5686
5687        let patch_merger = {
5688            let cfg = &cfg.vision_config;
5689            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
5690
5691            let mlp0 = hidden_size * hidden_size + hidden_size;
5692            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
5693
5694            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5695
5696            mlp0 + mlp2 + ln_q
5697        };
5698
5699        let patch_embed = {
5700            let cfg = &cfg.vision_config;
5701            let conv_cfg = Conv3dConfig {
5702                stride: cfg.patch_size,
5703                ..Default::default()
5704            };
5705            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
5706            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
5707                * kernel_sizes[0]
5708                * kernel_sizes[1]
5709                * kernel_sizes[2]
5710        };
5711
5712        let encoder_layer = {
5713            let cfg = &cfg.vision_config;
5714            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5715            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5716
5717            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
5718            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
5719            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
5720
5721            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
5722            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
5723
5724            norm1 + norm2 + fc1 + fc2 + qkv + out
5725        };
5726
5727        let elems =
5728            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
5729
5730        Ok(elems * dtype.size_in_bytes())
5731    }
5732
5733    fn layer_sizes_in_bytes(
5734        &self,
5735        config: &str,
5736        dtype: DType,
5737        weight_pack_factor: usize,
5738        _matformer_config: Option<&MatformerSliceConfig>,
5739    ) -> Result<Vec<usize>> {
5740        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5741        let per_layer_elems = {
5742            let cfg = &cfg.text_config;
5743            let input_layernorm = cfg.hidden_size;
5744            let post_attention_layernorm = cfg.hidden_size;
5745
5746            let size_in = cfg.hidden_size;
5747            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
5748            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
5749            let q_proj = size_in * size_q / weight_pack_factor + size_q;
5750            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
5751            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
5752            let o_proj = size_q * size_in / weight_pack_factor;
5753
5754            let h_size = cfg.hidden_size;
5755            let i_size = cfg.intermediate_size;
5756            let gate_proj = h_size * i_size / weight_pack_factor;
5757            let up_proj = h_size * i_size / weight_pack_factor;
5758            let down_proj = i_size * h_size / weight_pack_factor;
5759
5760            input_layernorm
5761                + post_attention_layernorm
5762                + q_proj
5763                + k_proj
5764                + v_proj
5765                + o_proj
5766                + gate_proj
5767                + up_proj
5768                + down_proj
5769        };
5770        Ok(vec![
5771            per_layer_elems * dtype.size_in_bytes();
5772            cfg.text_config.num_hidden_layers
5773        ])
5774    }
5775
5776    fn num_layers(&self, config: &str) -> Result<usize> {
5777        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5778        let cfg = &cfg.text_config;
5779        Ok(cfg.num_hidden_layers)
5780    }
5781
5782    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5783        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5784        let cfg = &cfg.text_config;
5785
5786        let cfg = ModelConfigMetadata {
5787            max_seq_len: cfg.max_position_embeddings,
5788            num_layers: cfg.num_hidden_layers,
5789            hidden_size: cfg.hidden_size,
5790            num_kv_heads: cfg.num_key_value_heads,
5791            num_attn_heads: cfg.num_attention_heads,
5792            sliding_window: cfg.sliding_window,
5793            k_head_dim: cfg.head_dim,
5794            v_head_dim: cfg.head_dim,
5795        };
5796
5797        Ok(Box::new(cfg))
5798    }
5799
5800    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5801        Some(vec![NonMappedSubModel::Vision])
5802    }
5803}
5804
5805// ======================== Qwen3VLMoE Loader
5806
5807/// [`VisionLoader`] for a Qwen3VLMoE model.
5808///
5809/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
5810pub struct Qwen3VLMoELoader;
5811
5812pub struct Qwen3VLMoEPrefixer;
5813
5814impl MultimodalPromptPrefixer for Qwen3VLMoEPrefixer {
5815    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
5816        format!(
5817            "{}{prompt}",
5818            format!(
5819                "{}{}{}",
5820                Qwen3VLMoEProcessor::VISION_START,
5821                Qwen3VLMoEProcessor::IMAGE_PAD,
5822                Qwen3VLMoEProcessor::VISION_END
5823            )
5824            .repeat(image_indexes.len())
5825        )
5826    }
5827}
5828
5829impl VisionModelLoader for Qwen3VLMoELoader {
5830    fn load(
5831        &self,
5832        config: &str,
5833        vb: ShardedVarBuilder,
5834        normal_loading_metadata: NormalLoadingMetadata,
5835        attention_mechanism: AttentionImplementation,
5836    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
5837        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5838        Ok(Box::new(Qwen3VLMoEModel::new(
5839            &cfg,
5840            vb,
5841            self.is_gptx(config),
5842            normal_loading_metadata,
5843            attention_mechanism,
5844        )?))
5845    }
5846    fn is_gptx(&self, _config: &str) -> bool {
5847        true
5848    }
5849    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5850        let config: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5851        Ok(Box::new(config))
5852    }
5853    fn get_processor(
5854        &self,
5855        _model_config: &str,
5856        _processor_config: Option<ProcessorConfig>,
5857        _preprocessor_config: PreProcessorConfig,
5858        max_edge: Option<u32>,
5859    ) -> Arc<dyn Processor + Send + Sync> {
5860        Arc::new(Qwen3VLMoEProcessor::new(max_edge))
5861    }
5862    fn supports_paged_attention(&self, _config: &str) -> bool {
5863        true
5864    }
5865    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5866        Arc::new(Qwen3VLMoEPrefixer)
5867    }
5868    fn modalities(&self, _config: &str) -> Result<Modalities> {
5869        Ok(Modalities {
5870            input: vec![SupportedModality::Text, SupportedModality::Vision],
5871            output: vec![SupportedModality::Text],
5872        })
5873    }
5874}
5875
5876impl IsqModelLoader for Qwen3VLMoELoader {
5877    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5878        Ok(vec![
5879            Regex::new(r"lm_head\.(weight|bias)$")?,
5880            // Attention
5881            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5882            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5883            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5884            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5885            // MLP (dense layers)
5886            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5887            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5888            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5889            // MoE router
5890            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate\.(weight|bias)$")?,
5891            // MoE experts - now unpacked into individual experts
5892            Regex::new(
5893                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
5894            )?,
5895            Regex::new(
5896                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$",
5897            )?,
5898            Regex::new(
5899                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$",
5900            )?,
5901        ])
5902    }
5903    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5904        self.isq_layer_regexes(config)
5905    }
5906}
5907
5908impl DeviceMappedModelLoader for Qwen3VLMoELoader {
5909    fn mapped_max_act_size_elems(
5910        &self,
5911        config: &str,
5912        params: &AutoDeviceMapParams,
5913    ) -> Result<usize> {
5914        let AutoDeviceMapParams::Vision {
5915            max_seq_len,
5916            max_batch_size,
5917            max_image_shape,
5918            max_num_images,
5919        } = params
5920        else {
5921            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5922        };
5923
5924        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5925
5926        // For images, grid_t=1. After spatial merging, grid_h and grid_w are reduced.
5927        let img_seq_len = {
5928            let cfg = &cfg.vision_config;
5929            // grid_t is 1 for images (temporal dimension is for video only)
5930            let grid_t = 1;
5931            // After patch embedding and spatial merge, the effective grid dimensions are reduced
5932            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
5933            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
5934            grid_t * grid_h * grid_w * max_num_images
5935        };
5936
5937        let max_text_attn = {
5938            let cfg = &cfg.text_config;
5939            // This model injects the vision information directly into the input embeddings
5940            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5941            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5942        };
5943
5944        Ok(max_text_attn)
5945    }
5946
5947    fn non_mapped_max_act_size_elems(
5948        &self,
5949        config: &str,
5950        params: &AutoDeviceMapParams,
5951    ) -> Result<usize> {
5952        let AutoDeviceMapParams::Vision {
5953            max_seq_len: _,
5954            max_batch_size,
5955            max_image_shape,
5956            max_num_images,
5957        } = params
5958        else {
5959            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5960        };
5961
5962        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5963
5964        // For the vision encoder, before spatial merging
5965        let img_seq_len = {
5966            let cfg = &cfg.vision_config;
5967            // grid_t is 1 for images
5968            let grid_t = 1;
5969            let grid_h = max_image_shape.0 / cfg.patch_size;
5970            let grid_w = max_image_shape.1 / cfg.patch_size;
5971            grid_t * grid_h * grid_w
5972        };
5973
5974        let max_vision_attn = {
5975            let cfg = &cfg.vision_config;
5976            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5977        };
5978
5979        Ok(max_vision_attn)
5980    }
5981
5982    fn non_mapped_size_in_bytes(
5983        &self,
5984        config: &str,
5985        dtype: DType,
5986        weight_pack_factor: usize,
5987        _matformer_config: Option<&MatformerSliceConfig>,
5988    ) -> Result<usize> {
5989        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5990        let tie = cfg.tie_word_embeddings;
5991        let text_elems = {
5992            let cfg = &cfg.text_config;
5993            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5994            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
5995            let lm_head = if !tie || weight_pack_factor != 1 {
5996                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5997            } else {
5998                0
5999            };
6000            let norm = cfg.hidden_size;
6001            embed_tokens + lm_head + norm
6002        };
6003
6004        let patch_merger = {
6005            let cfg = &cfg.vision_config;
6006            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
6007
6008            let mlp0 = hidden_size * hidden_size + hidden_size;
6009            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
6010
6011            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6012
6013            mlp0 + mlp2 + ln_q
6014        };
6015
6016        let patch_embed = {
6017            let cfg = &cfg.vision_config;
6018            let conv_cfg = Conv3dConfig {
6019                stride: cfg.patch_size,
6020                ..Default::default()
6021            };
6022            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
6023            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
6024                * kernel_sizes[0]
6025                * kernel_sizes[1]
6026                * kernel_sizes[2]
6027        };
6028
6029        let encoder_layer = {
6030            let cfg = &cfg.vision_config;
6031            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6032            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6033
6034            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
6035            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
6036            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
6037
6038            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
6039            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
6040
6041            norm1 + norm2 + fc1 + fc2 + qkv + out
6042        };
6043
6044        let elems =
6045            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
6046
6047        Ok(elems * dtype.size_in_bytes())
6048    }
6049
6050    fn layer_sizes_in_bytes(
6051        &self,
6052        config: &str,
6053        dtype: DType,
6054        weight_pack_factor: usize,
6055        _matformer_config: Option<&MatformerSliceConfig>,
6056    ) -> Result<Vec<usize>> {
6057        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6058        let text_cfg = &cfg.text_config;
6059
6060        let mut layer_sizes = Vec::with_capacity(text_cfg.num_hidden_layers);
6061
6062        for layer_idx in 0..text_cfg.num_hidden_layers {
6063            let input_layernorm = text_cfg.hidden_size;
6064            let post_attention_layernorm = text_cfg.hidden_size;
6065
6066            let size_in = text_cfg.hidden_size;
6067            let size_q = (text_cfg.hidden_size / text_cfg.num_attention_heads)
6068                * text_cfg.num_attention_heads;
6069            let size_kv = (text_cfg.hidden_size / text_cfg.num_attention_heads)
6070                * text_cfg.num_key_value_heads;
6071            let q_proj = size_in * size_q / weight_pack_factor + size_q;
6072            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
6073            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
6074            let o_proj = size_q * size_in / weight_pack_factor;
6075
6076            // Check if this is a MoE layer
6077            let is_moe = !text_cfg.mlp_only_layers.contains(&layer_idx)
6078                && (text_cfg.num_experts > 0
6079                    && (layer_idx + 1) % text_cfg.decoder_sparse_step == 0);
6080
6081            let mlp_elems = if is_moe {
6082                // MoE layer: gate + experts
6083                let gate = text_cfg.hidden_size * text_cfg.num_experts;
6084                let per_expert = {
6085                    let h_size = text_cfg.hidden_size;
6086                    let i_size = text_cfg.moe_intermediate_size;
6087                    let gate_proj = h_size * i_size / weight_pack_factor;
6088                    let up_proj = h_size * i_size / weight_pack_factor;
6089                    let down_proj = i_size * h_size / weight_pack_factor;
6090                    gate_proj + up_proj + down_proj
6091                };
6092                gate + per_expert * text_cfg.num_experts
6093            } else {
6094                // Dense MLP layer
6095                let h_size = text_cfg.hidden_size;
6096                let i_size = text_cfg.intermediate_size;
6097                let gate_proj = h_size * i_size / weight_pack_factor;
6098                let up_proj = h_size * i_size / weight_pack_factor;
6099                let down_proj = i_size * h_size / weight_pack_factor;
6100                gate_proj + up_proj + down_proj
6101            };
6102
6103            let per_layer_elems = input_layernorm
6104                + post_attention_layernorm
6105                + q_proj
6106                + k_proj
6107                + v_proj
6108                + o_proj
6109                + mlp_elems;
6110
6111            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
6112        }
6113
6114        Ok(layer_sizes)
6115    }
6116
6117    fn num_layers(&self, config: &str) -> Result<usize> {
6118        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6119        let cfg = &cfg.text_config;
6120        Ok(cfg.num_hidden_layers)
6121    }
6122
6123    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
6124        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6125        let cfg = &cfg.text_config;
6126
6127        let cfg = ModelConfigMetadata {
6128            max_seq_len: cfg.max_position_embeddings,
6129            num_layers: cfg.num_hidden_layers,
6130            hidden_size: cfg.hidden_size,
6131            num_kv_heads: cfg.num_key_value_heads,
6132            num_attn_heads: cfg.num_attention_heads,
6133            sliding_window: cfg.sliding_window,
6134            k_head_dim: cfg.head_dim,
6135            v_head_dim: cfg.head_dim,
6136        };
6137
6138        Ok(Box::new(cfg))
6139    }
6140
6141    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
6142        Some(vec![NonMappedSubModel::Vision])
6143    }
6144}