mistralrs_core/pipeline/loaders/
vision_loaders.rs

1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::attention::ATTENTION_CHUNK_SIZE;
24use crate::device_map::DeviceMapper;
25use crate::layers::Conv3dConfig;
26use crate::matformer::MatformerSliceConfig;
27use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
28use crate::pipeline::isq::IsqModelLoader;
29use crate::pipeline::loaders::AutoDeviceMapParams;
30use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
31use crate::pipeline::{
32    EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
33    SupportedModality,
34};
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::vision_models::clip::ClipConfig;
37use crate::vision_models::gemma3::config::Gemma3Config;
38use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
39use crate::vision_models::gemma3n::config::{Gemma3nConfig, IntermediateSize};
40use crate::vision_models::gemma3n::{Gemma3nModel, Gemma3nProcessor};
41use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
42use crate::vision_models::idefics2_input_processor::Idefics2Processor;
43use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
44use crate::vision_models::image_processor::ImagePreProcessor;
45use crate::vision_models::inputs_processor::Phi4MMProcessor;
46use crate::vision_models::llama4::{
47    self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
48};
49use crate::vision_models::llava::config::Config as LLaVAConfig;
50use crate::vision_models::llava15::Model as LLaVA;
51use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
52use crate::vision_models::llava_next::Model as LLaVANext;
53use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
54use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
55use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
56use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
57use crate::vision_models::phi3_inputs_processor::Phi3Processor;
58use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
59use crate::vision_models::preprocessor_config::PreProcessorConfig;
60use crate::vision_models::processor_config::ProcessorConfig;
61use crate::vision_models::qwen2_5_vl::{
62    Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
63};
64use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
65use crate::vision_models::{minicpmo, phi4};
66
67pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
68    // pixel_values and pixel_attention_mask only specified for prompt seqs
69    #[allow(clippy::too_many_arguments)]
70    fn forward(
71        &self,
72        input_ids: &Tensor,
73        pixel_values: Option<Tensor>,
74        seqlen_offsets: &[usize],
75        context_lens: Vec<(usize, usize)>,
76        position_ids: Vec<usize>,
77        model_specific_args: Box<dyn Any>, // pixel attention mask, or image sizes, or anything else
78        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
79        flash_params: &FlashParams,
80    ) -> candle_core::Result<Tensor>;
81    fn device(&self) -> &Device;
82    fn cache(&self) -> &EitherCache;
83    fn cache_mut(&mut self) -> &mut EitherCache;
84    fn max_seq_len(&self) -> usize;
85    fn config(&self) -> &ModelConfigMetadata;
86    /// For a prompt without images. Requires batch size of 1!
87    fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
88}
89
90pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
91    fn load(
92        &self,
93        config: &str,
94        vb: ShardedVarBuilder,
95        normal_loading_metadata: NormalLoadingMetadata,
96        attention_mechanism: AttentionImplementation,
97    ) -> Result<Box<dyn VisionModel + Send + Sync>>;
98    fn is_gptx(&self, config: &str) -> bool;
99    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
100    fn get_processor(
101        &self,
102        model_config: &str,
103        processor_config: Option<ProcessorConfig>,
104        preprocessor_config: PreProcessorConfig,
105        max_edge: Option<u32>,
106    ) -> Arc<dyn Processor + Send + Sync>;
107    fn supports_paged_attention(&self, config: &str) -> bool;
108    fn supports_prefix_cacher(&self, _config: &str) -> bool {
109        // Default is false, specific model must override.
110        false
111    }
112    fn modalities(&self, config: &str) -> Result<Modalities>;
113    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
114    fn get_device_for_tensor(
115        &self,
116        config: &str,
117        _mapper: &dyn DeviceMapper,
118        loading_isq: bool,
119    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
120        if loading_isq {
121            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
122        } else {
123            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
124            let num_layers = self.model_config(config)?.num_layers();
125            let closure = move |name: String| {
126                if let Some(captures) = re.captures(&name) {
127                    captures
128                        .get(1)
129                        .and_then(|m| m.as_str().parse::<usize>().ok())
130                        .map(|l| l.min(num_layers))
131                        .map(DeviceForLoadTensor::Idx)
132                        .unwrap_or(DeviceForLoadTensor::Base)
133                } else {
134                    DeviceForLoadTensor::Base
135                }
136            };
137
138            Ok(Arc::new(closure))
139        }
140    }
141}
142
143#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
144#[derive(Clone, Debug, Deserialize, PartialEq)]
145/// The architecture to load the vision model as.
146pub enum VisionLoaderType {
147    #[serde(rename = "phi3v")]
148    Phi3V,
149    #[serde(rename = "idefics2")]
150    Idefics2,
151    #[serde(rename = "llava_next")]
152    LLaVANext,
153    #[serde(rename = "llava")]
154    LLaVA,
155    #[serde(rename = "vllama")]
156    VLlama,
157    #[serde(rename = "qwen2vl")]
158    Qwen2VL,
159    #[serde(rename = "idefics3")]
160    Idefics3,
161    #[serde(rename = "minicpmo")]
162    MiniCpmO,
163    #[serde(rename = "phi4mm")]
164    Phi4MM,
165    #[serde(rename = "qwen2_5vl")]
166    Qwen2_5VL,
167    #[serde(rename = "gemma3")]
168    Gemma3,
169    #[serde(rename = "mistral3")]
170    Mistral3,
171    #[serde(rename = "llama4")]
172    Llama4,
173    #[serde(rename = "gemma3n")]
174    Gemma3n,
175}
176
177// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
178impl VisionLoaderType {
179    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
180        match name {
181            "Phi3VForCausalLM" => Ok(Self::Phi3V),
182            "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
183            "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
184            "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
185            "MllamaForConditionalGeneration" => Ok(Self::VLlama),
186            "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
187            "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
188            "MiniCPMO" => Ok(Self::MiniCpmO),
189            "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
190            "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
191            "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
192            "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
193            "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
194            "Gemma3nForConditionalGeneration" => Ok(Self::Gemma3n),
195            other => anyhow::bail!(
196                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
197            ),
198        }
199    }
200}
201
202impl FromStr for VisionLoaderType {
203    type Err = String;
204    fn from_str(s: &str) -> Result<Self, Self::Err> {
205        match s {
206            "phi3v" => Ok(Self::Phi3V),
207            "idefics2" => Ok(Self::Idefics2),
208            "llava_next" => Ok(Self::LLaVANext),
209            "llava" => Ok(Self::LLaVA),
210            "vllama" => Ok(Self::VLlama),
211            "qwen2vl" => Ok(Self::Qwen2VL),
212            "idefics3" => Ok(Self::Idefics3),
213            "minicpmo" => Ok(Self::MiniCpmO),
214            "phi4mm" => Ok(Self::Phi4MM),
215            "qwen2_5vl" => Ok(Self::Qwen2_5VL),
216            "gemma3" => Ok(Self::Gemma3),
217            "mistral3" => Ok(Self::Mistral3),
218            "llama4" => Ok(Self::Llama4),
219            "gemma3n" => Ok(Self::Gemma3n),
220            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`, `gemma3n`.")),
221        }
222    }
223}
224
225impl std::fmt::Display for VisionLoaderType {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        let name = match self {
228            VisionLoaderType::Phi3V => "phi3v",
229            VisionLoaderType::Idefics2 => "idefics2",
230            VisionLoaderType::LLaVANext => "llava_next",
231            VisionLoaderType::LLaVA => "llava",
232            VisionLoaderType::VLlama => "vllama",
233            VisionLoaderType::Qwen2VL => "qwen2vl",
234            VisionLoaderType::Idefics3 => "idefics3",
235            VisionLoaderType::MiniCpmO => "minicpmo",
236            VisionLoaderType::Phi4MM => "phi4mm",
237            VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
238            VisionLoaderType::Gemma3 => "gemma3",
239            VisionLoaderType::Mistral3 => "mistral3",
240            VisionLoaderType::Llama4 => "llama4",
241            VisionLoaderType::Gemma3n => "gemma3n",
242        };
243        write!(f, "{name}")
244    }
245}
246
247#[derive(Deserialize)]
248struct AutoVisionLoaderConfig {
249    architectures: Vec<String>,
250}
251
252/// Automatically selects a VisionModelLoader implementation based on the JSON `architectures` field.
253pub struct AutoVisionLoader;
254
255impl AutoVisionLoader {
256    fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
257        let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
258        if auto_cfg.architectures.len() != 1 {
259            anyhow::bail!("Expected exactly one architecture in config");
260        }
261
262        let name = &auto_cfg.architectures[0];
263        let tp = VisionLoaderType::from_causal_lm_name(name)?;
264
265        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
266
267        // Delegate to the concrete loader
268        Ok(match tp {
269            VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
270            VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
271            VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
272            VisionLoaderType::LLaVA => Box::new(LLaVALoader),
273            VisionLoaderType::VLlama => Box::new(VLlamaLoader),
274            VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
275            VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
276            VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
277            VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
278            VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
279            VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
280            VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
281            VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
282            VisionLoaderType::Gemma3n => Box::new(Gemma3nLoader),
283        })
284    }
285}
286
287impl VisionModelLoader for AutoVisionLoader {
288    fn load(
289        &self,
290        config: &str,
291        vb: ShardedVarBuilder,
292        normal_loading_metadata: NormalLoadingMetadata,
293        attention_mechanism: AttentionImplementation,
294    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
295        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
296    }
297
298    fn is_gptx(&self, config: &str) -> bool {
299        Self::get_loader(config)
300            .expect("AutoVisionLoader get_loader")
301            .is_gptx(config)
302    }
303
304    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
305        Self::get_loader(config)?.get_config_repr(config)
306    }
307
308    fn get_processor(
309        &self,
310        model_config: &str,
311        proc_cfg: Option<ProcessorConfig>,
312        preproc_cfg: PreProcessorConfig,
313        max_edge: Option<u32>,
314    ) -> Arc<dyn Processor + Send + Sync> {
315        Self::get_loader(model_config)
316            .expect("AutoVisionLoader get_loader")
317            .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
318    }
319
320    fn supports_paged_attention(&self, config: &str) -> bool {
321        Self::get_loader(config)
322            .expect("AutoVisionLoader")
323            .supports_paged_attention(config)
324    }
325
326    fn modalities(&self, config: &str) -> Result<Modalities> {
327        Self::get_loader(config)?.modalities(config)
328    }
329
330    fn supports_prefix_cacher(&self, config: &str) -> bool {
331        Self::get_loader(config)
332            .expect("AutoVisionLoader")
333            .supports_prefix_cacher(config)
334    }
335
336    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
337        Self::get_loader(config)
338            .expect("AutoVisionLoader")
339            .prefixer(config)
340    }
341
342    fn get_device_for_tensor(
343        &self,
344        config: &str,
345        mapper: &dyn DeviceMapper,
346        loading_isq: bool,
347    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
348        Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
349    }
350}
351
352impl IsqModelLoader for AutoVisionLoader {
353    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
354        Self::get_loader(config)?.isq_layer_regexes(config)
355    }
356    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
357        Self::get_loader(config)?.immediate_isq_predicates(config)
358    }
359}
360
361impl DeviceMappedModelLoader for AutoVisionLoader {
362    fn mapped_max_act_size_elems(
363        &self,
364        config: &str,
365        params: &AutoDeviceMapParams,
366    ) -> Result<usize> {
367        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
368    }
369    fn non_mapped_max_act_size_elems(
370        &self,
371        config: &str,
372        params: &AutoDeviceMapParams,
373    ) -> Result<usize> {
374        Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
375    }
376    fn non_mapped_size_in_bytes(
377        &self,
378        config: &str,
379        dtype: DType,
380        weight_pack_factor: usize,
381        _matformer_config: Option<&MatformerSliceConfig>,
382    ) -> Result<usize> {
383        Self::get_loader(config)?.non_mapped_size_in_bytes(
384            config,
385            dtype,
386            weight_pack_factor,
387            _matformer_config,
388        )
389    }
390    fn layer_sizes_in_bytes(
391        &self,
392        config: &str,
393        dtype: DType,
394        weight_pack_factor: usize,
395        _matformer_config: Option<&MatformerSliceConfig>,
396    ) -> Result<Vec<usize>> {
397        Self::get_loader(config)?.layer_sizes_in_bytes(
398            config,
399            dtype,
400            weight_pack_factor,
401            _matformer_config,
402        )
403    }
404    fn num_layers(&self, config: &str) -> Result<usize> {
405        Self::get_loader(config)?.num_layers(config)
406    }
407    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
408        Self::get_loader(config)?.model_config(config)
409    }
410}
411
412macro_rules! bias_if {
413    ($cond:expr, $size:expr) => {
414        if $cond {
415            $size
416        } else {
417            0
418        }
419    };
420}
421
422fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
423    let pre_layer_norm = cfg.hidden_size;
424    let final_layer_norm = cfg.hidden_size;
425
426    let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
427    let num_positions = num_patches + 1;
428
429    let class_embedding = cfg.hidden_size;
430
431    let position_ids = num_positions;
432    let position_embedding = num_positions * cfg.hidden_size;
433
434    let conv2dconfig = Conv2dConfig {
435        stride: cfg.patch_size,
436        ..Default::default()
437    };
438    let patch_embedding =
439        cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
440
441    let encoder_layer_elems = {
442        let layer_norm1 = cfg.hidden_size;
443        let layer_norm2 = cfg.hidden_size;
444
445        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
446        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
447        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
448        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
449
450        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
451        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
452
453        layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
454    };
455
456    pre_layer_norm
457        + final_layer_norm
458        + class_embedding
459        + position_ids
460        + position_embedding
461        + patch_embedding
462        + cfg.num_hidden_layers * encoder_layer_elems
463}
464
465// ======================== Phi 3 loader
466
467/// [`VisionLoader`] for a Phi 3 Vision model.
468///
469/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
470pub struct Phi3VLoader;
471
472pub struct Phi3VPrefixer;
473
474impl MultimodalPromptPrefixer for Phi3VPrefixer {
475    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
476        // Image indexing starts at 0.
477        format!(
478            "{}{prompt}",
479            image_indexes
480                .into_iter()
481                .map(|image_index| format!("<|image_{}|>", image_index + 1))
482                .join("")
483        )
484    }
485}
486
487impl VisionModelLoader for Phi3VLoader {
488    fn load(
489        &self,
490        config: &str,
491        vb: ShardedVarBuilder,
492        normal_loading_metadata: NormalLoadingMetadata,
493        attention_mechanism: AttentionImplementation,
494    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
495        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
496        Ok(Box::new(Phi3::new(
497            &cfg,
498            vb,
499            self.is_gptx(config),
500            normal_loading_metadata,
501            attention_mechanism,
502        )?))
503    }
504    fn is_gptx(&self, _config: &str) -> bool {
505        true
506    }
507    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
508        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
509        Ok(Box::new(cfg))
510    }
511    fn get_processor(
512        &self,
513        _model_config: &str,
514        processor_config: Option<ProcessorConfig>,
515        preprocessor_config: PreProcessorConfig,
516        _max_edge: Option<u32>,
517    ) -> Arc<dyn Processor + Send + Sync> {
518        Phi3Processor::new_processor(processor_config, preprocessor_config)
519    }
520    fn supports_paged_attention(&self, _config: &str) -> bool {
521        true
522    }
523    fn supports_prefix_cacher(&self, _config: &str) -> bool {
524        true
525    }
526    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
527        Arc::new(Phi3VPrefixer)
528    }
529    fn modalities(&self, _config: &str) -> Result<Modalities> {
530        Ok(Modalities {
531            input: vec![SupportedModality::Text, SupportedModality::Vision],
532            output: vec![SupportedModality::Text],
533        })
534    }
535}
536
537impl IsqModelLoader for Phi3VLoader {
538    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
539        Ok(vec![
540            Regex::new(r"lm_head\.(weight|bias)$")?,
541            // Attention
542            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
543            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
544            // MLP
545            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
546            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
547        ])
548    }
549    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
550        self.isq_layer_regexes(config)
551    }
552}
553
554impl DeviceMappedModelLoader for Phi3VLoader {
555    fn mapped_max_act_size_elems(
556        &self,
557        config: &str,
558        params: &AutoDeviceMapParams,
559    ) -> Result<usize> {
560        // NOTE: we ignore max_num_images although it can only be one...
561        let AutoDeviceMapParams::Vision {
562            max_seq_len,
563            max_batch_size,
564            max_image_shape: _,
565            max_num_images,
566        } = params
567        else {
568            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
569        };
570
571        let cfg: Phi3Config = serde_json::from_str(config)?;
572
573        let vcfg = &PHI3V_CLIP_CONFIG;
574
575        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
576        let img_seq_len = (num_patches + 1) * max_num_images;
577
578        let max_text_attn = {
579            // This model injects the vision information directly into the input embeddings
580            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
581            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
582        };
583
584        Ok(max_text_attn)
585    }
586
587    fn non_mapped_max_act_size_elems(
588        &self,
589        config: &str,
590        params: &AutoDeviceMapParams,
591    ) -> Result<usize> {
592        // NOTE: we ignore max_num_images although it can only be one...
593        let AutoDeviceMapParams::Vision {
594            max_seq_len: _,
595            max_batch_size,
596            max_image_shape: _,
597            max_num_images,
598        } = params
599        else {
600            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
601        };
602
603        let cfg: Phi3Config = serde_json::from_str(config)?;
604
605        let vcfg = &PHI3V_CLIP_CONFIG;
606
607        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
608        let img_seq_len = num_patches + 1;
609
610        let max_vision_attn = {
611            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
612        };
613
614        Ok(max_vision_attn)
615    }
616
617    fn non_mapped_size_in_bytes(
618        &self,
619        config: &str,
620        dtype: DType,
621        weight_pack_factor: usize,
622        _matformer_config: Option<&MatformerSliceConfig>,
623    ) -> Result<usize> {
624        let cfg: Phi3Config = serde_json::from_str(config)?;
625        let elems = {
626            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
627            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
628            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
629                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
630            } else {
631                0
632            };
633            let norm = cfg.hidden_size;
634
635            let image_embed = {
636                let projection_cls = cfg
637                    .embd_layer
638                    .projection_cls
639                    .clone()
640                    .unwrap_or("linear".to_string());
641                let with_learnable_separator =
642                    cfg.embd_layer.with_learnable_separator.unwrap_or(false);
643                let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
644                let image_dim_out = cfg.img_processor.image_dim_out;
645
646                let proj = match (projection_cls.as_str(), use_hd_transform) {
647                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
648                    ("mlp", true) => {
649                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
650                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
651                        a + b
652                    }
653                    ("mlp", false) => {
654                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
655                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
656                        a + b
657                    }
658                    _ => {
659                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
660                    }
661                };
662
663                let (glb_gn, sub_gn) = if with_learnable_separator {
664                    let glb_gn = image_dim_out * 4;
665                    let sub_gn = image_dim_out * 4;
666                    (glb_gn, sub_gn)
667                } else {
668                    (0, 0)
669                };
670
671                let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
672
673                proj + glb_gn + sub_gn + clip_vit
674            };
675
676            embed_tokens + lm_head + norm + image_embed
677        };
678
679        Ok(elems * dtype.size_in_bytes())
680    }
681
682    fn layer_sizes_in_bytes(
683        &self,
684        config: &str,
685        dtype: DType,
686        weight_pack_factor: usize,
687        _matformer_config: Option<&MatformerSliceConfig>,
688    ) -> Result<Vec<usize>> {
689        let cfg: Phi3Config = serde_json::from_str(config)?;
690        let per_layer_elems = {
691            let input_layernorm = cfg.hidden_size;
692            let post_attention_layernorm = cfg.hidden_size;
693
694            let size_in = cfg.hidden_size;
695            let head_dim = cfg.head_dim();
696            let op_size =
697                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
698            let qkv_proj = size_in * op_size / weight_pack_factor;
699            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
700
701            let h_size = cfg.hidden_size;
702            let i_size = cfg.intermediate_size;
703            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
704            let down_proj = h_size * i_size / weight_pack_factor;
705
706            input_layernorm
707                + post_attention_layernorm
708                + qkv_proj
709                + o_proj
710                + gate_up_proj
711                + down_proj
712        };
713        Ok(vec![
714            per_layer_elems * dtype.size_in_bytes();
715            cfg.num_hidden_layers
716        ])
717    }
718
719    fn num_layers(&self, config: &str) -> Result<usize> {
720        let cfg: Phi3Config = serde_json::from_str(config)?;
721        Ok(cfg.num_hidden_layers)
722    }
723
724    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
725        let cfg: Phi3Config = serde_json::from_str(config)?;
726
727        let cfg = ModelConfigMetadata {
728            max_seq_len: cfg.max_position_embeddings,
729            num_layers: cfg.num_hidden_layers,
730            hidden_size: cfg.hidden_size,
731            num_kv_heads: cfg.num_key_value_heads,
732            num_attn_heads: cfg.num_attention_heads,
733            sliding_window: cfg.sliding_window,
734            k_head_dim: cfg.head_dim(),
735            v_head_dim: cfg.head_dim(),
736        };
737
738        Ok(Box::new(cfg))
739    }
740
741    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
742        Some(vec![NonMappedSubModel::Vision])
743    }
744}
745
746// ======================== Idefics 2 loader
747
748/// [`VisionLoader`] for an Idefics 2 Vision model.
749///
750/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
751pub struct Idefics2Loader;
752
753pub struct Idefics2Prefixer;
754
755impl MultimodalPromptPrefixer for Idefics2Prefixer {
756    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
757        // Chat template does it
758        prompt.to_string()
759    }
760}
761
762impl VisionModelLoader for Idefics2Loader {
763    fn load(
764        &self,
765        config: &str,
766        vb: ShardedVarBuilder,
767        normal_loading_metadata: NormalLoadingMetadata,
768        attention_mechanism: AttentionImplementation,
769    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
770        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
771        Ok(Box::new(Idefics2::new(
772            &cfg,
773            vb,
774            self.is_gptx(config),
775            normal_loading_metadata,
776            attention_mechanism,
777        )?))
778    }
779    fn is_gptx(&self, _config: &str) -> bool {
780        true
781    }
782    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
783        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
784        Ok(Box::new(cfg))
785    }
786    fn get_processor(
787        &self,
788        _model_config: &str,
789        processor_config: Option<ProcessorConfig>,
790        preprocessor_config: PreProcessorConfig,
791        max_edge: Option<u32>,
792    ) -> Arc<dyn Processor + Send + Sync> {
793        Arc::new(Idefics2Processor::new(
794            processor_config.unwrap(),
795            preprocessor_config,
796            max_edge,
797        ))
798    }
799    fn supports_paged_attention(&self, _config: &str) -> bool {
800        true
801    }
802    fn supports_prefix_cacher(&self, _config: &str) -> bool {
803        true
804    }
805    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
806        Arc::new(Idefics2Prefixer)
807    }
808    fn modalities(&self, _config: &str) -> Result<Modalities> {
809        Ok(Modalities {
810            input: vec![SupportedModality::Text, SupportedModality::Vision],
811            output: vec![SupportedModality::Text],
812        })
813    }
814}
815
816impl IsqModelLoader for Idefics2Loader {
817    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
818        Ok(vec![
819            Regex::new(r"lm_head\.(weight|bias)$")?,
820            // Attention
821            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
822            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
823            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
824            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
825            // MLP
826            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
827            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
828            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
829        ])
830    }
831    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
832        Ok(vec![
833            Regex::new(r"lm_head\.(weight|bias)$")?,
834            // Attention
835            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
836            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
837            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
838            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
839            // MLP
840            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
841            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
842            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
843        ])
844    }
845}
846
847impl DeviceMappedModelLoader for Idefics2Loader {
848    fn mapped_max_act_size_elems(
849        &self,
850        config: &str,
851        params: &AutoDeviceMapParams,
852    ) -> Result<usize> {
853        let AutoDeviceMapParams::Vision {
854            max_seq_len,
855            max_batch_size,
856            max_image_shape: _,
857            max_num_images,
858        } = params
859        else {
860            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
861        };
862
863        let cfg: Idefics2Config = serde_json::from_str(config)?;
864
865        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
866        let img_seq_len = (num_patches + 1) * max_num_images;
867
868        let max_text_attn = {
869            // This model injects the vision information directly into the input embeddings
870            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
871            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
872        };
873
874        Ok(max_text_attn)
875    }
876
877    fn non_mapped_max_act_size_elems(
878        &self,
879        config: &str,
880        params: &AutoDeviceMapParams,
881    ) -> Result<usize> {
882        let AutoDeviceMapParams::Vision {
883            max_seq_len: _,
884            max_batch_size,
885            max_image_shape: _,
886            max_num_images,
887        } = params
888        else {
889            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
890        };
891
892        let cfg: Idefics2Config = serde_json::from_str(config)?;
893
894        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
895        let img_seq_len = num_patches + 1;
896
897        let max_vision_attn = {
898            // do_image_splitting = true
899            let images_factor = 5;
900
901            (max_batch_size * images_factor * max_num_images)
902                * cfg.vision_config.num_attention_heads
903                * img_seq_len
904                * img_seq_len
905        };
906
907        Ok(max_vision_attn)
908    }
909
910    fn non_mapped_size_in_bytes(
911        &self,
912        config: &str,
913        dtype: DType,
914        weight_pack_factor: usize,
915        _matformer_config: Option<&MatformerSliceConfig>,
916    ) -> Result<usize> {
917        let cfg: Idefics2Config = serde_json::from_str(config)?;
918        let text_elems = {
919            let tie_word_embeddings = cfg.tie_word_embeddings;
920            let cfg = &cfg.text_config;
921
922            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
923            let lm_head = if !tie_word_embeddings {
924                cfg.hidden_size * cfg.vocab_size
925            } else {
926                0
927            };
928            let norm = cfg.hidden_size;
929            embed_tokens + lm_head + norm
930        };
931
932        let connector_elems = {
933            let tcfg = &cfg.text_config;
934            let vcfg = &cfg.vision_config;
935            let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
936            let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
937            let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
938
939            let perceiver_elems = {
940                let tcfg = &cfg.text_config;
941                let pcfg = &cfg.perceiver_config;
942
943                let n_latents = pcfg.resampler_n_latents;
944                let hidden_size = tcfg.hidden_size;
945                let depth = pcfg.resampler_depth;
946
947                let norm = tcfg.hidden_size;
948                let latents = n_latents * hidden_size;
949
950                let layer_elems = {
951                    let input_latents_norm = hidden_size;
952                    let input_context_norm = hidden_size;
953                    let post_attn_norm = hidden_size;
954
955                    let num_heads = pcfg.resampler_n_heads;
956                    let head_dim = pcfg.resampler_head_dim;
957                    let num_key_value_heads = pcfg.num_key_value_heads;
958
959                    let q_proj = hidden_size * num_heads * head_dim;
960                    let k_proj = hidden_size * num_key_value_heads * head_dim;
961                    let v_proj = hidden_size * num_key_value_heads * head_dim;
962                    let o_proj = num_heads * head_dim * hidden_size;
963
964                    let gate_proj = hidden_size * hidden_size * 4;
965                    let up_proj = hidden_size * hidden_size * 4;
966                    let down_proj = hidden_size * 4 * hidden_size;
967
968                    input_latents_norm
969                        + input_context_norm
970                        + post_attn_norm
971                        + q_proj
972                        + k_proj
973                        + v_proj
974                        + o_proj
975                        + gate_proj
976                        + up_proj
977                        + down_proj
978                };
979
980                norm + latents + layer_elems * depth
981            };
982
983            gate_proj + up_proj + down_proj + perceiver_elems
984        };
985
986        let vision_transformer = {
987            let cfg = &cfg.vision_config;
988
989            let post_layernorm = cfg.hidden_size;
990
991            let conv_config = Conv2dConfig {
992                stride: cfg.patch_size,
993                ..Default::default()
994            };
995            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
996                * cfg.patch_size
997                * cfg.patch_size;
998
999            let num_patches_per_side = cfg.image_size / cfg.patch_size;
1000            let num_patches = num_patches_per_side.pow(2);
1001            let position_embedding = num_patches * cfg.hidden_size;
1002
1003            let layer_elems = {
1004                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1005                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1006
1007                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
1008                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
1009
1010                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1011                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1012                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1013                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1014
1015                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
1016            };
1017
1018            post_layernorm + patch_embedding + position_embedding + layer_elems
1019        };
1020
1021        let elems = text_elems + connector_elems + vision_transformer;
1022
1023        Ok(elems * dtype.size_in_bytes())
1024    }
1025
1026    fn layer_sizes_in_bytes(
1027        &self,
1028        config: &str,
1029        dtype: DType,
1030        weight_pack_factor: usize,
1031        _matformer_config: Option<&MatformerSliceConfig>,
1032    ) -> Result<Vec<usize>> {
1033        let cfg: Idefics2Config = serde_json::from_str(config)?;
1034        let cfg = cfg.text_config;
1035        let per_layer_elems = {
1036            let input_layernorm = cfg.hidden_size;
1037            let post_attention_layernorm = cfg.hidden_size;
1038
1039            let size_in = cfg.hidden_size;
1040            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1041            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1042            let q_proj = size_in * size_q / weight_pack_factor;
1043            let k_proj = size_in * size_kv / weight_pack_factor;
1044            let v_proj = size_in * size_kv / weight_pack_factor;
1045            let o_proj = size_q * size_in / weight_pack_factor;
1046
1047            let h_size = cfg.hidden_size;
1048            let i_size = cfg.intermediate_size;
1049            let gate_proj = h_size * i_size / weight_pack_factor;
1050            let up_proj = h_size * i_size / weight_pack_factor;
1051            let down_proj = i_size * h_size / weight_pack_factor;
1052
1053            input_layernorm
1054                + post_attention_layernorm
1055                + q_proj
1056                + k_proj
1057                + v_proj
1058                + o_proj
1059                + gate_proj
1060                + up_proj
1061                + down_proj
1062        };
1063        Ok(vec![
1064            per_layer_elems * dtype.size_in_bytes();
1065            cfg.num_hidden_layers
1066        ])
1067    }
1068
1069    fn num_layers(&self, config: &str) -> Result<usize> {
1070        let cfg: Idefics2Config = serde_json::from_str(config)?;
1071        Ok(cfg.text_config.num_hidden_layers)
1072    }
1073    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1074        let cfg: Idefics2Config = serde_json::from_str(config)?;
1075        let cfg = &cfg.text_config;
1076
1077        let cfg = ModelConfigMetadata {
1078            max_seq_len: cfg.max_position_embeddings,
1079            num_layers: cfg.num_hidden_layers,
1080            hidden_size: cfg.hidden_size,
1081            num_kv_heads: cfg.num_key_value_heads,
1082            num_attn_heads: cfg.num_attention_heads,
1083            sliding_window: cfg.sliding_window,
1084            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1085            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1086        };
1087
1088        Ok(Box::new(cfg))
1089    }
1090
1091    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1092        Some(vec![NonMappedSubModel::Vision])
1093    }
1094}
1095
1096// ======================== LLaVANext Loader
1097
1098/// [`VisionLoader`] for an LLaVANext Vision model.
1099///
1100/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1101pub struct LLaVANextLoader;
1102
1103pub struct LLaVANextPrefixer;
1104
1105impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1106    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1107        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1108    }
1109}
1110
1111impl VisionModelLoader for LLaVANextLoader {
1112    fn load(
1113        &self,
1114        config: &str,
1115        vb: ShardedVarBuilder,
1116        normal_loading_metadata: NormalLoadingMetadata,
1117        attention_mechanism: AttentionImplementation,
1118    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1119        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1120        Ok(Box::new(LLaVANext::new(
1121            &cfg,
1122            vb,
1123            self.is_gptx(config),
1124            normal_loading_metadata,
1125            attention_mechanism,
1126        )?))
1127    }
1128    fn is_gptx(&self, _config: &str) -> bool {
1129        false
1130    }
1131    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1132        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1133        Ok(Box::new(cfg))
1134    }
1135    fn get_processor(
1136        &self,
1137        model_config: &str,
1138        _processor_config: Option<ProcessorConfig>,
1139        _preprocessor_config: PreProcessorConfig,
1140        _max_edge: Option<u32>,
1141    ) -> Arc<dyn Processor + Send + Sync> {
1142        Arc::new(LLaVANextProcessor::new(model_config))
1143    }
1144    fn supports_paged_attention(&self, _config: &str) -> bool {
1145        true
1146    }
1147    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1148        true
1149    }
1150    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1151        Arc::new(LLaVANextPrefixer)
1152    }
1153    fn modalities(&self, _config: &str) -> Result<Modalities> {
1154        Ok(Modalities {
1155            input: vec![SupportedModality::Text, SupportedModality::Vision],
1156            output: vec![SupportedModality::Text],
1157        })
1158    }
1159}
1160
1161impl IsqModelLoader for LLaVANextLoader {
1162    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1163        Ok(vec![
1164            Regex::new(r"lm_head\.(weight|bias)$")?,
1165            // Attention
1166            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1167            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1168            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1169            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1170            // MLP
1171            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1172            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1173            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1174        ])
1175    }
1176    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1177        Ok(vec![
1178            Regex::new(r"lm_head\.(weight|bias)$")?,
1179            // Attention
1180            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1181            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1182            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1183            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1184            // MLP
1185            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1186            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1187            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1188        ])
1189    }
1190}
1191
1192impl DeviceMappedModelLoader for LLaVANextLoader {
1193    fn mapped_max_act_size_elems(
1194        &self,
1195        config: &str,
1196        params: &AutoDeviceMapParams,
1197    ) -> Result<usize> {
1198        let AutoDeviceMapParams::Vision {
1199            max_seq_len,
1200            max_batch_size,
1201            max_image_shape,
1202            max_num_images,
1203        } = params
1204        else {
1205            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1206        };
1207
1208        let config: LLaVAConfig = serde_json::from_str(config)?;
1209
1210        #[allow(clippy::cast_possible_truncation)]
1211        let img_seq_len =
1212            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1213                &config,
1214                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1215            );
1216        let img_seq_len = img_seq_len * max_num_images;
1217
1218        let max_text_attn = {
1219            let cfg = &config.text_config;
1220            // This model injects the vision information directly into the input embeddings
1221            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1222
1223            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1224        };
1225
1226        Ok(max_text_attn)
1227    }
1228
1229    fn non_mapped_max_act_size_elems(
1230        &self,
1231        config: &str,
1232        params: &AutoDeviceMapParams,
1233    ) -> Result<usize> {
1234        let AutoDeviceMapParams::Vision {
1235            max_seq_len: _,
1236            max_batch_size,
1237            max_image_shape,
1238            max_num_images,
1239        } = params
1240        else {
1241            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1242        };
1243
1244        let config: LLaVAConfig = serde_json::from_str(config)?;
1245
1246        #[allow(clippy::cast_possible_truncation)]
1247        let img_seq_len =
1248            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1249                &config,
1250                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1251            );
1252
1253        let max_vision_attn = {
1254            (max_batch_size * max_num_images)
1255                * config.vision_config.num_attention_heads
1256                * img_seq_len
1257                * img_seq_len
1258        };
1259
1260        Ok(max_vision_attn)
1261    }
1262
1263    fn non_mapped_size_in_bytes(
1264        &self,
1265        config: &str,
1266        dtype: DType,
1267        weight_pack_factor: usize,
1268        _matformer_config: Option<&MatformerSliceConfig>,
1269    ) -> Result<usize> {
1270        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1271        let text_elems = {
1272            let cfg = &cfg.text_config;
1273            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1274            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1275            let norm = cfg.hidden_size;
1276            embed_tokens + lm_head + norm
1277        };
1278
1279        let image_newline = cfg.text_config.hidden_size;
1280        let mmproj = {
1281            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1282                + cfg.text_config.hidden_size;
1283            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1284                + cfg.text_config.hidden_size;
1285
1286            linear_1 + linear_2
1287        };
1288        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1289
1290        let elems = text_elems + image_newline + mmproj + vision_tower;
1291        Ok(elems * dtype.size_in_bytes())
1292    }
1293
1294    fn layer_sizes_in_bytes(
1295        &self,
1296        config: &str,
1297        dtype: DType,
1298        weight_pack_factor: usize,
1299        _matformer_config: Option<&MatformerSliceConfig>,
1300    ) -> Result<Vec<usize>> {
1301        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1302        let per_layer_elems = {
1303            let cfg = &cfg.text_config;
1304            let input_layernorm = cfg.hidden_size;
1305            let post_attention_layernorm = cfg.hidden_size;
1306
1307            let size_in = cfg.hidden_size;
1308            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1309            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1310            let q_proj = size_in * size_q / weight_pack_factor;
1311            let k_proj = size_in * size_kv / weight_pack_factor;
1312            let v_proj = size_in * size_kv / weight_pack_factor;
1313            let o_proj = size_q * size_in / weight_pack_factor;
1314
1315            let h_size = cfg.hidden_size;
1316            let i_size = cfg.intermediate_size;
1317            let gate_proj = h_size * i_size / weight_pack_factor;
1318            let up_proj = h_size * i_size / weight_pack_factor;
1319            let down_proj = i_size * h_size / weight_pack_factor;
1320
1321            input_layernorm
1322                + post_attention_layernorm
1323                + q_proj
1324                + k_proj
1325                + v_proj
1326                + o_proj
1327                + gate_proj
1328                + up_proj
1329                + down_proj
1330        };
1331        Ok(vec![
1332            per_layer_elems * dtype.size_in_bytes();
1333            cfg.text_config.num_hidden_layers
1334        ])
1335    }
1336
1337    fn num_layers(&self, config: &str) -> Result<usize> {
1338        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1339        Ok(cfg.text_config.num_hidden_layers)
1340    }
1341
1342    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1343        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1344        let cfg = &cfg.text_config;
1345
1346        let cfg = ModelConfigMetadata {
1347            max_seq_len: cfg.max_position_embeddings,
1348            num_layers: cfg.num_hidden_layers,
1349            hidden_size: cfg.hidden_size,
1350            num_kv_heads: cfg.num_key_value_heads,
1351            num_attn_heads: cfg.num_attention_heads,
1352            sliding_window: cfg.sliding_window,
1353            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1354            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1355        };
1356
1357        Ok(Box::new(cfg))
1358    }
1359
1360    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1361        Some(vec![NonMappedSubModel::Vision])
1362    }
1363}
1364
1365// ======================== LLaVA Loader
1366
1367/// [`VisionLoader`] for an LLaVA Vision model.
1368///
1369/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1370pub struct LLaVALoader;
1371
1372pub struct LLaVAPrefixer;
1373
1374impl MultimodalPromptPrefixer for LLaVAPrefixer {
1375    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1376        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1377    }
1378}
1379
1380impl VisionModelLoader for LLaVALoader {
1381    fn load(
1382        &self,
1383        config: &str,
1384        vb: ShardedVarBuilder,
1385        normal_loading_metadata: NormalLoadingMetadata,
1386        attention_mechanism: AttentionImplementation,
1387    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1388        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1389        Ok(Box::new(LLaVA::new(
1390            &cfg,
1391            vb,
1392            self.is_gptx(config),
1393            normal_loading_metadata,
1394            attention_mechanism,
1395        )?))
1396    }
1397    fn is_gptx(&self, _config: &str) -> bool {
1398        false
1399    }
1400    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1401        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1402        Ok(Box::new(cfg))
1403    }
1404    fn get_processor(
1405        &self,
1406        model_config: &str,
1407        _processor_config: Option<ProcessorConfig>,
1408        _preprocessor_config: PreProcessorConfig,
1409        _max_edge: Option<u32>,
1410    ) -> Arc<dyn Processor + Send + Sync> {
1411        Arc::new(LLaVAProcessor::new(model_config))
1412    }
1413    fn supports_paged_attention(&self, _config: &str) -> bool {
1414        true
1415    }
1416    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1417        true
1418    }
1419    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1420        Arc::new(LLaVAPrefixer)
1421    }
1422    fn modalities(&self, _config: &str) -> Result<Modalities> {
1423        Ok(Modalities {
1424            input: vec![SupportedModality::Text, SupportedModality::Vision],
1425            output: vec![SupportedModality::Text],
1426        })
1427    }
1428}
1429
1430impl IsqModelLoader for LLaVALoader {
1431    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1432        Ok(vec![
1433            Regex::new(r"lm_head\.(weight|bias)$")?,
1434            // Attention
1435            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1436            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1437            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1438            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1439            // MLP
1440            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1441            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1442            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1443        ])
1444    }
1445    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1446        Ok(vec![
1447            Regex::new(r"lm_head\.(weight|bias)$")?,
1448            // Attention
1449            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1450            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1451            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1452            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1453            // MLP
1454            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1455            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1456            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1457        ])
1458    }
1459}
1460
1461impl DeviceMappedModelLoader for LLaVALoader {
1462    fn mapped_max_act_size_elems(
1463        &self,
1464        config: &str,
1465        params: &AutoDeviceMapParams,
1466    ) -> Result<usize> {
1467        let AutoDeviceMapParams::Vision {
1468            max_seq_len,
1469            max_batch_size,
1470            max_image_shape: _,
1471            max_num_images,
1472        } = params
1473        else {
1474            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1475        };
1476
1477        let config: LLaVAConfig = serde_json::from_str(config)?;
1478
1479        let img_seq_len =
1480            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1481        let img_seq_len = img_seq_len * max_num_images;
1482
1483        let max_text_attn = {
1484            let cfg = &config.text_config;
1485            // This model injects the vision information directly into the input embeddings
1486            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1487
1488            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1489        };
1490
1491        Ok(max_text_attn)
1492    }
1493
1494    fn non_mapped_max_act_size_elems(
1495        &self,
1496        config: &str,
1497        params: &AutoDeviceMapParams,
1498    ) -> Result<usize> {
1499        let AutoDeviceMapParams::Vision {
1500            max_seq_len: _,
1501            max_batch_size,
1502            max_image_shape: _,
1503            max_num_images,
1504        } = params
1505        else {
1506            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1507        };
1508
1509        let config: LLaVAConfig = serde_json::from_str(config)?;
1510
1511        let img_seq_len =
1512            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1513
1514        let max_vision_attn = {
1515            (max_batch_size * max_num_images)
1516                * config.vision_config.num_attention_heads
1517                * img_seq_len
1518                * img_seq_len
1519        };
1520
1521        Ok(max_vision_attn)
1522    }
1523
1524    fn non_mapped_size_in_bytes(
1525        &self,
1526        config: &str,
1527        dtype: DType,
1528        weight_pack_factor: usize,
1529        _matformer_config: Option<&MatformerSliceConfig>,
1530    ) -> Result<usize> {
1531        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1532        let text_elems = {
1533            let cfg = &cfg.text_config;
1534            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1535            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1536            let norm = cfg.hidden_size;
1537            embed_tokens + lm_head + norm
1538        };
1539
1540        let image_newline = cfg.text_config.hidden_size;
1541        let mmproj = {
1542            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1543                + cfg.text_config.hidden_size;
1544            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1545                + cfg.text_config.hidden_size;
1546
1547            linear_1 + linear_2
1548        };
1549        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1550
1551        let elems = text_elems + image_newline + mmproj + vision_tower;
1552        Ok(elems * dtype.size_in_bytes())
1553    }
1554
1555    fn layer_sizes_in_bytes(
1556        &self,
1557        config: &str,
1558        dtype: DType,
1559        weight_pack_factor: usize,
1560        _matformer_config: Option<&MatformerSliceConfig>,
1561    ) -> Result<Vec<usize>> {
1562        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1563        let per_layer_elems = {
1564            let cfg = &cfg.text_config;
1565            let input_layernorm = cfg.hidden_size;
1566            let post_attention_layernorm = cfg.hidden_size;
1567
1568            let size_in = cfg.hidden_size;
1569            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1570            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1571            let q_proj = size_in * size_q / weight_pack_factor;
1572            let k_proj = size_in * size_kv / weight_pack_factor;
1573            let v_proj = size_in * size_kv / weight_pack_factor;
1574            let o_proj = size_q * size_in / weight_pack_factor;
1575
1576            let h_size = cfg.hidden_size;
1577            let i_size = cfg.intermediate_size;
1578            let gate_proj = h_size * i_size / weight_pack_factor;
1579            let up_proj = h_size * i_size / weight_pack_factor;
1580            let down_proj = i_size * h_size / weight_pack_factor;
1581
1582            input_layernorm
1583                + post_attention_layernorm
1584                + q_proj
1585                + k_proj
1586                + v_proj
1587                + o_proj
1588                + gate_proj
1589                + up_proj
1590                + down_proj
1591        };
1592        Ok(vec![
1593            per_layer_elems * dtype.size_in_bytes();
1594            cfg.text_config.num_hidden_layers
1595        ])
1596    }
1597
1598    fn num_layers(&self, config: &str) -> Result<usize> {
1599        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1600        Ok(cfg.text_config.num_hidden_layers)
1601    }
1602
1603    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1604        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1605        let cfg = &cfg.text_config;
1606
1607        let cfg = ModelConfigMetadata {
1608            max_seq_len: cfg.max_position_embeddings,
1609            num_layers: cfg.num_hidden_layers,
1610            hidden_size: cfg.hidden_size,
1611            num_kv_heads: cfg.num_key_value_heads,
1612            num_attn_heads: cfg.num_attention_heads,
1613            sliding_window: cfg.sliding_window,
1614            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1615            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1616        };
1617
1618        Ok(Box::new(cfg))
1619    }
1620
1621    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1622        Some(vec![NonMappedSubModel::Vision])
1623    }
1624}
1625
1626// ======================== MLlama Loader
1627
1628/// [`VisionLoader`] for an Llama Vision model.
1629///
1630/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1631pub struct VLlamaLoader;
1632
1633pub struct VLlamaPrefixer;
1634
1635impl MultimodalPromptPrefixer for VLlamaPrefixer {
1636    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1637        format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1638    }
1639}
1640
1641impl VisionModelLoader for VLlamaLoader {
1642    fn load(
1643        &self,
1644        config: &str,
1645        vb: ShardedVarBuilder,
1646        normal_loading_metadata: NormalLoadingMetadata,
1647        attention_mechanism: AttentionImplementation,
1648    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1649        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1650        Ok(Box::new(MLlamaModel::new(
1651            &cfg,
1652            vb,
1653            self.is_gptx(config),
1654            normal_loading_metadata,
1655            attention_mechanism,
1656        )?))
1657    }
1658    fn is_gptx(&self, _config: &str) -> bool {
1659        true
1660    }
1661    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1662        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1663        Ok(Box::new(cfg))
1664    }
1665    fn get_processor(
1666        &self,
1667        _model_config: &str,
1668        _processor_config: Option<ProcessorConfig>,
1669        _preprocessor_config: PreProcessorConfig,
1670        _max_edge: Option<u32>,
1671    ) -> Arc<dyn Processor + Send + Sync> {
1672        Arc::new(MLlamaProcessor::new())
1673    }
1674    fn supports_paged_attention(&self, _config: &str) -> bool {
1675        false
1676    }
1677    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1678        true
1679    }
1680    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1681        Arc::new(VLlamaPrefixer)
1682    }
1683    fn modalities(&self, _config: &str) -> Result<Modalities> {
1684        Ok(Modalities {
1685            input: vec![SupportedModality::Text, SupportedModality::Vision],
1686            output: vec![SupportedModality::Text],
1687        })
1688    }
1689}
1690
1691impl IsqModelLoader for VLlamaLoader {
1692    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1693        let config: MLlamaConfig = serde_json::from_str(config)?;
1694        let cross_attn_layers = &config.text_config.cross_attention_layers;
1695        let transformer_layers =
1696            (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1697        let mut text_regexes = Vec::new();
1698        for layer in transformer_layers {
1699            text_regexes.extend(vec![
1700                // Attention text
1701                Regex::new(&format!(
1702                    r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1703                ))?,
1704                Regex::new(&format!(
1705                    r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1706                ))?,
1707                Regex::new(&format!(
1708                    r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1709                ))?,
1710                Regex::new(&format!(
1711                    r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1712                ))?,
1713                // MLP text
1714                Regex::new(&format!(
1715                    r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1716                ))?,
1717                Regex::new(&format!(
1718                    r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1719                ))?,
1720                Regex::new(&format!(
1721                    r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1722                ))?,
1723            ]);
1724        }
1725        let vision_regexes = vec![
1726            // Vision attention (transformer)
1727            Regex::new(
1728                r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1729            )?,
1730            Regex::new(
1731                r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1732            )?,
1733            Regex::new(
1734                r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1735            )?,
1736            Regex::new(
1737                r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1738            )?,
1739            // Vision attention (global transforemr)
1740            Regex::new(
1741                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1742            )?,
1743            Regex::new(
1744                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1745            )?,
1746            Regex::new(
1747                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1748            )?,
1749            Regex::new(
1750                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1751            )?,
1752            // MLP vision
1753            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1754            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1755        ];
1756
1757        Ok([text_regexes, vision_regexes].concat())
1758    }
1759    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1760        self.isq_layer_regexes(config)
1761    }
1762}
1763
1764impl DeviceMappedModelLoader for VLlamaLoader {
1765    fn mapped_max_act_size_elems(
1766        &self,
1767        config: &str,
1768        params: &AutoDeviceMapParams,
1769    ) -> Result<usize> {
1770        let AutoDeviceMapParams::Vision {
1771            max_seq_len,
1772            max_batch_size,
1773            max_image_shape: _,
1774            max_num_images,
1775        } = params
1776        else {
1777            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1778        };
1779
1780        let config: MLlamaConfig = serde_json::from_str(config)?;
1781
1782        let img_seq_len = {
1783            let cfg = &config.vision_config;
1784            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1785            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1786            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1787        };
1788        let img_seq_len = img_seq_len * max_num_images;
1789
1790        let max_cross_text_attn = {
1791            let cfg = &config.text_config;
1792            max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1793        };
1794
1795        let max_self_text_attn = {
1796            let cfg = &config.text_config;
1797            max_batch_size * cfg.num_attention_heads * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)
1798        };
1799
1800        Ok(max_self_text_attn.max(max_cross_text_attn))
1801    }
1802
1803    fn non_mapped_max_act_size_elems(
1804        &self,
1805        config: &str,
1806        params: &AutoDeviceMapParams,
1807    ) -> Result<usize> {
1808        let AutoDeviceMapParams::Vision {
1809            max_seq_len: _,
1810            max_batch_size,
1811            max_image_shape: _,
1812            max_num_images,
1813        } = params
1814        else {
1815            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1816        };
1817
1818        let config: MLlamaConfig = serde_json::from_str(config)?;
1819
1820        let img_seq_len = {
1821            let cfg = &config.vision_config;
1822            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1823            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1824            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1825        };
1826        let max_vision_attn = {
1827            let cfg = &config.vision_config;
1828            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1829        };
1830
1831        Ok(max_vision_attn)
1832    }
1833
1834    fn non_mapped_size_in_bytes(
1835        &self,
1836        config: &str,
1837        dtype: DType,
1838        weight_pack_factor: usize,
1839        _matformer_config: Option<&MatformerSliceConfig>,
1840    ) -> Result<usize> {
1841        let config: MLlamaConfig = serde_json::from_str(config)?;
1842        let text_elems = {
1843            let cfg = &config.text_config;
1844            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1845            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1846            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1847                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1848            } else {
1849                0
1850            };
1851            let norm = cfg.hidden_size;
1852            embed_tokens + lm_head + norm
1853        };
1854
1855        let vision_elems = {
1856            let cfg = &config.vision_config;
1857
1858            let conv_cfg = Conv2dConfig {
1859                stride: cfg.patch_size,
1860                ..Default::default()
1861            };
1862            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1863                * cfg.patch_size
1864                * cfg.patch_size;
1865
1866            let class_embedding = cfg.hidden_size;
1867
1868            let gated_positional_embedding = {
1869                let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1870                let embedding = num_patches * cfg.hidden_size;
1871                let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1872                    * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1873
1874                embedding + tile_embedding
1875            };
1876
1877            let pre_tile_positional_embedding =
1878                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1879            let post_tile_positional_embedding =
1880                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1881
1882            let layernorm_pre = cfg.hidden_size;
1883            let layernorm_post = cfg.hidden_size;
1884
1885            let encoder_layer = {
1886                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1887                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1888
1889                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1890                let q_proj =
1891                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1892                let k_proj =
1893                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1894                let v_proj =
1895                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1896                let o_proj =
1897                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1898
1899                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1900                    + cfg.intermediate_size;
1901                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1902                    + cfg.hidden_size;
1903
1904                input_layernorm
1905                    + post_attention_layernorm
1906                    + q_proj
1907                    + k_proj
1908                    + v_proj
1909                    + o_proj
1910                    + fc1
1911                    + fc2
1912            };
1913
1914            patch_embedding
1915                + class_embedding
1916                + gated_positional_embedding
1917                + pre_tile_positional_embedding
1918                + post_tile_positional_embedding
1919                + layernorm_pre
1920                + layernorm_post
1921                + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1922        };
1923
1924        let elems = text_elems + vision_elems;
1925        Ok(elems * dtype.size_in_bytes())
1926    }
1927
1928    fn layer_sizes_in_bytes(
1929        &self,
1930        config: &str,
1931        dtype: DType,
1932        weight_pack_factor: usize,
1933        _matformer_config: Option<&MatformerSliceConfig>,
1934    ) -> Result<Vec<usize>> {
1935        let config: MLlamaConfig = serde_json::from_str(config)?;
1936        let cfg = &config.text_config;
1937
1938        let mut layer_sizes = Vec::new();
1939
1940        for i in 0..cfg.num_hidden_layers {
1941            let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1942                // No isq for cross attention
1943                1
1944            } else {
1945                weight_pack_factor
1946            };
1947
1948            let per_layer_elems = {
1949                let input_layernorm = cfg.hidden_size;
1950                let post_attention_layernorm = cfg.hidden_size;
1951
1952                let size_in = cfg.hidden_size;
1953                let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1954                let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1955                let q_proj = size_in * size_q / weight_pack_factor;
1956                let k_proj = size_in * size_kv / weight_pack_factor;
1957                let v_proj = size_in * size_kv / weight_pack_factor;
1958                let o_proj = size_q * size_in / weight_pack_factor;
1959
1960                let h_size = cfg.hidden_size;
1961                let i_size = cfg.intermediate_size;
1962                let gate_proj = h_size * i_size / weight_pack_factor;
1963                let up_proj = h_size * i_size / weight_pack_factor;
1964                let down_proj = i_size * h_size / weight_pack_factor;
1965
1966                input_layernorm
1967                    + post_attention_layernorm
1968                    + q_proj
1969                    + k_proj
1970                    + v_proj
1971                    + o_proj
1972                    + gate_proj
1973                    + up_proj
1974                    + down_proj
1975            };
1976
1977            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1978        }
1979
1980        Ok(layer_sizes)
1981    }
1982
1983    fn num_layers(&self, config: &str) -> Result<usize> {
1984        let config: MLlamaConfig = serde_json::from_str(config)?;
1985        Ok(config.text_config.num_hidden_layers)
1986    }
1987
1988    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1989        let cfg: MLlamaConfig = serde_json::from_str(config)?;
1990        let cfg = &cfg.text_config;
1991
1992        let cfg = ModelConfigMetadata {
1993            max_seq_len: cfg.max_position_embeddings,
1994            num_layers: cfg.num_hidden_layers,
1995            hidden_size: cfg.hidden_size,
1996            num_kv_heads: cfg.num_key_value_heads,
1997            num_attn_heads: cfg.num_attention_heads,
1998            sliding_window: None,
1999            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2000            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2001        };
2002
2003        Ok(Box::new(cfg))
2004    }
2005
2006    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2007        Some(vec![NonMappedSubModel::Vision])
2008    }
2009}
2010
2011// ======================== Qwen2VL Loader
2012
2013/// [`VisionLoader`] for an Qwen2-VL model.
2014///
2015/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2016pub struct Qwen2VLLoader;
2017
2018pub struct Qwen2VLPrefixer;
2019
2020impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
2021    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2022        format!(
2023            "{}{prompt}",
2024            format!(
2025                "{}{}{}",
2026                Qwen2VLProcessor::VISION_START,
2027                Qwen2VLProcessor::IMAGE_PAD,
2028                Qwen2VLProcessor::VISION_END
2029            )
2030            .repeat(image_indexes.len())
2031        )
2032    }
2033}
2034
2035impl VisionModelLoader for Qwen2VLLoader {
2036    fn load(
2037        &self,
2038        config: &str,
2039        vb: ShardedVarBuilder,
2040        normal_loading_metadata: NormalLoadingMetadata,
2041        attention_mechanism: AttentionImplementation,
2042    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2043        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2044        Ok(Box::new(Qwen2VLModel::new(
2045            &cfg,
2046            vb,
2047            self.is_gptx(config),
2048            normal_loading_metadata,
2049            attention_mechanism,
2050        )?))
2051    }
2052    fn is_gptx(&self, _config: &str) -> bool {
2053        true
2054    }
2055    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2056        let config: Qwen2VLConfig = serde_json::from_str(config)?;
2057        Ok(Box::new(config))
2058    }
2059    fn get_processor(
2060        &self,
2061        _model_config: &str,
2062        _processor_config: Option<ProcessorConfig>,
2063        _preprocessor_config: PreProcessorConfig,
2064        max_edge: Option<u32>,
2065    ) -> Arc<dyn Processor + Send + Sync> {
2066        Arc::new(Qwen2VLProcessor::new(max_edge))
2067    }
2068    fn supports_paged_attention(&self, _config: &str) -> bool {
2069        false
2070    }
2071    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2072        Arc::new(Qwen2VLPrefixer)
2073    }
2074    fn modalities(&self, _config: &str) -> Result<Modalities> {
2075        Ok(Modalities {
2076            input: vec![SupportedModality::Text, SupportedModality::Vision],
2077            output: vec![SupportedModality::Text],
2078        })
2079    }
2080}
2081
2082impl IsqModelLoader for Qwen2VLLoader {
2083    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2084        Ok(vec![
2085            Regex::new(r"lm_head\.(weight|bias)$")?,
2086            // Attention
2087            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2088            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2089            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2090            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2091            // MLP
2092            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2093            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2094            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2095        ])
2096    }
2097    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2098        self.isq_layer_regexes(config)
2099    }
2100}
2101
2102impl DeviceMappedModelLoader for Qwen2VLLoader {
2103    fn mapped_max_act_size_elems(
2104        &self,
2105        config: &str,
2106        params: &AutoDeviceMapParams,
2107    ) -> Result<usize> {
2108        let AutoDeviceMapParams::Vision {
2109            max_seq_len,
2110            max_batch_size,
2111            max_image_shape,
2112            max_num_images,
2113        } = params
2114        else {
2115            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2116        };
2117
2118        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2119
2120        let img_seq_len = {
2121            let cfg = &cfg.vision_config;
2122            let grid_t = max_num_images / cfg.temporal_patch_size;
2123            let grid_h = max_image_shape.0 / cfg.patch_size;
2124            let grid_w = max_image_shape.1 / cfg.patch_size;
2125            grid_t * grid_h * grid_w
2126        };
2127        let img_seq_len = img_seq_len * max_num_images;
2128
2129        let max_text_attn = {
2130            // This model injects the vision information directly into the input embeddings
2131            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2132            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2133        };
2134
2135        Ok(max_text_attn)
2136    }
2137
2138    fn non_mapped_max_act_size_elems(
2139        &self,
2140        config: &str,
2141        params: &AutoDeviceMapParams,
2142    ) -> Result<usize> {
2143        let AutoDeviceMapParams::Vision {
2144            max_seq_len: _,
2145            max_batch_size,
2146            max_image_shape,
2147            max_num_images,
2148        } = params
2149        else {
2150            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2151        };
2152
2153        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2154
2155        let img_seq_len = {
2156            let cfg = &cfg.vision_config;
2157            let grid_t = max_num_images / cfg.temporal_patch_size;
2158            let grid_h = max_image_shape.0 / cfg.patch_size;
2159            let grid_w = max_image_shape.1 / cfg.patch_size;
2160            grid_t * grid_h * grid_w
2161        };
2162
2163        let max_vision_attn = {
2164            let cfg = &cfg.vision_config;
2165            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2166        };
2167
2168        Ok(max_vision_attn)
2169    }
2170
2171    fn non_mapped_size_in_bytes(
2172        &self,
2173        config: &str,
2174        dtype: DType,
2175        weight_pack_factor: usize,
2176        _matformer_config: Option<&MatformerSliceConfig>,
2177    ) -> Result<usize> {
2178        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2179        let text_elems = {
2180            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2181            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2182            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2183                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2184            } else {
2185                0
2186            };
2187            let norm = cfg.hidden_size;
2188            embed_tokens + lm_head + norm
2189        };
2190
2191        let patch_merger = {
2192            let cfg = &cfg.vision_config;
2193            let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2194
2195            let mlp0 = hidden_size * hidden_size + hidden_size;
2196            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2197
2198            let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2199
2200            mlp0 + mlp2 + ln_q
2201        };
2202
2203        let patch_embed = {
2204            let cfg = &cfg.vision_config;
2205            let conv_cfg = Conv3dConfig {
2206                stride: cfg.patch_size,
2207                ..Default::default()
2208            };
2209            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2210            cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2211                * kernel_sizes[0]
2212                * kernel_sizes[1]
2213                * kernel_sizes[2]
2214        };
2215
2216        let encoder_layer = {
2217            let cfg = &cfg.vision_config;
2218            let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2219            let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2220
2221            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2222            let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2223            let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2224            let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2225
2226            let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2227            let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2228
2229            norm1 + norm2 + fc1 + fc2 + qkv + out
2230        };
2231
2232        let elems =
2233            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2234
2235        Ok(elems * dtype.size_in_bytes())
2236    }
2237
2238    fn layer_sizes_in_bytes(
2239        &self,
2240        config: &str,
2241        dtype: DType,
2242        weight_pack_factor: usize,
2243        _matformer_config: Option<&MatformerSliceConfig>,
2244    ) -> Result<Vec<usize>> {
2245        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2246        let per_layer_elems = {
2247            let input_layernorm = cfg.hidden_size;
2248            let post_attention_layernorm = cfg.hidden_size;
2249
2250            let size_in = cfg.hidden_size;
2251            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2252            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2253            let q_proj = size_in * size_q / weight_pack_factor + size_q;
2254            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2255            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2256            let o_proj = size_q * size_in / weight_pack_factor;
2257
2258            let h_size = cfg.hidden_size;
2259            let i_size = cfg.intermediate_size;
2260            let gate_proj = h_size * i_size / weight_pack_factor;
2261            let up_proj = h_size * i_size / weight_pack_factor;
2262            let down_proj = i_size * h_size / weight_pack_factor;
2263
2264            input_layernorm
2265                + post_attention_layernorm
2266                + q_proj
2267                + k_proj
2268                + v_proj
2269                + o_proj
2270                + gate_proj
2271                + up_proj
2272                + down_proj
2273        };
2274        Ok(vec![
2275            per_layer_elems * dtype.size_in_bytes();
2276            cfg.num_hidden_layers
2277        ])
2278    }
2279
2280    fn num_layers(&self, config: &str) -> Result<usize> {
2281        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2282        Ok(cfg.num_hidden_layers)
2283    }
2284
2285    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2286        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2287
2288        let cfg = ModelConfigMetadata {
2289            max_seq_len: cfg.max_position_embeddings,
2290            num_layers: cfg.num_hidden_layers,
2291            hidden_size: cfg.hidden_size,
2292            num_kv_heads: cfg.num_key_value_heads,
2293            num_attn_heads: cfg.num_attention_heads,
2294            sliding_window: cfg.sliding_window,
2295            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2296            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2297        };
2298
2299        Ok(Box::new(cfg))
2300    }
2301
2302    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2303        Some(vec![NonMappedSubModel::Vision])
2304    }
2305}
2306
2307// ======================== Idefics 3 loader
2308
2309/// [`VisionLoader`] for an Idefics 3 Vision model.
2310///
2311/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2312pub struct Idefics3Loader;
2313
2314pub struct Idefics3Prefixer;
2315
2316impl MultimodalPromptPrefixer for Idefics3Prefixer {
2317    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2318        // Chat template does it
2319        prompt.to_string()
2320    }
2321}
2322
2323impl VisionModelLoader for Idefics3Loader {
2324    fn load(
2325        &self,
2326        config: &str,
2327        vb: ShardedVarBuilder,
2328        normal_loading_metadata: NormalLoadingMetadata,
2329        attention_mechanism: AttentionImplementation,
2330    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2331        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2332        Ok(Box::new(Idefics3Model::new(
2333            &cfg,
2334            vb,
2335            self.is_gptx(config),
2336            normal_loading_metadata,
2337            attention_mechanism,
2338        )?))
2339    }
2340    fn is_gptx(&self, _config: &str) -> bool {
2341        true
2342    }
2343    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2344        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2345        Ok(Box::new(cfg))
2346    }
2347    fn get_processor(
2348        &self,
2349        _model_config: &str,
2350        processor_config: Option<ProcessorConfig>,
2351        preprocessor_config: PreProcessorConfig,
2352        max_edge: Option<u32>,
2353    ) -> Arc<dyn Processor + Send + Sync> {
2354        Arc::new(Idefics3Processor::new(
2355            processor_config.unwrap_or_default(),
2356            preprocessor_config,
2357            max_edge,
2358        ))
2359    }
2360    fn supports_paged_attention(&self, _config: &str) -> bool {
2361        true
2362    }
2363    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2364        true
2365    }
2366    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2367        Arc::new(Idefics3Prefixer)
2368    }
2369    fn modalities(&self, _config: &str) -> Result<Modalities> {
2370        Ok(Modalities {
2371            input: vec![SupportedModality::Text, SupportedModality::Vision],
2372            output: vec![SupportedModality::Text],
2373        })
2374    }
2375}
2376
2377impl IsqModelLoader for Idefics3Loader {
2378    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2379        Ok(vec![
2380            Regex::new(r"lm_head\.(weight|bias)$")?,
2381            // Attention
2382            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2383            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2384            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2385            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2386            // MLP
2387            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2388            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2389            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2390        ])
2391    }
2392    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2393        Ok(vec![
2394            Regex::new(r"lm_head\.(weight|bias)$")?,
2395            // Attention
2396            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2397            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2398            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2399            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2400            // MLP
2401            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2402            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2403            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2404            // // Attention (vision)
2405            // Regex::new(
2406            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2407            // )?,
2408            // Regex::new(
2409            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
2410            // )?,
2411            // Regex::new(
2412            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
2413            // )?,
2414            // Regex::new(
2415            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(weight|bias)$",
2416            // )?,
2417            // MLP (vision)
2418            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2419            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
2420        ])
2421    }
2422}
2423
2424impl DeviceMappedModelLoader for Idefics3Loader {
2425    fn mapped_max_act_size_elems(
2426        &self,
2427        config: &str,
2428        params: &AutoDeviceMapParams,
2429    ) -> Result<usize> {
2430        let AutoDeviceMapParams::Vision {
2431            max_seq_len,
2432            max_batch_size,
2433            max_image_shape: _,
2434            max_num_images,
2435        } = params
2436        else {
2437            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2438        };
2439
2440        let cfg: Idefics3Config = serde_json::from_str(config)?;
2441
2442        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2443        let img_seq_len = (num_patches + 1) * max_num_images;
2444
2445        let max_text_attn = {
2446            // This model injects the vision information directly into the input embeddings
2447            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2448            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2449        };
2450
2451        Ok(max_text_attn)
2452    }
2453
2454    fn non_mapped_max_act_size_elems(
2455        &self,
2456        config: &str,
2457        params: &AutoDeviceMapParams,
2458    ) -> Result<usize> {
2459        let AutoDeviceMapParams::Vision {
2460            max_seq_len: _,
2461            max_batch_size,
2462            max_image_shape: _,
2463            max_num_images,
2464        } = params
2465        else {
2466            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2467        };
2468
2469        let cfg: Idefics3Config = serde_json::from_str(config)?;
2470
2471        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2472        let img_seq_len = num_patches + 1;
2473
2474        let max_vision_attn = {
2475            // do_image_splitting = true
2476            let images_factor = 5;
2477
2478            (max_batch_size * images_factor * max_num_images)
2479                * cfg.vision_config.num_attention_heads
2480                * img_seq_len
2481                * img_seq_len
2482        };
2483
2484        Ok(max_vision_attn)
2485    }
2486
2487    fn non_mapped_size_in_bytes(
2488        &self,
2489        config: &str,
2490        dtype: DType,
2491        weight_pack_factor: usize,
2492        _matformer_config: Option<&MatformerSliceConfig>,
2493    ) -> Result<usize> {
2494        let cfg: Idefics3Config = serde_json::from_str(config)?;
2495        let text_elems = {
2496            let cfg = &cfg.text_config;
2497
2498            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2499            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2500            let norm = cfg.hidden_size;
2501            embed_tokens + lm_head + norm
2502        };
2503
2504        let connector_elems = {
2505            let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2506            let out_dim = cfg.text_config.hidden_size;
2507
2508            in_dim * out_dim
2509        };
2510
2511        let vision_transformer = {
2512            let cfg = &cfg.vision_config;
2513
2514            let post_layernorm = cfg.hidden_size;
2515
2516            let conv_config = Conv2dConfig {
2517                stride: cfg.patch_size,
2518                ..Default::default()
2519            };
2520            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2521                * cfg.patch_size
2522                * cfg.patch_size;
2523
2524            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2525            let num_patches = num_patches_per_side.pow(2);
2526            let position_embedding = num_patches * cfg.hidden_size;
2527
2528            let layer_elems = {
2529                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2530                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2531
2532                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2533                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2534
2535                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2536                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2537                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2538                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2539
2540                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2541            };
2542
2543            post_layernorm
2544                + patch_embedding
2545                + position_embedding
2546                + layer_elems * cfg.num_hidden_layers
2547        };
2548
2549        let elems = text_elems + connector_elems + vision_transformer;
2550
2551        Ok(elems * dtype.size_in_bytes())
2552    }
2553
2554    fn layer_sizes_in_bytes(
2555        &self,
2556        config: &str,
2557        dtype: DType,
2558        weight_pack_factor: usize,
2559        _matformer_config: Option<&MatformerSliceConfig>,
2560    ) -> Result<Vec<usize>> {
2561        let cfg: Idefics3Config = serde_json::from_str(config)?;
2562        let cfg = cfg.text_config;
2563        let per_layer_elems = {
2564            let input_layernorm = cfg.hidden_size;
2565            let post_attention_layernorm = cfg.hidden_size;
2566
2567            let size_in = cfg.hidden_size;
2568            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2569            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2570            let q_proj = size_in * size_q / weight_pack_factor;
2571            let k_proj = size_in * size_kv / weight_pack_factor;
2572            let v_proj = size_in * size_kv / weight_pack_factor;
2573            let o_proj = size_q * size_in / weight_pack_factor;
2574
2575            let h_size = cfg.hidden_size;
2576            let i_size = cfg.intermediate_size;
2577            let gate_proj = h_size * i_size / weight_pack_factor;
2578            let up_proj = h_size * i_size / weight_pack_factor;
2579            let down_proj = i_size * h_size / weight_pack_factor;
2580
2581            input_layernorm
2582                + post_attention_layernorm
2583                + q_proj
2584                + k_proj
2585                + v_proj
2586                + o_proj
2587                + gate_proj
2588                + up_proj
2589                + down_proj
2590        };
2591        Ok(vec![
2592            per_layer_elems * dtype.size_in_bytes();
2593            cfg.num_hidden_layers
2594        ])
2595    }
2596
2597    fn num_layers(&self, config: &str) -> Result<usize> {
2598        let cfg: Idefics3Config = serde_json::from_str(config)?;
2599        Ok(cfg.text_config.num_hidden_layers)
2600    }
2601    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2602        let cfg: Idefics3Config = serde_json::from_str(config)?;
2603        let cfg = &cfg.text_config;
2604
2605        let cfg = ModelConfigMetadata {
2606            max_seq_len: cfg.max_position_embeddings,
2607            num_layers: cfg.num_hidden_layers,
2608            hidden_size: cfg.hidden_size,
2609            num_kv_heads: cfg.num_key_value_heads,
2610            num_attn_heads: cfg.num_attention_heads,
2611            sliding_window: None,
2612            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2613            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2614        };
2615
2616        Ok(Box::new(cfg))
2617    }
2618
2619    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2620        Some(vec![NonMappedSubModel::Vision])
2621    }
2622}
2623
2624// ======================== MiniCpm-O loader
2625
2626/// [`VisionLoader`] for an MiniCpm-O model.
2627///
2628/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2629pub struct MiniCpmOLoader;
2630
2631pub struct MiniCpmOPrefixer;
2632
2633impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2634    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2635        format!(
2636            "{}{prompt}",
2637            "(<image>./</image>)".repeat(image_indexes.len())
2638        )
2639    }
2640}
2641
2642impl VisionModelLoader for MiniCpmOLoader {
2643    fn load(
2644        &self,
2645        config: &str,
2646        vb: ShardedVarBuilder,
2647        normal_loading_metadata: NormalLoadingMetadata,
2648        attention_mechanism: AttentionImplementation,
2649    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2650        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2651        Ok(Box::new(MiniCpmOModel::new(
2652            &cfg,
2653            vb,
2654            self.is_gptx(config),
2655            normal_loading_metadata,
2656            attention_mechanism,
2657        )?))
2658    }
2659    fn is_gptx(&self, _config: &str) -> bool {
2660        true
2661    }
2662    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2663        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2664        Ok(Box::new(cfg))
2665    }
2666    fn get_processor(
2667        &self,
2668        _model_config: &str,
2669        processor_config: Option<ProcessorConfig>,
2670        preprocessor_config: PreProcessorConfig,
2671        max_edge: Option<u32>,
2672    ) -> Arc<dyn Processor + Send + Sync> {
2673        Arc::new(MiniCpmOProcessor::new(
2674            processor_config.unwrap_or_default(),
2675            preprocessor_config,
2676            max_edge,
2677        ))
2678    }
2679    fn supports_paged_attention(&self, _config: &str) -> bool {
2680        true
2681    }
2682    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2683        Arc::new(MiniCpmOPrefixer)
2684    }
2685    fn modalities(&self, _config: &str) -> Result<Modalities> {
2686        Ok(Modalities {
2687            input: vec![SupportedModality::Text, SupportedModality::Vision],
2688            output: vec![SupportedModality::Text],
2689        })
2690    }
2691}
2692
2693impl IsqModelLoader for MiniCpmOLoader {
2694    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2695        Ok(vec![
2696            Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2697            // Attention
2698            Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2699            Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2700            Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2701            Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2702            // MLP
2703            Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2704            Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2705            Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2706        ])
2707    }
2708    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2709        self.isq_layer_regexes(config)
2710    }
2711}
2712
2713impl DeviceMappedModelLoader for MiniCpmOLoader {
2714    fn mapped_max_act_size_elems(
2715        &self,
2716        config: &str,
2717        params: &AutoDeviceMapParams,
2718    ) -> Result<usize> {
2719        let AutoDeviceMapParams::Vision {
2720            max_seq_len,
2721            max_batch_size,
2722            max_image_shape: _,
2723            max_num_images,
2724        } = params
2725        else {
2726            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2727        };
2728
2729        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2730
2731        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2732        let img_seq_len = (num_patches + 1) * max_num_images;
2733
2734        let max_text_attn = {
2735            // This model injects the vision information directly into the input embeddings
2736            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2737            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2738        };
2739
2740        Ok(max_text_attn)
2741    }
2742
2743    fn non_mapped_max_act_size_elems(
2744        &self,
2745        config: &str,
2746        params: &AutoDeviceMapParams,
2747    ) -> Result<usize> {
2748        let AutoDeviceMapParams::Vision {
2749            max_seq_len: _,
2750            max_batch_size,
2751            max_image_shape: _,
2752            max_num_images,
2753        } = params
2754        else {
2755            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2756        };
2757
2758        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2759
2760        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2761        let img_seq_len = num_patches + 1;
2762
2763        let max_vision_attn = {
2764            // do_image_splitting = true
2765            let images_factor = 5;
2766
2767            (max_batch_size * images_factor * max_num_images)
2768                * cfg.vision_config.num_attention_heads
2769                * img_seq_len
2770                * img_seq_len
2771        };
2772
2773        Ok(max_vision_attn)
2774    }
2775
2776    fn non_mapped_size_in_bytes(
2777        &self,
2778        config: &str,
2779        dtype: DType,
2780        weight_pack_factor: usize,
2781        _matformer_config: Option<&MatformerSliceConfig>,
2782    ) -> Result<usize> {
2783        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2784        let text_elems = {
2785            let cfg = &cfg.text_config;
2786
2787            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2788            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2789            let norm = cfg.hidden_size;
2790            embed_tokens + lm_head + norm
2791        };
2792
2793        let vision_transformer = {
2794            let cfg = &cfg.vision_config;
2795
2796            let post_layernorm = cfg.hidden_size;
2797
2798            let conv_config = Conv2dConfig {
2799                stride: cfg.patch_size,
2800                ..Default::default()
2801            };
2802            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2803                * cfg.patch_size
2804                * cfg.patch_size;
2805
2806            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2807            let num_patches = num_patches_per_side.pow(2);
2808            let position_embedding = num_patches * cfg.hidden_size;
2809
2810            let layer_elems = {
2811                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2812                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2813
2814                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2815                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2816
2817                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2818                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2819                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2820                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2821
2822                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2823            };
2824
2825            post_layernorm
2826                + patch_embedding
2827                + position_embedding
2828                + layer_elems * cfg.num_hidden_layers
2829        };
2830
2831        let elems = text_elems + vision_transformer;
2832
2833        Ok(elems * dtype.size_in_bytes())
2834    }
2835
2836    fn layer_sizes_in_bytes(
2837        &self,
2838        config: &str,
2839        dtype: DType,
2840        weight_pack_factor: usize,
2841        _matformer_config: Option<&MatformerSliceConfig>,
2842    ) -> Result<Vec<usize>> {
2843        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2844        let cfg = cfg.text_config;
2845        let per_layer_elems = {
2846            let input_layernorm = cfg.hidden_size;
2847            let post_attention_layernorm = cfg.hidden_size;
2848
2849            let size_in = cfg.hidden_size;
2850            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2851            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2852            let q_proj = size_in * size_q / weight_pack_factor;
2853            let k_proj = size_in * size_kv / weight_pack_factor;
2854            let v_proj = size_in * size_kv / weight_pack_factor;
2855            let o_proj = size_q * size_in / weight_pack_factor;
2856
2857            let h_size = cfg.hidden_size;
2858            let i_size = cfg.intermediate_size;
2859            let gate_proj = h_size * i_size / weight_pack_factor;
2860            let up_proj = h_size * i_size / weight_pack_factor;
2861            let down_proj = i_size * h_size / weight_pack_factor;
2862
2863            input_layernorm
2864                + post_attention_layernorm
2865                + q_proj
2866                + k_proj
2867                + v_proj
2868                + o_proj
2869                + gate_proj
2870                + up_proj
2871                + down_proj
2872        };
2873        Ok(vec![
2874            per_layer_elems * dtype.size_in_bytes();
2875            cfg.num_hidden_layers
2876        ])
2877    }
2878
2879    fn num_layers(&self, config: &str) -> Result<usize> {
2880        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2881        Ok(cfg.text_config.num_hidden_layers)
2882    }
2883    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2884        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2885        let cfg = &cfg.text_config;
2886
2887        let cfg = ModelConfigMetadata {
2888            max_seq_len: cfg.max_position_embeddings,
2889            num_layers: cfg.num_hidden_layers,
2890            hidden_size: cfg.hidden_size,
2891            num_kv_heads: cfg.num_key_value_heads,
2892            num_attn_heads: cfg.num_attention_heads,
2893            sliding_window: None,
2894            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2895            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2896        };
2897
2898        Ok(Box::new(cfg))
2899    }
2900}
2901
2902// ======================== Phi 4MM loader
2903
2904/// [`VisionLoader`] for a Phi 4MM Vision model.
2905///
2906/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2907pub struct Phi4MMLoader;
2908
2909pub struct Phi4MMPrefixer;
2910
2911impl MultimodalPromptPrefixer for Phi4MMPrefixer {
2912    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2913        // Image indexing starts at 0.
2914
2915        format!(
2916            "{}{prompt}",
2917            image_indexes
2918                .into_iter()
2919                .map(|image_index| format!("<|image_{}|>", image_index + 1))
2920                .join("")
2921        )
2922    }
2923    fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
2924        // Image indexing starts at 0.
2925
2926        format!(
2927            "{}{prompt}",
2928            audio_indexes
2929                .into_iter()
2930                .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
2931                .join("")
2932        )
2933    }
2934}
2935
2936impl VisionModelLoader for Phi4MMLoader {
2937    fn load(
2938        &self,
2939        config: &str,
2940        vb: ShardedVarBuilder,
2941        normal_loading_metadata: NormalLoadingMetadata,
2942        attention_mechanism: AttentionImplementation,
2943    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2944        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2945        Ok(Box::new(Phi4MMModel::new(
2946            &cfg,
2947            vb,
2948            self.is_gptx(config),
2949            normal_loading_metadata,
2950            attention_mechanism,
2951        )?))
2952    }
2953    fn is_gptx(&self, _config: &str) -> bool {
2954        true
2955    }
2956    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2957        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2958        Ok(Box::new(cfg))
2959    }
2960    fn get_processor(
2961        &self,
2962        _model_config: &str,
2963        processor_config: Option<ProcessorConfig>,
2964        preprocessor_config: PreProcessorConfig,
2965        _max_edge: Option<u32>,
2966    ) -> Arc<dyn Processor + Send + Sync> {
2967        Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2968    }
2969    fn supports_paged_attention(&self, _config: &str) -> bool {
2970        true
2971    }
2972    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2973        true
2974    }
2975    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2976        Arc::new(Phi4MMPrefixer)
2977    }
2978    fn modalities(&self, _config: &str) -> Result<Modalities> {
2979        Ok(Modalities {
2980            input: vec![
2981                SupportedModality::Text,
2982                SupportedModality::Vision,
2983                SupportedModality::Audio,
2984            ],
2985            output: vec![SupportedModality::Text],
2986        })
2987    }
2988}
2989
2990impl IsqModelLoader for Phi4MMLoader {
2991    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2992        Ok(vec![
2993            Regex::new(r"lm_head\.(weight|bias)$")?,
2994            // Attention
2995            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
2996            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2997            // MLP
2998            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
2999            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3000        ])
3001    }
3002    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3003        self.isq_layer_regexes(config)
3004    }
3005}
3006
3007impl DeviceMappedModelLoader for Phi4MMLoader {
3008    fn mapped_max_act_size_elems(
3009        &self,
3010        config: &str,
3011        params: &AutoDeviceMapParams,
3012    ) -> Result<usize> {
3013        // NOTE: we ignore max_num_images although it can only be one...
3014        let AutoDeviceMapParams::Vision {
3015            max_seq_len,
3016            max_batch_size,
3017            max_image_shape: _,
3018            max_num_images,
3019        } = params
3020        else {
3021            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3022        };
3023
3024        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3025
3026        let vcfg = &PHI4_MM_VISION_CFG;
3027
3028        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3029        let img_seq_len = (num_patches + 1) * max_num_images;
3030
3031        let max_text_attn = {
3032            // This model injects the vision information directly into the input embeddings
3033            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3034            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3035        };
3036
3037        Ok(max_text_attn)
3038    }
3039
3040    fn non_mapped_max_act_size_elems(
3041        &self,
3042        _config: &str,
3043        params: &AutoDeviceMapParams,
3044    ) -> Result<usize> {
3045        let AutoDeviceMapParams::Vision {
3046            max_seq_len: _,
3047            max_batch_size,
3048            max_image_shape,
3049            max_num_images,
3050        } = params
3051        else {
3052            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3053        };
3054
3055        let vcfg = &PHI4_MM_VISION_CFG;
3056
3057        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3058        let img_seq_len = num_patches + 1;
3059
3060        let max_batch_size = max_batch_size
3061            * (max_image_shape
3062                .0
3063                .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3064                * max_image_shape
3065                    .1
3066                    .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3067                + 1);
3068
3069        let max_vision_attn = (max_batch_size * max_num_images)
3070            * vcfg.num_attention_heads
3071            * img_seq_len
3072            * img_seq_len;
3073        let max_qkv = 3
3074            * (max_batch_size
3075                * vcfg.num_attention_heads
3076                * img_seq_len
3077                * (vcfg.hidden_size / vcfg.num_attention_heads));
3078
3079        Ok(max_vision_attn + max_qkv)
3080    }
3081
3082    fn non_mapped_size_in_bytes(
3083        &self,
3084        config: &str,
3085        dtype: DType,
3086        weight_pack_factor: usize,
3087        _matformer_config: Option<&MatformerSliceConfig>,
3088    ) -> Result<usize> {
3089        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3090        let elems = {
3091            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3092            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3093            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3094                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3095            } else {
3096                0
3097            };
3098            let norm = cfg.hidden_size;
3099
3100            let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3101                let projection_cls = img_embed
3102                    .projection_cls
3103                    .clone()
3104                    .unwrap_or("linear".to_string());
3105                let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3106                let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3107                let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3108
3109                let proj = match (projection_cls.as_str(), use_hd_transform) {
3110                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3111                    ("mlp", true) => {
3112                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3113                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3114                        a + b
3115                    }
3116                    ("mlp", false) => {
3117                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3118                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3119                        a + b
3120                    }
3121                    _ => {
3122                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3123                    }
3124                };
3125
3126                let (glb_gn, sub_gn) = if with_learnable_separator {
3127                    let glb_gn = image_dim_out * 4;
3128                    let sub_gn = image_dim_out * 4;
3129                    (glb_gn, sub_gn)
3130                } else {
3131                    (0, 0)
3132                };
3133
3134                let vision_transformer = {
3135                    let cfg = &PHI4_MM_VISION_CFG;
3136
3137                    let post_layernorm = cfg.hidden_size;
3138
3139                    let conv_config = Conv2dConfig {
3140                        stride: cfg.patch_size,
3141                        ..Default::default()
3142                    };
3143                    let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3144                        * cfg.patch_size
3145                        * cfg.patch_size;
3146
3147                    let num_patches_per_side = cfg.image_size / cfg.patch_size;
3148                    let num_patches = num_patches_per_side.pow(2);
3149                    let position_embedding = num_patches * cfg.hidden_size;
3150
3151                    let layer_elems = {
3152                        let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3153                        let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3154
3155                        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3156                        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3157
3158                        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3159                        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3160                        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3161                        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3162
3163                        layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3164                    };
3165
3166                    post_layernorm
3167                        + patch_embedding
3168                        + position_embedding
3169                        + layer_elems * cfg.num_hidden_layers
3170                };
3171
3172                proj + glb_gn + sub_gn + vision_transformer
3173            } else {
3174                0
3175            };
3176
3177            embed_tokens + lm_head + norm + image_embed
3178        };
3179
3180        Ok(elems * dtype.size_in_bytes())
3181    }
3182
3183    fn layer_sizes_in_bytes(
3184        &self,
3185        config: &str,
3186        dtype: DType,
3187        weight_pack_factor: usize,
3188        _matformer_config: Option<&MatformerSliceConfig>,
3189    ) -> Result<Vec<usize>> {
3190        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3191        let per_layer_elems = {
3192            let input_layernorm = cfg.hidden_size;
3193            let post_attention_layernorm = cfg.hidden_size;
3194
3195            let size_in = cfg.hidden_size;
3196            let head_dim = cfg.head_dim();
3197            let op_size =
3198                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3199            let qkv_proj = size_in * op_size / weight_pack_factor;
3200            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3201
3202            let h_size = cfg.hidden_size;
3203            let i_size = cfg.intermediate_size;
3204            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3205            let down_proj = h_size * i_size / weight_pack_factor;
3206
3207            input_layernorm
3208                + post_attention_layernorm
3209                + qkv_proj
3210                + o_proj
3211                + gate_up_proj
3212                + down_proj
3213        };
3214        Ok(vec![
3215            per_layer_elems * dtype.size_in_bytes();
3216            cfg.num_hidden_layers
3217        ])
3218    }
3219
3220    fn num_layers(&self, config: &str) -> Result<usize> {
3221        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3222        Ok(cfg.num_hidden_layers)
3223    }
3224
3225    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3226        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3227
3228        let cfg = ModelConfigMetadata {
3229            max_seq_len: cfg.max_position_embeddings,
3230            num_layers: cfg.num_hidden_layers,
3231            hidden_size: cfg.hidden_size,
3232            num_kv_heads: cfg.num_key_value_heads(),
3233            num_attn_heads: cfg.num_attention_heads,
3234            sliding_window: cfg.sliding_window,
3235            k_head_dim: cfg.head_dim(),
3236            v_head_dim: cfg.head_dim(),
3237        };
3238
3239        Ok(Box::new(cfg))
3240    }
3241
3242    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3243        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
3244    }
3245}
3246
3247// ======================== Qwen2_5VL Loader
3248
3249/// [`VisionLoader`] for an Qwen2_5VL model.
3250///
3251/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3252pub struct Qwen2_5VLLoader;
3253
3254pub struct Qwen2_5VLPrefixer;
3255
3256impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3257    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3258        format!(
3259            "{}{prompt}",
3260            format!(
3261                "{}{}{}",
3262                Qwen2_5VLProcessor::VISION_START,
3263                Qwen2_5VLProcessor::IMAGE_PAD,
3264                Qwen2_5VLProcessor::VISION_END
3265            )
3266            .repeat(image_indexes.len())
3267        )
3268    }
3269}
3270
3271impl VisionModelLoader for Qwen2_5VLLoader {
3272    fn load(
3273        &self,
3274        config: &str,
3275        vb: ShardedVarBuilder,
3276        normal_loading_metadata: NormalLoadingMetadata,
3277        attention_mechanism: AttentionImplementation,
3278    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3279        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3280        Ok(Box::new(Qwen2_5VLModel::new(
3281            &cfg,
3282            vb,
3283            self.is_gptx(config),
3284            normal_loading_metadata,
3285            attention_mechanism,
3286        )?))
3287    }
3288    fn is_gptx(&self, _config: &str) -> bool {
3289        true
3290    }
3291    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3292        let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3293        Ok(Box::new(config))
3294    }
3295    fn get_processor(
3296        &self,
3297        _model_config: &str,
3298        _processor_config: Option<ProcessorConfig>,
3299        _preprocessor_config: PreProcessorConfig,
3300        max_edge: Option<u32>,
3301    ) -> Arc<dyn Processor + Send + Sync> {
3302        Arc::new(Qwen2_5VLProcessor::new(max_edge))
3303    }
3304    fn supports_paged_attention(&self, _config: &str) -> bool {
3305        false
3306    }
3307    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3308        Arc::new(Qwen2_5VLPrefixer)
3309    }
3310    fn modalities(&self, _config: &str) -> Result<Modalities> {
3311        Ok(Modalities {
3312            input: vec![SupportedModality::Text, SupportedModality::Vision],
3313            output: vec![SupportedModality::Text],
3314        })
3315    }
3316}
3317
3318impl IsqModelLoader for Qwen2_5VLLoader {
3319    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3320        Ok(vec![
3321            Regex::new(r"lm_head\.(weight|bias)$")?,
3322            // Attention
3323            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3324            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3325            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3326            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3327            // MLP
3328            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3329            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3330            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3331        ])
3332    }
3333    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3334        self.isq_layer_regexes(config)
3335    }
3336}
3337
3338impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3339    fn mapped_max_act_size_elems(
3340        &self,
3341        config: &str,
3342        params: &AutoDeviceMapParams,
3343    ) -> Result<usize> {
3344        let AutoDeviceMapParams::Vision {
3345            max_seq_len,
3346            max_batch_size,
3347            max_image_shape,
3348            max_num_images,
3349        } = params
3350        else {
3351            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3352        };
3353
3354        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3355
3356        let img_seq_len = {
3357            let cfg = &cfg.vision_config;
3358            let grid_t = max_num_images / cfg.temporal_patch_size;
3359            let grid_h = max_image_shape.0 / cfg.patch_size;
3360            let grid_w = max_image_shape.1 / cfg.patch_size;
3361            grid_t * grid_h * grid_w
3362        };
3363        let img_seq_len = img_seq_len * max_num_images;
3364
3365        let max_text_attn = {
3366            // This model injects the vision information directly into the input embeddings
3367            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3368            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3369        };
3370
3371        Ok(max_text_attn)
3372    }
3373
3374    fn non_mapped_max_act_size_elems(
3375        &self,
3376        config: &str,
3377        params: &AutoDeviceMapParams,
3378    ) -> Result<usize> {
3379        let AutoDeviceMapParams::Vision {
3380            max_seq_len: _,
3381            max_batch_size,
3382            max_image_shape,
3383            max_num_images,
3384        } = params
3385        else {
3386            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3387        };
3388
3389        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3390
3391        let img_seq_len = {
3392            let cfg = &cfg.vision_config;
3393            let grid_t = max_num_images / cfg.temporal_patch_size;
3394            let grid_h = max_image_shape.0 / cfg.patch_size;
3395            let grid_w = max_image_shape.1 / cfg.patch_size;
3396            grid_t * grid_h * grid_w
3397        };
3398
3399        let max_vision_attn = {
3400            let cfg = &cfg.vision_config;
3401            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3402        };
3403
3404        Ok(max_vision_attn)
3405    }
3406
3407    fn non_mapped_size_in_bytes(
3408        &self,
3409        config: &str,
3410        dtype: DType,
3411        weight_pack_factor: usize,
3412        _matformer_config: Option<&MatformerSliceConfig>,
3413    ) -> Result<usize> {
3414        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3415        let text_elems = {
3416            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3417            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3418            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3419                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3420            } else {
3421                0
3422            };
3423            let norm = cfg.hidden_size;
3424            embed_tokens + lm_head + norm
3425        };
3426
3427        let patch_merger = {
3428            let cfg = &cfg.vision_config;
3429            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3430
3431            let mlp0 = hidden_size * hidden_size + hidden_size;
3432            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3433
3434            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3435
3436            mlp0 + mlp2 + ln_q
3437        };
3438
3439        let patch_embed = {
3440            let cfg = &cfg.vision_config;
3441            let conv_cfg = Conv3dConfig {
3442                stride: cfg.patch_size,
3443                ..Default::default()
3444            };
3445            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3446            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3447                * kernel_sizes[0]
3448                * kernel_sizes[1]
3449                * kernel_sizes[2]
3450        };
3451
3452        let encoder_layer = {
3453            let cfg = &cfg.vision_config;
3454            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3455            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3456
3457            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3458            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3459            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3460
3461            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3462            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3463
3464            norm1 + norm2 + fc1 + fc2 + qkv + out
3465        };
3466
3467        let elems =
3468            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3469
3470        Ok(elems * dtype.size_in_bytes())
3471    }
3472
3473    fn layer_sizes_in_bytes(
3474        &self,
3475        config: &str,
3476        dtype: DType,
3477        weight_pack_factor: usize,
3478        _matformer_config: Option<&MatformerSliceConfig>,
3479    ) -> Result<Vec<usize>> {
3480        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3481        let per_layer_elems = {
3482            let input_layernorm = cfg.hidden_size;
3483            let post_attention_layernorm = cfg.hidden_size;
3484
3485            let size_in = cfg.hidden_size;
3486            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3487            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3488            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3489            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3490            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3491            let o_proj = size_q * size_in / weight_pack_factor;
3492
3493            let h_size = cfg.hidden_size;
3494            let i_size = cfg.intermediate_size;
3495            let gate_proj = h_size * i_size / weight_pack_factor;
3496            let up_proj = h_size * i_size / weight_pack_factor;
3497            let down_proj = i_size * h_size / weight_pack_factor;
3498
3499            input_layernorm
3500                + post_attention_layernorm
3501                + q_proj
3502                + k_proj
3503                + v_proj
3504                + o_proj
3505                + gate_proj
3506                + up_proj
3507                + down_proj
3508        };
3509        Ok(vec![
3510            per_layer_elems * dtype.size_in_bytes();
3511            cfg.num_hidden_layers
3512        ])
3513    }
3514
3515    fn num_layers(&self, config: &str) -> Result<usize> {
3516        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3517        Ok(cfg.num_hidden_layers)
3518    }
3519
3520    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3521        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3522
3523        let cfg = ModelConfigMetadata {
3524            max_seq_len: cfg.max_position_embeddings,
3525            num_layers: cfg.num_hidden_layers,
3526            hidden_size: cfg.hidden_size,
3527            num_kv_heads: cfg.num_key_value_heads,
3528            num_attn_heads: cfg.num_attention_heads,
3529            sliding_window: cfg.sliding_window,
3530            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3531            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3532        };
3533
3534        Ok(Box::new(cfg))
3535    }
3536
3537    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3538        Some(vec![NonMappedSubModel::Vision])
3539    }
3540}
3541
3542// ======================== Gemma 3 Loader
3543
3544/// [`VisionLoader`] for an Gemma 3 model.
3545///
3546/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3547pub struct Gemma3Loader;
3548
3549pub struct Gemma3Prefixer;
3550
3551impl MultimodalPromptPrefixer for Gemma3Prefixer {
3552    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3553        prompt.to_string()
3554    }
3555}
3556
3557impl VisionModelLoader for Gemma3Loader {
3558    fn load(
3559        &self,
3560        config: &str,
3561        vb: ShardedVarBuilder,
3562        normal_loading_metadata: NormalLoadingMetadata,
3563        attention_mechanism: AttentionImplementation,
3564    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3565        let cfg: Gemma3Config = serde_json::from_str(config)?;
3566        Ok(Box::new(Gemma3Model::new(
3567            &cfg,
3568            vb,
3569            self.is_gptx(config),
3570            normal_loading_metadata,
3571            attention_mechanism,
3572        )?))
3573    }
3574    fn is_gptx(&self, _config: &str) -> bool {
3575        true
3576    }
3577    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3578        let config: Gemma3Config = serde_json::from_str(config)?;
3579        Ok(Box::new(config))
3580    }
3581    fn get_processor(
3582        &self,
3583        config: &str,
3584        processor_config: Option<ProcessorConfig>,
3585        _preprocessor_config: PreProcessorConfig,
3586        _max_edge: Option<u32>,
3587    ) -> Arc<dyn Processor + Send + Sync> {
3588        let config: Gemma3Config = serde_json::from_str(config).unwrap();
3589        // Handle the Gemma 3 1b case here
3590        Arc::new(Gemma3Processor::new(
3591            processor_config.unwrap_or_default(),
3592            matches!(config, Gemma3Config::WithVision { .. }),
3593        ))
3594    }
3595    fn supports_paged_attention(&self, _config: &str) -> bool {
3596        true
3597    }
3598    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3599        true
3600    }
3601    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3602        Arc::new(Gemma3Prefixer)
3603    }
3604    fn modalities(&self, _config: &str) -> Result<Modalities> {
3605        Ok(Modalities {
3606            input: vec![SupportedModality::Text, SupportedModality::Vision],
3607            output: vec![SupportedModality::Text],
3608        })
3609    }
3610}
3611
3612impl IsqModelLoader for Gemma3Loader {
3613    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3614        Ok(vec![
3615            Regex::new(r"lm_head\.(weight|bias)$")?,
3616            // Attention
3617            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3618            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3619            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3620            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3621            // MLP
3622            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3623            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3624            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3625        ])
3626    }
3627    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3628        Ok(vec![
3629            Regex::new(r"lm_head\.(weight|bias)$")?,
3630            // Attention
3631            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3632            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3633            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3634            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3635            // MLP
3636            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3637            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3638            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3639        ])
3640    }
3641}
3642
3643impl DeviceMappedModelLoader for Gemma3Loader {
3644    fn mapped_max_act_size_elems(
3645        &self,
3646        config: &str,
3647        params: &AutoDeviceMapParams,
3648    ) -> Result<usize> {
3649        let AutoDeviceMapParams::Vision {
3650            max_seq_len,
3651            max_batch_size,
3652            max_image_shape: _,
3653            max_num_images,
3654        } = params
3655        else {
3656            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3657        };
3658
3659        let cfg: Gemma3Config = serde_json::from_str(config)?;
3660
3661        match cfg {
3662            Gemma3Config::Text(text_config) => Ok(max_batch_size
3663                * text_config.num_attention_heads
3664                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)),
3665            Gemma3Config::WithVision {
3666                text_config,
3667                vision_config,
3668                ..
3669            } => {
3670                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3671                let img_seq_len = (num_patches + 1) * max_num_images;
3672
3673                let max_text_attn = {
3674                    // This model injects the vision information directly into the input embeddings
3675                    let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3676                    max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3677                };
3678                Ok(max_text_attn)
3679            }
3680        }
3681    }
3682
3683    fn non_mapped_max_act_size_elems(
3684        &self,
3685        config: &str,
3686        params: &AutoDeviceMapParams,
3687    ) -> Result<usize> {
3688        let AutoDeviceMapParams::Vision {
3689            max_seq_len: _,
3690            max_batch_size,
3691            max_image_shape: _,
3692            max_num_images,
3693        } = params
3694        else {
3695            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3696        };
3697
3698        let cfg: Gemma3Config = serde_json::from_str(config)?;
3699
3700        match cfg {
3701            Gemma3Config::WithVision { vision_config, .. } => {
3702                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3703                let img_seq_len = num_patches + 1;
3704
3705                let max_vision_attn = {
3706                    (max_batch_size * max_num_images)
3707                        * vision_config.num_attention_heads
3708                        * img_seq_len
3709                        * img_seq_len
3710                };
3711
3712                Ok(max_vision_attn)
3713            }
3714            Gemma3Config::Text(_) => Ok(0),
3715        }
3716    }
3717
3718    fn non_mapped_size_in_bytes(
3719        &self,
3720        config: &str,
3721        dtype: DType,
3722        weight_pack_factor: usize,
3723        _matformer_config: Option<&MatformerSliceConfig>,
3724    ) -> Result<usize> {
3725        let cfg: Gemma3Config = serde_json::from_str(config)?;
3726
3727        let text_elems = {
3728            let cfg = match &cfg {
3729                Gemma3Config::Text(cfg) => cfg,
3730                Gemma3Config::WithVision { text_config, .. } => text_config,
3731            };
3732            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3733            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3734            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3735                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3736            } else {
3737                0
3738            };
3739            let norm = cfg.hidden_size;
3740            embed_tokens + lm_head + norm
3741        };
3742
3743        let vision_transformer = if let Gemma3Config::WithVision {
3744            vision_config: cfg, ..
3745        } = &cfg
3746        {
3747            let post_layernorm = cfg.hidden_size;
3748
3749            let conv_config = Conv2dConfig {
3750                stride: cfg.patch_size,
3751                ..Default::default()
3752            };
3753            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3754                * cfg.patch_size
3755                * cfg.patch_size;
3756
3757            let num_patches_per_side = cfg.image_size / cfg.patch_size;
3758            let num_patches = num_patches_per_side.pow(2);
3759            let position_embedding = num_patches * cfg.hidden_size;
3760
3761            let layer_elems = {
3762                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3763                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3764
3765                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3766                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3767
3768                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3769                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3770                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3771                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3772
3773                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3774            };
3775
3776            post_layernorm
3777                + patch_embedding
3778                + position_embedding
3779                + layer_elems * cfg.num_hidden_layers
3780        } else {
3781            0
3782        };
3783
3784        let elems = text_elems + vision_transformer;
3785
3786        Ok(elems * dtype.size_in_bytes())
3787    }
3788
3789    fn layer_sizes_in_bytes(
3790        &self,
3791        config: &str,
3792        dtype: DType,
3793        weight_pack_factor: usize,
3794        _matformer_config: Option<&MatformerSliceConfig>,
3795    ) -> Result<Vec<usize>> {
3796        let cfg: Gemma3Config = serde_json::from_str(config)?;
3797
3798        let txt_cfg = match &cfg {
3799            Gemma3Config::Text(cfg) => cfg,
3800            Gemma3Config::WithVision { text_config, .. } => text_config,
3801        };
3802        let per_layer_elems = {
3803            let cfg = txt_cfg;
3804
3805            let input_layernorm = cfg.hidden_size;
3806            let post_attention_layernorm = cfg.hidden_size;
3807
3808            let size_in = cfg.hidden_size;
3809            let size_q = cfg.head_dim * cfg.num_attention_heads;
3810            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3811            let q_proj =
3812                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3813            let k_proj =
3814                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3815            let v_proj =
3816                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3817            let o_proj =
3818                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3819
3820            let h_size = cfg.hidden_size;
3821            let i_size = cfg.intermediate_size;
3822            let gate_proj = h_size * i_size / weight_pack_factor;
3823            let up_proj = h_size * i_size / weight_pack_factor;
3824            let down_proj = i_size * h_size / weight_pack_factor;
3825
3826            input_layernorm
3827                + post_attention_layernorm
3828                + q_proj
3829                + k_proj
3830                + v_proj
3831                + o_proj
3832                + gate_proj
3833                + up_proj
3834                + down_proj
3835        };
3836        Ok(vec![
3837            per_layer_elems * dtype.size_in_bytes();
3838            txt_cfg.num_hidden_layers
3839        ])
3840    }
3841
3842    fn num_layers(&self, config: &str) -> Result<usize> {
3843        let cfg: Gemma3Config = serde_json::from_str(config)?;
3844
3845        let txt_cfg = match &cfg {
3846            Gemma3Config::Text(cfg) => cfg,
3847            Gemma3Config::WithVision { text_config, .. } => text_config,
3848        };
3849
3850        Ok(txt_cfg.num_hidden_layers)
3851    }
3852
3853    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3854        let cfg: Gemma3Config = serde_json::from_str(config)?;
3855
3856        let cfg = match &cfg {
3857            Gemma3Config::Text(cfg) => cfg,
3858            Gemma3Config::WithVision { text_config, .. } => text_config,
3859        };
3860
3861        let cfg = ModelConfigMetadata {
3862            max_seq_len: cfg.max_position_embeddings,
3863            num_layers: cfg.num_hidden_layers,
3864            hidden_size: cfg.hidden_size,
3865            num_kv_heads: cfg.num_key_value_heads,
3866            num_attn_heads: cfg.num_attention_heads,
3867            sliding_window: None, // None to be more forgiving, some do not
3868            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3869            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3870        };
3871
3872        Ok(Box::new(cfg))
3873    }
3874
3875    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3876        Some(vec![NonMappedSubModel::Vision])
3877    }
3878}
3879
3880// ======================== Mistral 3 Loader
3881
3882/// [`VisionLoader`] for an Mistral 3 model.
3883///
3884/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3885pub struct Mistral3Loader;
3886
3887pub struct Mistral3Prefixer;
3888
3889impl MultimodalPromptPrefixer for Mistral3Prefixer {
3890    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3891        prompt.to_string()
3892    }
3893}
3894
3895impl VisionModelLoader for Mistral3Loader {
3896    fn load(
3897        &self,
3898        config: &str,
3899        vb: ShardedVarBuilder,
3900        normal_loading_metadata: NormalLoadingMetadata,
3901        attention_mechanism: AttentionImplementation,
3902    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3903        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3904        Ok(Box::new(Mistral3Model::new(
3905            &cfg,
3906            vb,
3907            self.is_gptx(config),
3908            normal_loading_metadata,
3909            attention_mechanism,
3910        )?))
3911    }
3912    fn is_gptx(&self, _config: &str) -> bool {
3913        true
3914    }
3915    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3916        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3917        Ok(Box::new(cfg))
3918    }
3919    fn get_processor(
3920        &self,
3921        _model_config: &str,
3922        processor_config: Option<ProcessorConfig>,
3923        _preprocessor_config: PreProcessorConfig,
3924        _max_edge: Option<u32>,
3925    ) -> Arc<dyn Processor + Send + Sync> {
3926        Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3927    }
3928    fn supports_paged_attention(&self, _config: &str) -> bool {
3929        true
3930    }
3931    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3932        true
3933    }
3934    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3935        Arc::new(Mistral3Prefixer)
3936    }
3937    fn modalities(&self, _config: &str) -> Result<Modalities> {
3938        Ok(Modalities {
3939            input: vec![SupportedModality::Text, SupportedModality::Vision],
3940            output: vec![SupportedModality::Text],
3941        })
3942    }
3943}
3944
3945impl IsqModelLoader for Mistral3Loader {
3946    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3947        Ok(vec![
3948            Regex::new(r"lm_head\.(weight|bias)$")?,
3949            // Attention
3950            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3951            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3952            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3953            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3954            // MLP
3955            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3956            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3957            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3958        ])
3959    }
3960    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3961        Ok(vec![
3962            Regex::new(r"lm_head\.(weight|bias)$")?,
3963            // Attention
3964            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3965            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3966            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3967            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3968            // MLP
3969            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3970            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3971            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3972        ])
3973    }
3974}
3975
3976#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3977impl DeviceMappedModelLoader for Mistral3Loader {
3978    fn mapped_max_act_size_elems(
3979        &self,
3980        config: &str,
3981        params: &AutoDeviceMapParams,
3982    ) -> Result<usize> {
3983        let cfg: Mistral3Config = serde_json::from_str(config)?;
3984        let vcfg = &cfg.vision_config;
3985        let tcfg = &cfg.text_config;
3986
3987        let AutoDeviceMapParams::Vision {
3988            max_seq_len,
3989            max_batch_size,
3990            max_image_shape: (mut height, mut width),
3991            max_num_images,
3992        } = params
3993        else {
3994            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3995        };
3996
3997        let img_seq_len = {
3998            // Reshaping algorithm
3999
4000            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
4001            let (max_height, max_width) = (1540, 1540);
4002            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4003            if ratio > 1. {
4004                height = (height as f64 / ratio).floor() as usize;
4005                width = (width as f64 / ratio).floor() as usize;
4006            }
4007
4008            let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
4009            let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
4010
4011            height = num_height_tokens * vcfg.patch_size;
4012            width = num_width_tokens * vcfg.patch_size;
4013
4014            let num_height_tokens = height / vcfg.patch_size;
4015            let num_width_tokens = width / vcfg.patch_size;
4016
4017            (num_width_tokens + 1) * num_height_tokens
4018        };
4019
4020        // This model injects the vision information directly into the input embeddings
4021        let max_seq_len = img_seq_len * max_num_images + *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4022        Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
4023    }
4024
4025    fn non_mapped_max_act_size_elems(
4026        &self,
4027        config: &str,
4028        params: &AutoDeviceMapParams,
4029    ) -> Result<usize> {
4030        let cfg: Mistral3Config = serde_json::from_str(config)?;
4031        let cfg = &cfg.vision_config;
4032
4033        let AutoDeviceMapParams::Vision {
4034            max_seq_len: _,
4035            max_batch_size,
4036            max_image_shape: (mut height, mut width),
4037            max_num_images,
4038        } = params
4039        else {
4040            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4041        };
4042
4043        let img_seq_len = {
4044            // Reshaping algorithm
4045
4046            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
4047            let (max_height, max_width) = (1540, 1540);
4048            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4049            if ratio > 1. {
4050                height = (height as f64 / ratio).floor() as usize;
4051                width = (width as f64 / ratio).floor() as usize;
4052            }
4053
4054            let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4055            let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4056
4057            height = num_height_tokens * cfg.patch_size;
4058            width = num_width_tokens * cfg.patch_size;
4059
4060            let num_height_tokens = height / cfg.patch_size;
4061            let num_width_tokens = width / cfg.patch_size;
4062
4063            (num_width_tokens + 1) * num_height_tokens
4064        };
4065
4066        Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4067    }
4068
4069    fn non_mapped_size_in_bytes(
4070        &self,
4071        config: &str,
4072        dtype: DType,
4073        weight_pack_factor: usize,
4074        _matformer_config: Option<&MatformerSliceConfig>,
4075    ) -> Result<usize> {
4076        let cfg: Mistral3Config = serde_json::from_str(config)?;
4077
4078        let text_elems = {
4079            let cfg = &cfg.text_config;
4080
4081            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4082            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4083            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4084                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4085            } else {
4086                0
4087            };
4088            let norm = cfg.hidden_size;
4089            embed_tokens + lm_head + norm
4090        };
4091
4092        let vision_elems = {
4093            let cfg = &cfg.vision_config;
4094
4095            let patch_embed = {
4096                let conv_cfg = Conv2dConfig {
4097                    stride: cfg.patch_size,
4098                    ..Default::default()
4099                };
4100                cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4101                    * cfg.patch_size
4102                    * cfg.patch_size
4103                    * cfg.patch_size
4104            };
4105            let ln_pre = cfg.hidden_size;
4106            let vision_layer = {
4107                let attn_norm = cfg.hidden_size;
4108                let ffn_norm = cfg.hidden_size;
4109
4110                let gate = cfg.hidden_size * cfg.intermediate_size;
4111                let up = cfg.hidden_size * cfg.intermediate_size;
4112                let down = cfg.hidden_size * cfg.intermediate_size;
4113
4114                let q = cfg.hidden_size * cfg.hidden_size;
4115                let k = cfg.hidden_size * cfg.hidden_size;
4116                let v = cfg.hidden_size * cfg.hidden_size;
4117                let o = cfg.hidden_size * cfg.hidden_size;
4118
4119                attn_norm + ffn_norm + gate + up + down + q + k + v + o
4120            };
4121
4122            patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4123        };
4124
4125        let elems = text_elems + vision_elems;
4126
4127        Ok(elems * dtype.size_in_bytes())
4128    }
4129
4130    fn layer_sizes_in_bytes(
4131        &self,
4132        config: &str,
4133        dtype: DType,
4134        weight_pack_factor: usize,
4135        _matformer_config: Option<&MatformerSliceConfig>,
4136    ) -> Result<Vec<usize>> {
4137        let cfg: Mistral3Config = serde_json::from_str(config)?;
4138        let cfg = &cfg.text_config;
4139
4140        let per_layer_elems = {
4141            let input_layernorm = cfg.hidden_size;
4142            let post_attention_layernorm = cfg.hidden_size;
4143
4144            let size_in = cfg.hidden_size;
4145            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4146            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4147            let q_proj = size_in * size_q / weight_pack_factor;
4148            let k_proj = size_in * size_kv / weight_pack_factor;
4149            let v_proj = size_in * size_kv / weight_pack_factor;
4150            let o_proj = size_q * size_in / weight_pack_factor;
4151
4152            let h_size = cfg.hidden_size;
4153            let i_size = cfg.intermediate_size;
4154            let gate_proj = h_size * i_size / weight_pack_factor;
4155            let up_proj = h_size * i_size / weight_pack_factor;
4156            let down_proj = i_size * h_size / weight_pack_factor;
4157
4158            input_layernorm
4159                + post_attention_layernorm
4160                + q_proj
4161                + k_proj
4162                + v_proj
4163                + o_proj
4164                + gate_proj
4165                + up_proj
4166                + down_proj
4167        };
4168        Ok(vec![
4169            per_layer_elems * dtype.size_in_bytes();
4170            cfg.num_hidden_layers
4171        ])
4172    }
4173
4174    fn num_layers(&self, config: &str) -> Result<usize> {
4175        let cfg: Mistral3Config = serde_json::from_str(config)?;
4176        let cfg = &cfg.text_config;
4177        Ok(cfg.num_hidden_layers)
4178    }
4179
4180    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4181        let cfg: Mistral3Config = serde_json::from_str(config)?;
4182        let cfg = &cfg.text_config;
4183
4184        let cfg = ModelConfigMetadata {
4185            max_seq_len: cfg.max_position_embeddings,
4186            num_layers: cfg.num_hidden_layers,
4187            hidden_size: cfg.hidden_size,
4188            num_kv_heads: cfg.num_key_value_heads,
4189            num_attn_heads: cfg.num_attention_heads,
4190            sliding_window: cfg.sliding_window,
4191            k_head_dim: cfg.head_dim(),
4192            v_head_dim: cfg.head_dim(),
4193        };
4194
4195        Ok(Box::new(cfg))
4196    }
4197
4198    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4199        Some(vec![NonMappedSubModel::Vision])
4200    }
4201}
4202
4203// ======================== Llama 4 Loader
4204
4205/// [`VisionLoader`] for an Llama Vision model.
4206///
4207/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
4208pub struct VLlama4Loader;
4209
4210pub struct VLlama4Prefixer;
4211
4212impl MultimodalPromptPrefixer for VLlama4Prefixer {
4213    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4214        format!(
4215            "{}{prompt}",
4216            llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4217        )
4218    }
4219}
4220
4221impl VisionModelLoader for VLlama4Loader {
4222    fn load(
4223        &self,
4224        config: &str,
4225        vb: ShardedVarBuilder,
4226        normal_loading_metadata: NormalLoadingMetadata,
4227        attention_mechanism: AttentionImplementation,
4228    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4229        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4230        Ok(Box::new(Llama4Model::new(
4231            &cfg,
4232            vb,
4233            self.is_gptx(config),
4234            normal_loading_metadata,
4235            attention_mechanism,
4236        )?))
4237    }
4238    fn is_gptx(&self, _config: &str) -> bool {
4239        false
4240    }
4241    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4242        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4243        Ok(Box::new(cfg))
4244    }
4245    fn get_processor(
4246        &self,
4247        _model_config: &str,
4248        processor_config: Option<ProcessorConfig>,
4249        _preprocessor_config: PreProcessorConfig,
4250        _max_edge: Option<u32>,
4251    ) -> Arc<dyn Processor + Send + Sync> {
4252        Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4253    }
4254    fn supports_paged_attention(&self, _config: &str) -> bool {
4255        true
4256    }
4257    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4258        Arc::new(VLlama4Prefixer)
4259    }
4260    fn modalities(&self, _config: &str) -> Result<Modalities> {
4261        Ok(Modalities {
4262            input: vec![SupportedModality::Text, SupportedModality::Vision],
4263            output: vec![SupportedModality::Text],
4264        })
4265    }
4266}
4267
4268impl IsqModelLoader for VLlama4Loader {
4269    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4270        Ok(vec![
4271            Regex::new(r"lm_head\.(weight|bias)$")?,
4272            // Attention
4273            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4274            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4275            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4276            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4277            // FF MoE
4278            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4279            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4280            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4281            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4282            Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4283            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4284            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4285            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4286            // FF MLP
4287            Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4288            Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4289            Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4290        ])
4291    }
4292    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4293        Ok(vec![
4294            Regex::new(r"lm_head\.(weight|bias)$")?,
4295            // Attention
4296            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4297            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4298            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4299            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4300            // FF MoE
4301            Regex::new(
4302                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4303            )?,
4304            Regex::new(
4305                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4306            )?,
4307            Regex::new(
4308                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4309            )?,
4310            Regex::new(
4311                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4312            )?,
4313            Regex::new(
4314                r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4315            )?,
4316            Regex::new(
4317                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4318            )?,
4319            Regex::new(
4320                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4321            )?,
4322            Regex::new(
4323                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4324            )?,
4325            // FF MLP
4326            Regex::new(
4327                r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4328            )?,
4329            Regex::new(
4330                r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4331            )?,
4332            Regex::new(
4333                r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4334            )?,
4335        ])
4336    }
4337}
4338
4339impl VLlama4Loader {
4340    /// This incorporates the max batch size!
4341    /// Returns (pixels max batch size, num text image tokens)
4342    #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4343    fn run_dummy_processing(
4344        &self,
4345        cfg: &Llama4Config,
4346        height: usize,
4347        width: usize,
4348        max_num_images: usize,
4349        max_batch_size: usize,
4350    ) -> Result<(usize, usize)> {
4351        let cfg = &cfg.vision_config;
4352
4353        let img_processor =
4354            Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4355        let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4356        let res = img_processor.preprocess(
4357            vec![image; max_num_images],
4358            vec![],
4359            &PreProcessorConfig::default(),
4360            &Device::Cpu,
4361            (max_batch_size, max_num_images),
4362        )?;
4363
4364        let pixels_batch_size = res.pixel_values.dim(0)?;
4365        let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4366
4367        let (image_h, image_w) = (
4368            res.pixel_values.dim(D::Minus2).unwrap(),
4369            res.pixel_values.dim(D::Minus1).unwrap(),
4370        );
4371        let num_patches_per_chunk = (image_h / img_processor.patch_size)
4372            * (image_w / img_processor.patch_size)
4373            / img_processor.downsample_ratio;
4374
4375        Ok((
4376            pixels_max_batch_size,
4377            num_patches_per_chunk * pixels_max_batch_size,
4378        ))
4379    }
4380}
4381
4382impl DeviceMappedModelLoader for VLlama4Loader {
4383    fn mapped_max_act_size_elems(
4384        &self,
4385        config: &str,
4386        params: &AutoDeviceMapParams,
4387    ) -> Result<usize> {
4388        let AutoDeviceMapParams::Vision {
4389            max_seq_len,
4390            max_batch_size,
4391            max_image_shape: (height, width),
4392            max_num_images,
4393        } = params
4394        else {
4395            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4396        };
4397
4398        let cfg: Llama4Config = serde_json::from_str(config)?;
4399
4400        let (_pixels_batch_size, num_text_image_toks) =
4401            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4402
4403        let max_seq_len = max_seq_len.min(&ATTENTION_CHUNK_SIZE) + num_text_image_toks;
4404
4405        Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4406    }
4407    fn non_mapped_max_act_size_elems(
4408        &self,
4409        config: &str,
4410        params: &AutoDeviceMapParams,
4411    ) -> Result<usize> {
4412        let AutoDeviceMapParams::Vision {
4413            max_seq_len: _,
4414            max_batch_size,
4415            max_image_shape: (height, width),
4416            max_num_images,
4417        } = params
4418        else {
4419            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4420        };
4421
4422        let cfg: Llama4Config = serde_json::from_str(config)?;
4423
4424        let (pixels_batch_size, _num_text_image_toks) =
4425            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4426        let max_seq_len = cfg.vision_config.num_patches();
4427
4428        Ok((max_batch_size * pixels_batch_size)
4429            * cfg.vision_config.num_attention_heads
4430            * max_seq_len
4431            * max_seq_len)
4432    }
4433
4434    fn non_mapped_size_in_bytes(
4435        &self,
4436        config: &str,
4437        dtype: DType,
4438        weight_pack_factor: usize,
4439        _matformer_config: Option<&MatformerSliceConfig>,
4440    ) -> Result<usize> {
4441        let cfg: Llama4Config = serde_json::from_str(config)?;
4442        let tcfg = &cfg.text_config;
4443
4444        let text_elems = {
4445            let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4446            let lm_head = if !tcfg.tie_word_embeddings {
4447                tcfg.hidden_size * tcfg.vocab_size
4448            } else {
4449                0
4450            };
4451            let norm = tcfg.hidden_size;
4452            embed_tokens + lm_head + norm
4453        };
4454
4455        let vision_elems = {
4456            let cfg = &cfg.vision_config;
4457
4458            let num_patches = cfg.num_patches();
4459
4460            let unfold_elems =
4461                (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4462            let class_embeddng_elems = cfg.hidden_size;
4463            let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4464            let layernorm_pre_elems = cfg.hidden_size;
4465            let layernorm_post_elems = cfg.hidden_size;
4466
4467            let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4468                / weight_pack_factor
4469                + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4470
4471            let encoder_layer = {
4472                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4473                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4474
4475                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4476                let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4477                    / weight_pack_factor
4478                    + cfg.num_attention_heads * head_dim;
4479                let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4480                    / weight_pack_factor
4481                    + cfg.num_attention_heads * head_dim;
4482                let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4483                    / weight_pack_factor
4484                    + cfg.num_attention_heads * head_dim;
4485                let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4486                    / weight_pack_factor
4487                    + cfg.num_attention_heads * head_dim;
4488
4489                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4490                    + cfg.intermediate_size;
4491                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4492                    + cfg.hidden_size;
4493
4494                input_layernorm
4495                    + post_attention_layernorm
4496                    + q_proj
4497                    + k_proj
4498                    + v_proj
4499                    + o_proj
4500                    + fc1
4501                    + fc2
4502            };
4503
4504            unfold_elems
4505                + class_embeddng_elems
4506                + positional_embedding_vlm_elems
4507                + layernorm_post_elems
4508                + layernorm_pre_elems
4509                + pixel_shuffle_elems
4510                + encoder_layer * cfg.num_hidden_layers
4511        };
4512
4513        let elems = text_elems + vision_elems;
4514
4515        Ok(elems * dtype.size_in_bytes())
4516    }
4517
4518    fn layer_sizes_in_bytes(
4519        &self,
4520        config: &str,
4521        dtype: DType,
4522        weight_pack_factor: usize,
4523        _matformer_config: Option<&MatformerSliceConfig>,
4524    ) -> Result<Vec<usize>> {
4525        let cfg: Llama4Config = serde_json::from_str(config)?;
4526        let tcfg = &cfg.text_config;
4527
4528        let mut per_layer_elems = Vec::new();
4529
4530        for layer_idx in 0..tcfg.num_hidden_layers {
4531            let input_layernorm = tcfg.hidden_size;
4532            let post_attention_layernorm = tcfg.hidden_size;
4533
4534            let size_in = tcfg.hidden_size;
4535            let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4536            let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4537            let q_proj = size_in * size_q / weight_pack_factor;
4538            let k_proj = size_in * size_kv / weight_pack_factor;
4539            let v_proj = size_in * size_kv / weight_pack_factor;
4540            let o_proj = size_q * size_in / weight_pack_factor;
4541
4542            let use_moe = tcfg.moe_layers().contains(&layer_idx);
4543            let moe_block = if use_moe {
4544                let h_size = tcfg.hidden_size;
4545                let i_size = tcfg.intermediate_size;
4546                let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4547                let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4548                let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4549
4550                gate_proj + up_proj + down_proj
4551            } else {
4552                let h_size = tcfg.hidden_size;
4553                let i_size = tcfg.intermediate_size_mlp;
4554                let gate_proj = h_size * i_size / weight_pack_factor;
4555                let up_proj = h_size * i_size / weight_pack_factor;
4556                let down_proj = i_size * h_size / weight_pack_factor;
4557
4558                gate_proj + up_proj + down_proj
4559            };
4560
4561            per_layer_elems.push(
4562                input_layernorm
4563                    + post_attention_layernorm
4564                    + q_proj
4565                    + k_proj
4566                    + v_proj
4567                    + o_proj
4568                    + moe_block,
4569            );
4570        }
4571
4572        Ok(per_layer_elems
4573            .into_iter()
4574            .map(|x| x * dtype.size_in_bytes())
4575            .collect())
4576    }
4577
4578    fn num_layers(&self, config: &str) -> Result<usize> {
4579        let cfg: Llama4Config = serde_json::from_str(config)?;
4580        Ok(cfg.text_config.num_hidden_layers)
4581    }
4582
4583    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4584        let cfg: Llama4Config = serde_json::from_str(config)?;
4585        let cfg = &cfg.text_config;
4586
4587        let cfg = ModelConfigMetadata {
4588            max_seq_len: cfg.max_position_embeddings,
4589            num_layers: cfg.num_hidden_layers,
4590            hidden_size: cfg.hidden_size,
4591            num_kv_heads: cfg.num_attention_heads,
4592            num_attn_heads: cfg.num_attention_heads,
4593            sliding_window: None,
4594            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4595            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4596        };
4597
4598        Ok(Box::new(cfg))
4599    }
4600
4601    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4602        Some(vec![NonMappedSubModel::Vision])
4603    }
4604}
4605
4606// ======================== Gemma 3n Loader
4607
4608/// [`VisionLoader`] for an Gemma 3n model.
4609///
4610/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
4611pub struct Gemma3nLoader;
4612
4613#[allow(dead_code)]
4614pub struct Gemma3nPrefixer;
4615
4616impl MultimodalPromptPrefixer for Gemma3nPrefixer {
4617    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4618        prompt.to_string()
4619    }
4620}
4621
4622impl VisionModelLoader for Gemma3nLoader {
4623    fn load(
4624        &self,
4625        config: &str,
4626        vb: ShardedVarBuilder,
4627        normal_loading_metadata: NormalLoadingMetadata,
4628        attention_mechanism: AttentionImplementation,
4629    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4630        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4631        Ok(Box::new(Gemma3nModel::new(
4632            &cfg,
4633            vb,
4634            self.is_gptx(config),
4635            normal_loading_metadata,
4636            attention_mechanism,
4637        )?))
4638    }
4639    fn is_gptx(&self, _config: &str) -> bool {
4640        true
4641    }
4642    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4643        let config: Gemma3nConfig = serde_json::from_str(config)?;
4644        Ok(Box::new(config))
4645    }
4646    fn get_processor(
4647        &self,
4648        _config: &str,
4649        processor_config: Option<ProcessorConfig>,
4650        _preprocessor_config: PreProcessorConfig,
4651        _max_edge: Option<u32>,
4652    ) -> Arc<dyn Processor + Send + Sync> {
4653        // Handle the Gemma 3 1b case here
4654        Arc::new(Gemma3nProcessor::new(
4655            processor_config.unwrap_or_default(),
4656            true,
4657        ))
4658    }
4659    fn supports_paged_attention(&self, _config: &str) -> bool {
4660        false
4661    }
4662    fn supports_prefix_cacher(&self, _config: &str) -> bool {
4663        true
4664    }
4665    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4666        Arc::new(Gemma3Prefixer)
4667    }
4668    fn modalities(&self, _config: &str) -> Result<Modalities> {
4669        Ok(Modalities {
4670            input: vec![
4671                SupportedModality::Text,
4672                SupportedModality::Vision,
4673                SupportedModality::Audio,
4674            ],
4675            output: vec![SupportedModality::Text],
4676        })
4677    }
4678}
4679
4680impl IsqModelLoader for Gemma3nLoader {
4681    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4682        Ok(vec![
4683            Regex::new(r"lm_head\.(weight|bias)$")?,
4684            // Language model attention
4685            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4686            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4687            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4688            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4689            // Language model MLP
4690            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4691            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4692            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4693            // Audio conformer attention layers
4694            Regex::new(r"conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$")?,
4695            Regex::new(r"conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$")?,
4696            Regex::new(r"conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$")?,
4697            Regex::new(
4698                r"conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4699            )?,
4700            Regex::new(r"conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4701            // Audio conformer FFW layers
4702            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$")?,
4703            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$")?,
4704            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$")?,
4705            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$")?,
4706            // Audio conformer conv1d layers
4707            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$")?,
4708            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$")?,
4709            // Audio subsample projection
4710            Regex::new(r"subsample_conv_projection\.input_proj_linear\.(weight|bias)$")?,
4711            // Multimodal embedders
4712            Regex::new(r"embed_vision\.embedding_projection\.(weight|bias)$")?,
4713            Regex::new(r"embed_audio\.embedding_projection\.(weight|bias)$")?,
4714        ])
4715    }
4716    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4717        Ok(vec![
4718            Regex::new(r"lm_head\.(weight|bias)$")?,
4719            // Language model attention
4720            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4721            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4722            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4723            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4724            // Language model MLP
4725            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4726            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4727            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4728            // Projections
4729            Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
4730            Regex::new(r"model\.language_model\.altup_projections\.(\d+)\.(weight|bias)$")?,
4731            Regex::new(r"model\.language_model\.altup_unembed_projections\.(\d+)\.(weight|bias)$")?,
4732            // Audio conformer attention layers
4733            Regex::new(
4734                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$",
4735            )?,
4736            Regex::new(
4737                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$",
4738            )?,
4739            Regex::new(
4740                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$",
4741            )?,
4742            Regex::new(
4743                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4744            )?,
4745            Regex::new(r"model\.audio_tower\.conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4746            // Audio conformer FFW layers
4747            Regex::new(
4748                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$",
4749            )?,
4750            Regex::new(
4751                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$",
4752            )?,
4753            Regex::new(
4754                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$",
4755            )?,
4756            Regex::new(
4757                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$",
4758            )?,
4759            // Audio conformer conv1d layers
4760            Regex::new(
4761                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$",
4762            )?,
4763            Regex::new(
4764                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$",
4765            )?,
4766            // Audio subsample projection
4767            Regex::new(
4768                r"model\.audio_tower\.subsample_conv_projection\.input_proj_linear\.(weight|bias)$",
4769            )?,
4770            // Multimodal embedders
4771            Regex::new(r"model\.embed_vision\.embedding_projection\.(weight|bias)$")?,
4772            Regex::new(r"model\.embed_audio\.embedding_projection\.(weight|bias)$")?,
4773        ])
4774    }
4775}
4776
4777impl DeviceMappedModelLoader for Gemma3nLoader {
4778    fn mapped_max_act_size_elems(
4779        &self,
4780        config: &str,
4781        params: &AutoDeviceMapParams,
4782    ) -> Result<usize> {
4783        let AutoDeviceMapParams::Vision {
4784            max_seq_len,
4785            max_batch_size,
4786            max_image_shape: _,
4787            max_num_images,
4788        } = params
4789        else {
4790            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4791        };
4792
4793        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4794        let text_cfg = &cfg.text_config;
4795
4796        // Gemma3n is an "inject into the prompt" model, similar to Gemma3
4797        // We need to account for vision and audio tokens in the sequence length
4798
4799        let mut total_seq_len = *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4800
4801        // Add vision tokens
4802        {
4803            // Vision tokens are injected into the prompt
4804            // MSFA outputs fixed 16x16 features regardless of input size
4805            let msfa_spatial_size = 16; // Fixed from vision.rs line 1115
4806            let vision_tokens_per_image = msfa_spatial_size * msfa_spatial_size; // 256 tokens
4807            total_seq_len += vision_tokens_per_image * max_num_images;
4808        }
4809
4810        // Add audio tokens
4811        {
4812            // Audio tokens are injected into the prompt
4813            // From config field audio_soft_tokens_per_image (typically 188)
4814            let audio_tokens = cfg.audio_soft_tokens_per_image;
4815            total_seq_len += audio_tokens;
4816        }
4817
4818        // Calculate max attention size for text model with all injected tokens
4819        let max_text_attn =
4820            max_batch_size * text_cfg.num_attention_heads * total_seq_len * total_seq_len;
4821
4822        Ok(max_text_attn)
4823    }
4824
4825    fn non_mapped_max_act_size_elems(
4826        &self,
4827        config: &str,
4828        params: &AutoDeviceMapParams,
4829    ) -> Result<usize> {
4830        let AutoDeviceMapParams::Vision {
4831            max_seq_len: _,
4832            max_batch_size,
4833            max_image_shape: _,
4834            max_num_images,
4835        } = params
4836        else {
4837            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4838        };
4839
4840        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4841
4842        // Calculate max activation sizes for each modality
4843        let mut max_activation = 0;
4844
4845        // Vision activation size
4846        {
4847            // Vision is Gemma3n's MobileNetV5 architecture with Multi-Query Attention
4848            // The peak activation is in the Multi-Query Attention layers
4849
4850            // From the architecture: stages 3 and 4 have MMQA blocks
4851            // Input images are 768x768 (from inputs_processor.rs)
4852            // Stage 3: 640 channels at 48x48 (768/16 downsampling), MMQA with num_heads=12, kv_dim=64
4853            // Stage 4: 1280 channels at 24x24 (768/32 downsampling), MMQA with num_heads=16, kv_dim=96
4854            // MSFA output: 2048 channels at fixed 16x16
4855
4856            let vision_tower_act = {
4857                // Peak is during MMQA attention computation in stage 4
4858                // Stage 4 has higher memory usage than Stage 3 due to more heads (16 vs 12)
4859                // From vision.rs: Stage 4 has num_heads=16, kv_dim=96, kv_stride=1
4860                let num_heads = 16; // Stage 4 configuration
4861                let spatial_size = 24; // 768 / 32 = 24 (input 768x768, stage 4 has 32x downsampling)
4862                let seq_len = spatial_size * spatial_size;
4863
4864                // Attention scores: [B * num_images, num_heads, seq_len, seq_len]
4865                max_batch_size * max_num_images * num_heads * seq_len * seq_len
4866            };
4867
4868            // Vision embedder activations
4869            let vision_embed_act = {
4870                // MSFA output: 2048 channels at fixed 16x16 spatial (from vision.rs line 1115)
4871                let msfa_channels = 2048; // MSFA_OUT_CHANNELS from vision.rs
4872                let spatial_size = 16; // Fixed output resolution from MSFA
4873                let vision_features =
4874                    max_batch_size * max_num_images * msfa_channels * spatial_size * spatial_size;
4875
4876                // After embedding projection to text hidden size
4877                let projected = max_batch_size
4878                    * max_num_images
4879                    * spatial_size
4880                    * spatial_size
4881                    * cfg.text_config.hidden_size;
4882
4883                vision_features.max(projected)
4884            };
4885
4886            max_activation = max_activation.max(vision_tower_act).max(vision_embed_act);
4887        }
4888
4889        // Audio activation size
4890        {
4891            let audio_cfg = &cfg.audio_config;
4892
4893            // Calculate max audio sequence length based on config
4894            // Audio uses conformer with subsampling and reduction
4895
4896            // A rough estimate of max_audio_frames
4897            let max_audio_frames = 1280;
4898
4899            let subsample_factor: usize = audio_cfg
4900                .sscp_conv_stride_size
4901                .iter()
4902                .map(|stride| stride[0]) // Time dimension stride
4903                .product();
4904            let audio_seq_after_subsample = max_audio_frames / subsample_factor;
4905
4906            // Audio encoder activations
4907            let audio_encoder_act = {
4908                // Conformer FFW layers have expansion factor from config
4909                let intermediate_size = audio_cfg.hidden_size * 4; // FFW expansion factor
4910
4911                // Peak is in the FFW layers before reduction
4912                max_batch_size * audio_seq_after_subsample * intermediate_size
4913            };
4914
4915            // Audio attention activations
4916            let audio_attn_act = {
4917                // Attention uses chunked processing with specific context sizes
4918                let chunk_size = audio_cfg.conf_attention_chunk_size;
4919                let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
4920                    + audio_cfg.conf_attention_context_right;
4921
4922                // Peak is attention scores: [B, num_heads, num_chunks, chunk_size, context_size]
4923                let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
4924
4925                max_batch_size
4926                    * audio_cfg.conf_num_attention_heads
4927                    * num_chunks
4928                    * chunk_size
4929                    * context_size
4930            };
4931
4932            max_activation = max_activation.max(audio_encoder_act).max(audio_attn_act);
4933        }
4934
4935        Ok(max_activation)
4936    }
4937
4938    fn non_mapped_size_in_bytes(
4939        &self,
4940        config: &str,
4941        dtype: DType,
4942        weight_pack_factor: usize,
4943        matformer_config: Option<&MatformerSliceConfig>,
4944    ) -> Result<usize> {
4945        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4946
4947        // Apply matformer slicing if configured
4948        let text_cfg = if let Some(matformer_cfg) = matformer_config {
4949            use crate::device_map::DummyDeviceMapper;
4950            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
4951
4952            let dummy_mapper = DummyDeviceMapper {
4953                nm_device: Device::Cpu,
4954            };
4955            let (adjusted_cfg, _, _, _, _) = handle_matformer_slicing(
4956                &cfg.text_config,
4957                &Some(matformer_cfg.clone()),
4958                &dummy_mapper,
4959            )?;
4960            adjusted_cfg
4961        } else {
4962            cfg.text_config.clone()
4963        };
4964
4965        let text_cfg = &text_cfg;
4966
4967        // Text components that are not device-mapped
4968        let text_elems = {
4969            // Embeddings
4970            let embed_tokens = text_cfg.hidden_size * text_cfg.vocab_size;
4971            let embed_tokens_per_layer = text_cfg.num_hidden_layers
4972                * text_cfg.hidden_size_per_layer_input
4973                * text_cfg.vocab_size_per_layer_input;
4974
4975            // LM head (if not tied)
4976            let lm_head = if !text_cfg.tie_word_embeddings || weight_pack_factor != 1 {
4977                text_cfg.hidden_size * text_cfg.vocab_size / weight_pack_factor
4978            } else {
4979                0
4980            };
4981
4982            // Final layer norm
4983            let norm = text_cfg.hidden_size;
4984
4985            // AltUp projections (not device-mapped)
4986            let altup_projections =
4987                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
4988                    / weight_pack_factor;
4989            let altup_unembed_projections =
4990                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
4991                    / weight_pack_factor;
4992
4993            // Per-layer model projection
4994            let per_layer_model_projection = text_cfg.num_hidden_layers
4995                * text_cfg.hidden_size
4996                * text_cfg.hidden_size_per_layer_input
4997                / weight_pack_factor;
4998            let per_layer_projection_norm = text_cfg.hidden_size;
4999
5000            embed_tokens
5001                + embed_tokens_per_layer
5002                + lm_head
5003                + norm
5004                + altup_projections
5005                + altup_unembed_projections
5006                + per_layer_model_projection
5007                + per_layer_projection_norm
5008        };
5009
5010        // Vision components
5011        let vision_elems = {
5012            let vision_cfg = &cfg.vision_config;
5013            // Vision tower - calculated from actual Gemma3n architecture
5014            // NOTE: Vision tower uses only Conv2d layers, NOT Arc<dyn QuantMethod>,
5015            // so NONE of these should be divided by weight_pack_factor
5016            let vision_tower_elems = {
5017                use crate::vision_models::gemma3n::vision::{
5018                    gemma3n_mobilenet_def, make_divisible, BlockType, INPUT_CHANNELS,
5019                    MSFA_EXPANSION_RATIO, MSFA_IN_CHANNELS, MSFA_OUT_CHANNELS, STEM_KERNEL_SIZE,
5020                    STEM_OUT_CHANNELS,
5021                };
5022
5023                // Stem: ConvNormAct (Conv2d + RMSNorm)
5024                let stem_conv =
5025                    INPUT_CHANNELS * STEM_OUT_CHANNELS * STEM_KERNEL_SIZE * STEM_KERNEL_SIZE;
5026                let stem_norm = STEM_OUT_CHANNELS; // RMSNorm weight
5027
5028                // Track input channels through the network
5029                let mut in_chs = STEM_OUT_CHANNELS;
5030                let mut total_elems = stem_conv + stem_norm;
5031
5032                // Process all stages from gemma3n_mobilenet_def
5033                let block_defs = gemma3n_mobilenet_def();
5034
5035                for stage_blocks in block_defs.iter() {
5036                    for block_type in stage_blocks.iter() {
5037                        match block_type {
5038                            BlockType::EdgeResidual {
5039                                out_channels,
5040                                kernel_size,
5041                                stride: _,
5042                                expand_ratio,
5043                                ..
5044                            } => {
5045                                #[allow(clippy::cast_precision_loss)]
5046                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5047                                // EdgeResidual: all Conv2d layers, not quantizable
5048                                total_elems += in_chs * mid_chs * kernel_size * kernel_size; // conv_exp (Conv2d)
5049                                total_elems += mid_chs; // bn1 weight
5050                                total_elems += mid_chs * out_channels; // conv_pwl (Conv2d)
5051                                total_elems += out_channels; // bn2 weight
5052                                in_chs = *out_channels;
5053                            }
5054                            BlockType::UniversalInvertedResidual {
5055                                out_channels,
5056                                start_kernel_size,
5057                                mid_kernel_size,
5058                                stride: _,
5059                                expand_ratio,
5060                                ..
5061                            } => {
5062                                #[allow(clippy::cast_precision_loss)]
5063                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5064                                // UniversalInvertedResidual: all Conv2d layers, not quantizable
5065                                if *expand_ratio != 1.0 {
5066                                    total_elems += in_chs * mid_chs; // expand conv (Conv2d)
5067                                    total_elems += mid_chs; // expand norm
5068                                }
5069                                if *start_kernel_size > 0 {
5070                                    total_elems += mid_chs * start_kernel_size * start_kernel_size; // depthwise start (Conv2d)
5071                                    total_elems += mid_chs; // norm
5072                                }
5073                                if *mid_kernel_size > 0 {
5074                                    total_elems += mid_chs * mid_kernel_size * mid_kernel_size; // depthwise mid (Conv2d)
5075                                    total_elems += mid_chs; // norm
5076                                }
5077                                total_elems += mid_chs * out_channels; // project conv (Conv2d)
5078                                total_elems += out_channels; // project norm
5079                                total_elems += out_channels; // layer scale gamma
5080                                in_chs = *out_channels;
5081                            }
5082                            BlockType::MultiQueryAttention {
5083                                num_heads,
5084                                kv_dim,
5085                                kv_stride: _,
5086                                ..
5087                            } => {
5088                                // MMQA: all Conv2d layers, not quantizable
5089                                let dw_kernel_size = 3; // Default dw_kernel_size for MMQA
5090                                total_elems += in_chs; // norm weight
5091                                total_elems += in_chs * num_heads * kv_dim; // query_proj (Conv2d)
5092                                total_elems += in_chs * kv_dim; // key_proj (Conv2d)
5093                                total_elems += in_chs * dw_kernel_size * dw_kernel_size; // key_dw_conv (Conv2d)
5094                                total_elems += *kv_dim; // value_down_conv (Conv2d)
5095                                total_elems += 1; // value_norm weight
5096                                total_elems += *kv_dim; // value_proj (Conv2d)
5097                                total_elems += num_heads * kv_dim * in_chs; // output_proj (Conv2d)
5098                                total_elems += in_chs; // layer scale
5099                            }
5100                        }
5101                    }
5102                }
5103
5104                // Multi-scale fusion adapter (msfa) - also uses Conv2d layers
5105                let msfa_in = MSFA_IN_CHANNELS.iter().sum::<usize>();
5106                let msfa_out = MSFA_OUT_CHANNELS;
5107                #[allow(clippy::cast_precision_loss)]
5108                let msfa_mid = make_divisible(msfa_in as f64 * MSFA_EXPANSION_RATIO, 8);
5109
5110                // MSFA FFN (UIR with expansion_ratio) - Conv2d layers, not quantizable
5111                total_elems += msfa_in * msfa_mid; // expand (Conv2d)
5112                total_elems += msfa_mid; // expand norm
5113                total_elems += msfa_mid * msfa_out; // project (Conv2d)
5114                total_elems += msfa_out; // project norm
5115                total_elems += msfa_out; // final norm
5116
5117                total_elems
5118            };
5119
5120            // Vision multimodal embedder components
5121            let embed_vision_elems = {
5122                // Embedding layer (not quantizable)
5123                let embedding = vision_cfg.vocab_size * vision_cfg.hidden_size;
5124
5125                // Normalization layers (not quantizable)
5126                let hard_norm = vision_cfg.hidden_size;
5127                let soft_norm = vision_cfg.hidden_size;
5128
5129                // Projection from vision to text hidden size (IS Arc<dyn QuantMethod>, so quantizable)
5130                let projection = vision_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5131
5132                // Post-projection norm (not quantizable)
5133                let post_norm = text_cfg.hidden_size;
5134
5135                embedding + hard_norm + soft_norm + projection + post_norm
5136            };
5137
5138            vision_tower_elems + embed_vision_elems
5139        };
5140
5141        // Audio components - based on actual audio.rs structure
5142        let audio_elems = {
5143            let audio_cfg = &cfg.audio_config;
5144
5145            // SubSampleConvProjection components
5146            let subsample_conv_projection_elems = {
5147                // Conv blocks (Conv2d layers - NOT quantizable)
5148                let mut conv_elems = 0;
5149
5150                // conv_0: Conv2d from 1 channel to first channel size
5151                let in_ch_0 = 1;
5152                let out_ch_0 = audio_cfg.sscp_conv_channel_size[0];
5153                let kernel_0 = &audio_cfg.sscp_conv_kernel_size[0];
5154                conv_elems += in_ch_0 * out_ch_0 * kernel_0[0] * kernel_0[1];
5155
5156                // conv_1: Conv2d from first to second channel size
5157                let in_ch_1 = out_ch_0;
5158                let out_ch_1 = audio_cfg.sscp_conv_channel_size[1];
5159                let kernel_1 = &audio_cfg.sscp_conv_kernel_size[1];
5160                conv_elems += in_ch_1 * out_ch_1 * kernel_1[0] * kernel_1[1];
5161
5162                // CumulativeGroupNorm for each conv block (weight only, no bias by default)
5163                let norm_0 = out_ch_0; // norm weight for conv_0
5164                let norm_1 = out_ch_1; // norm weight for conv_1
5165
5166                // input_proj_linear (Arc<dyn QuantMethod> - IS quantizable)
5167                let mut f_out = audio_cfg.input_feat_size;
5168                for i in 0..2 {
5169                    let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
5170                    let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
5171                    let pad_left = 1;
5172                    let pad_right = 1;
5173                    f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
5174                }
5175                let input_proj_in_features = out_ch_1 * f_out;
5176                let input_proj_linear =
5177                    input_proj_in_features * audio_cfg.hidden_size / weight_pack_factor;
5178
5179                conv_elems + norm_0 + norm_1 + input_proj_linear
5180            };
5181
5182            // Conformer blocks
5183            let conformer_elems = {
5184                let mut total = 0;
5185
5186                for _ in 0..audio_cfg.conf_num_hidden_layers {
5187                    // ConformerAttention
5188                    let attention_elems = {
5189                        // Norms (NOT quantizable)
5190                        let pre_attn_norm = audio_cfg.hidden_size;
5191                        let post_norm = audio_cfg.hidden_size;
5192
5193                        // Attention projections (Arc<dyn QuantMethod> - IS quantizable)
5194                        let q_proj =
5195                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5196                        let k_proj =
5197                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5198                        let v_proj =
5199                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5200                        let post =
5201                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5202
5203                        // RelativePositionEmbedding
5204                        let pos_proj =
5205                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5206                        let per_dim_scale =
5207                            audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads; // head_dim
5208                        let inv_timescales = audio_cfg.hidden_size / 2; // num_timescales
5209                        let pos_indices = audio_cfg.conf_attention_context_left
5210                            + audio_cfg.conf_attention_context_right
5211                            + 1;
5212
5213                        // Local causal masks (precomputed tensors)
5214                        let chunk_size = audio_cfg.conf_attention_chunk_size;
5215                        let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5216                            + audio_cfg.conf_attention_context_right;
5217                        let local_causal_valid_mask = chunk_size * context_size; // U8 tensor
5218                        let invalid_logits_tensor = 1; // single f32 value
5219
5220                        pre_attn_norm
5221                            + post_norm
5222                            + q_proj
5223                            + k_proj
5224                            + v_proj
5225                            + post
5226                            + pos_proj
5227                            + per_dim_scale
5228                            + inv_timescales
5229                            + pos_indices
5230                            + local_causal_valid_mask
5231                            + invalid_logits_tensor
5232                    };
5233
5234                    // ConformerFeedForward (start and end)
5235                    let ffw_elems = {
5236                        // Each FFW has:
5237                        // - pre_layer_norm (NOT quantizable)
5238                        // - ffw_layer_1 (Arc<dyn QuantMethod> - IS quantizable)
5239                        // - ffw_layer_2 (Arc<dyn QuantMethod> - IS quantizable)
5240                        // - post_layer_norm (NOT quantizable)
5241                        let intermediate_size = audio_cfg.hidden_size * 4;
5242
5243                        let ffw_start = {
5244                            let pre_norm = audio_cfg.hidden_size;
5245                            let layer_1 =
5246                                audio_cfg.hidden_size * intermediate_size / weight_pack_factor;
5247                            let layer_2 =
5248                                intermediate_size * audio_cfg.hidden_size / weight_pack_factor;
5249                            let post_norm = audio_cfg.hidden_size;
5250                            pre_norm + layer_1 + layer_2 + post_norm
5251                        };
5252
5253                        let ffw_end = ffw_start; // Same structure
5254
5255                        ffw_start + ffw_end
5256                    };
5257
5258                    // ConformerLightConv1d
5259                    let lconv1d_elems = {
5260                        // Norms (NOT quantizable)
5261                        let pre_layer_norm = audio_cfg.hidden_size;
5262                        let conv_norm = audio_cfg.hidden_size;
5263
5264                        // Linear layers (Arc<dyn QuantMethod> - IS quantizable)
5265                        let linear_start = audio_cfg.hidden_size * (audio_cfg.hidden_size * 2)
5266                            / weight_pack_factor;
5267                        let linear_end =
5268                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5269
5270                        // depthwise_conv1d (Conv1d - NOT quantizable)
5271                        let depthwise = audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
5272
5273                        pre_layer_norm + conv_norm + linear_start + linear_end + depthwise
5274                    };
5275
5276                    // Final norm for conformer block (NOT quantizable)
5277                    let block_norm = audio_cfg.hidden_size;
5278
5279                    total += attention_elems + ffw_elems + lconv1d_elems + block_norm;
5280                }
5281
5282                total
5283            };
5284
5285            // Audio multimodal embedder (embed_audio)
5286            let embed_audio_elems = {
5287                // Embedding layer (ScaledEmbedding - NOT quantizable)
5288                let embedding = audio_cfg.vocab_size * audio_cfg.hidden_size;
5289
5290                // RMS norms (NOT quantizable)
5291                let hard_embedding_norm = audio_cfg.hidden_size; // with scale
5292                let soft_embedding_norm = audio_cfg.hidden_size; // with scale
5293                let embedding_post_projection_norm = text_cfg.hidden_size; // without scale
5294
5295                // Projection (Arc<dyn QuantMethod> - IS quantizable)
5296                let embedding_projection =
5297                    audio_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5298
5299                embedding
5300                    + hard_embedding_norm
5301                    + soft_embedding_norm
5302                    + embedding_post_projection_norm
5303                    + embedding_projection
5304            };
5305
5306            subsample_conv_projection_elems + conformer_elems + embed_audio_elems
5307        };
5308
5309        let vision_dtype = if dtype == DType::F16 {
5310            // f16 -> f32 for vision model in particular.
5311            DType::F32
5312        } else {
5313            dtype
5314        };
5315
5316        let total_elems = text_elems * dtype.size_in_bytes()
5317            + vision_elems * vision_dtype.size_in_bytes()
5318            + audio_elems * dtype.size_in_bytes();
5319
5320        Ok(total_elems)
5321    }
5322
5323    fn layer_sizes_in_bytes(
5324        &self,
5325        config: &str,
5326        dtype: DType,
5327        weight_pack_factor: usize,
5328        matformer_config: Option<&MatformerSliceConfig>,
5329    ) -> Result<Vec<usize>> {
5330        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5331
5332        // Apply matformer slicing if configured
5333        let (text_cfg, _layer_rename_map, _layers_skipped) = if let Some(matformer_cfg) =
5334            matformer_config
5335        {
5336            use crate::device_map::DummyDeviceMapper;
5337            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5338
5339            let dummy_mapper = DummyDeviceMapper {
5340                nm_device: Device::Cpu,
5341            };
5342            let (adjusted_cfg, _, _, layer_rename_map, layers_skipped) = handle_matformer_slicing(
5343                &cfg.text_config,
5344                &Some(matformer_cfg.clone()),
5345                &dummy_mapper,
5346            )?;
5347            (adjusted_cfg, layer_rename_map, layers_skipped)
5348        } else {
5349            (cfg.text_config.clone(), None, None)
5350        };
5351
5352        let text_cfg = &text_cfg;
5353
5354        // When matformer slicing is applied, we only include the layers that are kept
5355        let mut layer_sizes = Vec::new();
5356
5357        // Note: We don't need orig_intermediate_sizes anymore since the adjusted config
5358        // already has the correct intermediate sizes after matformer slicing
5359
5360        for layer_idx in 0..text_cfg.num_hidden_layers {
5361            let per_layer_elems = {
5362                // Layer norms
5363                let input_layernorm = text_cfg.hidden_size;
5364                let post_attention_layernorm = text_cfg.hidden_size;
5365                let pre_feedforward_layernorm = text_cfg.hidden_size;
5366                let post_feedforward_layernorm = text_cfg.hidden_size;
5367                let post_per_layer_input_norm = text_cfg.hidden_size;
5368
5369                // Attention components
5370                let size_in = text_cfg.hidden_size;
5371                let size_q = text_cfg.num_attention_heads * text_cfg.head_dim;
5372                let size_kv = text_cfg.num_key_value_heads * text_cfg.head_dim;
5373
5374                let q_proj = size_in * size_q / weight_pack_factor;
5375                let k_proj = size_in * size_kv / weight_pack_factor;
5376                let v_proj = size_in * size_kv / weight_pack_factor;
5377                let o_proj = size_q * size_in / weight_pack_factor;
5378
5379                // Q, K, V norms
5380                let q_norm = text_cfg.head_dim;
5381                let k_norm = text_cfg.head_dim;
5382                let v_norm = text_cfg.head_dim; // No bias for v_norm
5383
5384                // MLP components - use the adjusted intermediate sizes from matformer
5385                let intermediate_size = match &text_cfg.intermediate_size {
5386                    IntermediateSize::Single(size) => *size,
5387                    IntermediateSize::PerLayer(sizes) => sizes[layer_idx],
5388                    IntermediateSize::Matformer(sizes, _) => sizes[layer_idx],
5389                };
5390                let gate_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5391                let up_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5392                let down_proj = intermediate_size * text_cfg.hidden_size / weight_pack_factor;
5393
5394                // AltUp components (per layer)
5395                let altup_elems = {
5396                    let correct_output_scale = text_cfg.hidden_size;
5397                    let correction_coefs = text_cfg.altup_num_inputs * text_cfg.altup_num_inputs;
5398                    let prediction_coefs =
5399                        text_cfg.altup_num_inputs * text_cfg.altup_num_inputs.pow(2);
5400                    let modality_router = text_cfg.hidden_size * text_cfg.altup_num_inputs;
5401                    let router_norm = text_cfg.hidden_size;
5402
5403                    correct_output_scale
5404                        + correction_coefs
5405                        + prediction_coefs
5406                        + modality_router
5407                        + router_norm
5408                };
5409
5410                // Laurel block components
5411                let laurel_elems = {
5412                    let left = text_cfg.hidden_size * text_cfg.laurel_rank;
5413                    let right = text_cfg.laurel_rank * text_cfg.hidden_size;
5414                    let post_norm = text_cfg.hidden_size;
5415
5416                    left + right + post_norm
5417                };
5418
5419                // Per-layer input components
5420                let per_layer_input_gate =
5421                    text_cfg.hidden_size * text_cfg.hidden_size_per_layer_input;
5422                let per_layer_projection =
5423                    text_cfg.hidden_size_per_layer_input * text_cfg.hidden_size;
5424
5425                input_layernorm
5426                    + post_attention_layernorm
5427                    + pre_feedforward_layernorm
5428                    + post_feedforward_layernorm
5429                    + post_per_layer_input_norm
5430                    + q_proj
5431                    + k_proj
5432                    + v_proj
5433                    + o_proj
5434                    + q_norm
5435                    + k_norm
5436                    + v_norm
5437                    + gate_proj
5438                    + up_proj
5439                    + down_proj
5440                    + altup_elems
5441                    + laurel_elems
5442                    + per_layer_input_gate
5443                    + per_layer_projection
5444            };
5445
5446            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5447        }
5448
5449        Ok(layer_sizes)
5450    }
5451
5452    fn num_layers(&self, config: &str) -> Result<usize> {
5453        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5454        Ok(cfg.text_config.num_hidden_layers)
5455    }
5456
5457    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5458        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5459        let cfg = cfg.text_config;
5460
5461        let cfg = ModelConfigMetadata {
5462            max_seq_len: cfg.max_position_embeddings,
5463            num_layers: cfg.num_hidden_layers,
5464            hidden_size: cfg.hidden_size,
5465            num_kv_heads: cfg.num_key_value_heads,
5466            num_attn_heads: cfg.num_attention_heads,
5467            sliding_window: None, // None to be more forgiving, some do not
5468            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5469            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5470        };
5471
5472        Ok(Box::new(cfg))
5473    }
5474
5475    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5476        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
5477    }
5478}