mistralrs_core/pipeline/loaders/
vision_loaders.rs

1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::device_map::DeviceMapper;
24use crate::layers::Conv3dConfig;
25use crate::matformer::MatformerSliceConfig;
26use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
27use crate::pipeline::isq::IsqModelLoader;
28use crate::pipeline::loaders::AutoDeviceMapParams;
29use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
30use crate::pipeline::{
31    EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
32    SupportedModality,
33};
34use crate::utils::varbuilder_utils::DeviceForLoadTensor;
35use crate::vision_models::clip::ClipConfig;
36use crate::vision_models::gemma3::config::Gemma3Config;
37use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
38use crate::vision_models::gemma3n::config::Gemma3nConfig;
39use crate::vision_models::gemma3n::{Gemma3nModel, Gemma3nProcessor};
40use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
41use crate::vision_models::idefics2_input_processor::Idefics2Processor;
42use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
43use crate::vision_models::image_processor::ImagePreProcessor;
44use crate::vision_models::inputs_processor::Phi4MMProcessor;
45use crate::vision_models::llama4::{
46    self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
47};
48use crate::vision_models::llava::config::Config as LLaVAConfig;
49use crate::vision_models::llava15::Model as LLaVA;
50use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
51use crate::vision_models::llava_next::Model as LLaVANext;
52use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
53use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
54use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
55use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
56use crate::vision_models::phi3_inputs_processor::Phi3Processor;
57use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
58use crate::vision_models::preprocessor_config::PreProcessorConfig;
59use crate::vision_models::processor_config::ProcessorConfig;
60use crate::vision_models::qwen2_5_vl::{
61    Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
62};
63use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
64use crate::vision_models::{minicpmo, phi4};
65
66pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
67    // pixel_values and pixel_attention_mask only specified for prompt seqs
68    #[allow(clippy::too_many_arguments)]
69    fn forward(
70        &self,
71        input_ids: &Tensor,
72        pixel_values: Option<Tensor>,
73        seqlen_offsets: &[usize],
74        context_lens: Vec<(usize, usize)>,
75        position_ids: Vec<usize>,
76        model_specific_args: Box<dyn Any>, // pixel attention mask, or image sizes, or anything else
77        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
78        flash_params: &FlashParams,
79    ) -> candle_core::Result<Tensor>;
80    fn device(&self) -> &Device;
81    fn cache(&self) -> &EitherCache;
82    fn cache_mut(&mut self) -> &mut EitherCache;
83    fn max_seq_len(&self) -> usize;
84    fn config(&self) -> &ModelConfigMetadata;
85    /// For a prompt without images. Requires batch size of 1!
86    fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
87}
88
89pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90    fn load(
91        &self,
92        config: &str,
93        vb: ShardedVarBuilder,
94        normal_loading_metadata: NormalLoadingMetadata,
95        attention_mechanism: AttentionImplementation,
96    ) -> Result<Box<dyn VisionModel + Send + Sync>>;
97    fn is_gptx(&self, config: &str) -> bool;
98    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
99    fn get_processor(
100        &self,
101        model_config: &str,
102        processor_config: Option<ProcessorConfig>,
103        preprocessor_config: PreProcessorConfig,
104        max_edge: Option<u32>,
105    ) -> Arc<dyn Processor + Send + Sync>;
106    fn supports_paged_attention(&self, config: &str) -> bool;
107    fn supports_prefix_cacher(&self, _config: &str) -> bool {
108        // Default is false, specific model must override.
109        false
110    }
111    fn modalities(&self, config: &str) -> Result<Modalities>;
112    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
113    fn get_device_for_tensor(
114        &self,
115        config: &str,
116        _mapper: &dyn DeviceMapper,
117        loading_isq: bool,
118    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119        if loading_isq {
120            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121        } else {
122            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123            let num_layers = self.model_config(config)?.num_layers();
124            let closure = move |name: String| {
125                if let Some(captures) = re.captures(&name) {
126                    captures
127                        .get(1)
128                        .and_then(|m| m.as_str().parse::<usize>().ok())
129                        .map(|l| l.min(num_layers))
130                        .map(DeviceForLoadTensor::Idx)
131                        .unwrap_or(DeviceForLoadTensor::Base)
132                } else {
133                    DeviceForLoadTensor::Base
134                }
135            };
136
137            Ok(Arc::new(closure))
138        }
139    }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, PartialEq)]
144/// The architecture to load the vision model as.
145pub enum VisionLoaderType {
146    #[serde(rename = "phi3v")]
147    Phi3V,
148    #[serde(rename = "idefics2")]
149    Idefics2,
150    #[serde(rename = "llava_next")]
151    LLaVANext,
152    #[serde(rename = "llava")]
153    LLaVA,
154    #[serde(rename = "vllama")]
155    VLlama,
156    #[serde(rename = "qwen2vl")]
157    Qwen2VL,
158    #[serde(rename = "idefics3")]
159    Idefics3,
160    #[serde(rename = "minicpmo")]
161    MiniCpmO,
162    #[serde(rename = "phi4mm")]
163    Phi4MM,
164    #[serde(rename = "qwen2_5vl")]
165    Qwen2_5VL,
166    #[serde(rename = "gemma3")]
167    Gemma3,
168    #[serde(rename = "mistral3")]
169    Mistral3,
170    #[serde(rename = "llama4")]
171    Llama4,
172    #[serde(rename = "gemma3n")]
173    Gemma3n,
174}
175
176// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
177impl VisionLoaderType {
178    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
179        match name {
180            "Phi3VForCausalLM" => Ok(Self::Phi3V),
181            "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
182            "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
183            "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
184            "MllamaForConditionalGeneration" => Ok(Self::VLlama),
185            "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
186            "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
187            "MiniCPMO" => Ok(Self::MiniCpmO),
188            "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
189            "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
190            "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
191            "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
192            "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
193            "Gemma3nForConditionalGeneration" => Ok(Self::Gemma3n),
194            other => anyhow::bail!(
195                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
196            ),
197        }
198    }
199}
200
201impl FromStr for VisionLoaderType {
202    type Err = String;
203    fn from_str(s: &str) -> Result<Self, Self::Err> {
204        match s {
205            "phi3v" => Ok(Self::Phi3V),
206            "idefics2" => Ok(Self::Idefics2),
207            "llava_next" => Ok(Self::LLaVANext),
208            "llava" => Ok(Self::LLaVA),
209            "vllama" => Ok(Self::VLlama),
210            "qwen2vl" => Ok(Self::Qwen2VL),
211            "idefics3" => Ok(Self::Idefics3),
212            "minicpmo" => Ok(Self::MiniCpmO),
213            "phi4mm" => Ok(Self::Phi4MM),
214            "qwen2_5vl" => Ok(Self::Qwen2_5VL),
215            "gemma3" => Ok(Self::Gemma3),
216            "mistral3" => Ok(Self::Mistral3),
217            "llama4" => Ok(Self::Llama4),
218            "gemma3n" => Ok(Self::Gemma3n),
219            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`, `gemma3n`.")),
220        }
221    }
222}
223
224impl std::fmt::Display for VisionLoaderType {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        let name = match self {
227            VisionLoaderType::Phi3V => "phi3v",
228            VisionLoaderType::Idefics2 => "idefics2",
229            VisionLoaderType::LLaVANext => "llava_next",
230            VisionLoaderType::LLaVA => "llava",
231            VisionLoaderType::VLlama => "vllama",
232            VisionLoaderType::Qwen2VL => "qwen2vl",
233            VisionLoaderType::Idefics3 => "idefics3",
234            VisionLoaderType::MiniCpmO => "minicpmo",
235            VisionLoaderType::Phi4MM => "phi4mm",
236            VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
237            VisionLoaderType::Gemma3 => "gemma3",
238            VisionLoaderType::Mistral3 => "mistral3",
239            VisionLoaderType::Llama4 => "llama4",
240            VisionLoaderType::Gemma3n => "gemma3n",
241        };
242        write!(f, "{name}")
243    }
244}
245
246#[derive(Deserialize)]
247struct AutoVisionLoaderConfig {
248    architectures: Vec<String>,
249}
250
251/// Automatically selects a VisionModelLoader implementation based on the JSON `architectures` field.
252pub struct AutoVisionLoader;
253
254impl AutoVisionLoader {
255    fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
256        let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
257        if auto_cfg.architectures.len() != 1 {
258            anyhow::bail!("Expected exactly one architecture in config");
259        }
260
261        let name = &auto_cfg.architectures[0];
262        let tp = VisionLoaderType::from_causal_lm_name(name)?;
263
264        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
265
266        // Delegate to the concrete loader
267        Ok(match tp {
268            VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
269            VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
270            VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
271            VisionLoaderType::LLaVA => Box::new(LLaVALoader),
272            VisionLoaderType::VLlama => Box::new(VLlamaLoader),
273            VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
274            VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
275            VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
276            VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
277            VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
278            VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
279            VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
280            VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
281            VisionLoaderType::Gemma3n => Box::new(Gemma3nLoader),
282        })
283    }
284}
285
286impl VisionModelLoader for AutoVisionLoader {
287    fn load(
288        &self,
289        config: &str,
290        vb: ShardedVarBuilder,
291        normal_loading_metadata: NormalLoadingMetadata,
292        attention_mechanism: AttentionImplementation,
293    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
294        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
295    }
296
297    fn is_gptx(&self, config: &str) -> bool {
298        Self::get_loader(config)
299            .expect("AutoVisionLoader get_loader")
300            .is_gptx(config)
301    }
302
303    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
304        Self::get_loader(config)?.get_config_repr(config)
305    }
306
307    fn get_processor(
308        &self,
309        model_config: &str,
310        proc_cfg: Option<ProcessorConfig>,
311        preproc_cfg: PreProcessorConfig,
312        max_edge: Option<u32>,
313    ) -> Arc<dyn Processor + Send + Sync> {
314        Self::get_loader(model_config)
315            .expect("AutoVisionLoader get_loader")
316            .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
317    }
318
319    fn supports_paged_attention(&self, config: &str) -> bool {
320        Self::get_loader(config)
321            .expect("AutoVisionLoader")
322            .supports_paged_attention(config)
323    }
324
325    fn modalities(&self, config: &str) -> Result<Modalities> {
326        Self::get_loader(config)?.modalities(config)
327    }
328
329    fn supports_prefix_cacher(&self, config: &str) -> bool {
330        Self::get_loader(config)
331            .expect("AutoVisionLoader")
332            .supports_prefix_cacher(config)
333    }
334
335    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
336        Self::get_loader(config)
337            .expect("AutoVisionLoader")
338            .prefixer(config)
339    }
340
341    fn get_device_for_tensor(
342        &self,
343        config: &str,
344        mapper: &dyn DeviceMapper,
345        loading_isq: bool,
346    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
347        Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
348    }
349}
350
351impl IsqModelLoader for AutoVisionLoader {
352    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
353        Self::get_loader(config)?.isq_layer_regexes(config)
354    }
355    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
356        Self::get_loader(config)?.immediate_isq_predicates(config)
357    }
358}
359
360impl DeviceMappedModelLoader for AutoVisionLoader {
361    fn mapped_max_act_size_elems(
362        &self,
363        config: &str,
364        params: &AutoDeviceMapParams,
365        prompt_chunksize: usize,
366    ) -> Result<usize> {
367        Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
368    }
369    fn non_mapped_max_act_size_elems(
370        &self,
371        config: &str,
372        params: &AutoDeviceMapParams,
373    ) -> Result<usize> {
374        Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
375    }
376    fn non_mapped_size_in_bytes(
377        &self,
378        config: &str,
379        dtype: DType,
380        weight_pack_factor: usize,
381        _matformer_config: Option<&MatformerSliceConfig>,
382    ) -> Result<usize> {
383        Self::get_loader(config)?.non_mapped_size_in_bytes(
384            config,
385            dtype,
386            weight_pack_factor,
387            _matformer_config,
388        )
389    }
390    fn layer_sizes_in_bytes(
391        &self,
392        config: &str,
393        dtype: DType,
394        weight_pack_factor: usize,
395        _matformer_config: Option<&MatformerSliceConfig>,
396    ) -> Result<Vec<usize>> {
397        Self::get_loader(config)?.layer_sizes_in_bytes(
398            config,
399            dtype,
400            weight_pack_factor,
401            _matformer_config,
402        )
403    }
404    fn num_layers(&self, config: &str) -> Result<usize> {
405        Self::get_loader(config)?.num_layers(config)
406    }
407    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
408        Self::get_loader(config)?.model_config(config)
409    }
410}
411
412macro_rules! bias_if {
413    ($cond:expr, $size:expr) => {
414        if $cond {
415            $size
416        } else {
417            0
418        }
419    };
420}
421
422fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
423    let pre_layer_norm = cfg.hidden_size;
424    let final_layer_norm = cfg.hidden_size;
425
426    let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
427    let num_positions = num_patches + 1;
428
429    let class_embedding = cfg.hidden_size;
430
431    let position_ids = num_positions;
432    let position_embedding = num_positions * cfg.hidden_size;
433
434    let conv2dconfig = Conv2dConfig {
435        stride: cfg.patch_size,
436        ..Default::default()
437    };
438    let patch_embedding =
439        cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
440
441    let encoder_layer_elems = {
442        let layer_norm1 = cfg.hidden_size;
443        let layer_norm2 = cfg.hidden_size;
444
445        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
446        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
447        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
448        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
449
450        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
451        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
452
453        layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
454    };
455
456    pre_layer_norm
457        + final_layer_norm
458        + class_embedding
459        + position_ids
460        + position_embedding
461        + patch_embedding
462        + cfg.num_hidden_layers * encoder_layer_elems
463}
464
465// ======================== Phi 3 loader
466
467/// [`VisionLoader`] for a Phi 3 Vision model.
468///
469/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
470pub struct Phi3VLoader;
471
472pub struct Phi3VPrefixer;
473
474impl MultimodalPromptPrefixer for Phi3VPrefixer {
475    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
476        // Image indexing starts at 0.
477        format!(
478            "{}{prompt}",
479            image_indexes
480                .into_iter()
481                .map(|image_index| format!("<|image_{}|>", image_index + 1))
482                .join("")
483        )
484    }
485}
486
487impl VisionModelLoader for Phi3VLoader {
488    fn load(
489        &self,
490        config: &str,
491        vb: ShardedVarBuilder,
492        normal_loading_metadata: NormalLoadingMetadata,
493        attention_mechanism: AttentionImplementation,
494    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
495        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
496        Ok(Box::new(Phi3::new(
497            &cfg,
498            vb,
499            self.is_gptx(config),
500            normal_loading_metadata,
501            attention_mechanism,
502        )?))
503    }
504    fn is_gptx(&self, _config: &str) -> bool {
505        true
506    }
507    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
508        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
509        Ok(Box::new(cfg))
510    }
511    fn get_processor(
512        &self,
513        _model_config: &str,
514        processor_config: Option<ProcessorConfig>,
515        preprocessor_config: PreProcessorConfig,
516        _max_edge: Option<u32>,
517    ) -> Arc<dyn Processor + Send + Sync> {
518        Phi3Processor::new_processor(processor_config, preprocessor_config)
519    }
520    fn supports_paged_attention(&self, _config: &str) -> bool {
521        true
522    }
523    fn supports_prefix_cacher(&self, _config: &str) -> bool {
524        true
525    }
526    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
527        Arc::new(Phi3VPrefixer)
528    }
529    fn modalities(&self, _config: &str) -> Result<Modalities> {
530        Ok(Modalities {
531            input: vec![SupportedModality::Text, SupportedModality::Vision],
532            output: vec![SupportedModality::Text],
533        })
534    }
535}
536
537impl IsqModelLoader for Phi3VLoader {
538    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
539        Ok(vec![
540            Regex::new(r"lm_head\.(weight|bias)$")?,
541            // Attention
542            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
543            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
544            // MLP
545            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
546            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
547        ])
548    }
549    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
550        self.isq_layer_regexes(config)
551    }
552}
553
554impl DeviceMappedModelLoader for Phi3VLoader {
555    fn mapped_max_act_size_elems(
556        &self,
557        config: &str,
558        params: &AutoDeviceMapParams,
559        _prompt_chunksize: usize,
560    ) -> Result<usize> {
561        // NOTE: we ignore max_num_images although it can only be one...
562        let AutoDeviceMapParams::Vision {
563            max_seq_len,
564            max_batch_size,
565            max_image_shape: _,
566            max_num_images,
567        } = params
568        else {
569            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
570        };
571
572        let cfg: Phi3Config = serde_json::from_str(config)?;
573
574        let vcfg = &PHI3V_CLIP_CONFIG;
575
576        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
577        let img_seq_len = (num_patches + 1) * max_num_images;
578
579        let max_text_attn = {
580            // This model injects the vision information directly into the input embeddings
581            let max_seq_len = img_seq_len + max_seq_len;
582            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
583        };
584
585        Ok(max_text_attn)
586    }
587
588    fn non_mapped_max_act_size_elems(
589        &self,
590        config: &str,
591        params: &AutoDeviceMapParams,
592    ) -> Result<usize> {
593        // NOTE: we ignore max_num_images although it can only be one...
594        let AutoDeviceMapParams::Vision {
595            max_seq_len: _,
596            max_batch_size,
597            max_image_shape: _,
598            max_num_images,
599        } = params
600        else {
601            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
602        };
603
604        let cfg: Phi3Config = serde_json::from_str(config)?;
605
606        let vcfg = &PHI3V_CLIP_CONFIG;
607
608        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
609        let img_seq_len = num_patches + 1;
610
611        let max_vision_attn = {
612            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
613        };
614
615        Ok(max_vision_attn)
616    }
617
618    fn non_mapped_size_in_bytes(
619        &self,
620        config: &str,
621        dtype: DType,
622        weight_pack_factor: usize,
623        _matformer_config: Option<&MatformerSliceConfig>,
624    ) -> Result<usize> {
625        let cfg: Phi3Config = serde_json::from_str(config)?;
626        let elems = {
627            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
628            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
629            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
630                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
631            } else {
632                0
633            };
634            let norm = cfg.hidden_size;
635
636            let image_embed = {
637                let projection_cls = cfg
638                    .embd_layer
639                    .projection_cls
640                    .clone()
641                    .unwrap_or("linear".to_string());
642                let with_learnable_separator =
643                    cfg.embd_layer.with_learnable_separator.unwrap_or(false);
644                let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
645                let image_dim_out = cfg.img_processor.image_dim_out;
646
647                let proj = match (projection_cls.as_str(), use_hd_transform) {
648                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
649                    ("mlp", true) => {
650                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
651                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
652                        a + b
653                    }
654                    ("mlp", false) => {
655                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
656                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
657                        a + b
658                    }
659                    _ => {
660                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
661                    }
662                };
663
664                let (glb_gn, sub_gn) = if with_learnable_separator {
665                    let glb_gn = image_dim_out * 4;
666                    let sub_gn = image_dim_out * 4;
667                    (glb_gn, sub_gn)
668                } else {
669                    (0, 0)
670                };
671
672                let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
673
674                proj + glb_gn + sub_gn + clip_vit
675            };
676
677            embed_tokens + lm_head + norm + image_embed
678        };
679
680        Ok(elems * dtype.size_in_bytes())
681    }
682
683    fn layer_sizes_in_bytes(
684        &self,
685        config: &str,
686        dtype: DType,
687        weight_pack_factor: usize,
688        _matformer_config: Option<&MatformerSliceConfig>,
689    ) -> Result<Vec<usize>> {
690        let cfg: Phi3Config = serde_json::from_str(config)?;
691        let per_layer_elems = {
692            let input_layernorm = cfg.hidden_size;
693            let post_attention_layernorm = cfg.hidden_size;
694
695            let size_in = cfg.hidden_size;
696            let head_dim = cfg.head_dim();
697            let op_size =
698                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
699            let qkv_proj = size_in * op_size / weight_pack_factor;
700            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
701
702            let h_size = cfg.hidden_size;
703            let i_size = cfg.intermediate_size;
704            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
705            let down_proj = h_size * i_size / weight_pack_factor;
706
707            input_layernorm
708                + post_attention_layernorm
709                + qkv_proj
710                + o_proj
711                + gate_up_proj
712                + down_proj
713        };
714        Ok(vec![
715            per_layer_elems * dtype.size_in_bytes();
716            cfg.num_hidden_layers
717        ])
718    }
719
720    fn num_layers(&self, config: &str) -> Result<usize> {
721        let cfg: Phi3Config = serde_json::from_str(config)?;
722        Ok(cfg.num_hidden_layers)
723    }
724
725    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
726        let cfg: Phi3Config = serde_json::from_str(config)?;
727
728        let cfg = ModelConfigMetadata {
729            max_seq_len: cfg.max_position_embeddings,
730            num_layers: cfg.num_hidden_layers,
731            hidden_size: cfg.hidden_size,
732            num_kv_heads: cfg.num_key_value_heads,
733            num_attn_heads: cfg.num_attention_heads,
734            sliding_window: cfg.sliding_window,
735            k_head_dim: cfg.head_dim(),
736            v_head_dim: cfg.head_dim(),
737        };
738
739        Ok(Box::new(cfg))
740    }
741
742    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
743        Some(vec![NonMappedSubModel::Vision])
744    }
745}
746
747// ======================== Idefics 2 loader
748
749/// [`VisionLoader`] for an Idefics 2 Vision model.
750///
751/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
752pub struct Idefics2Loader;
753
754pub struct Idefics2Prefixer;
755
756impl MultimodalPromptPrefixer for Idefics2Prefixer {
757    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
758        // Chat template does it
759        prompt.to_string()
760    }
761}
762
763impl VisionModelLoader for Idefics2Loader {
764    fn load(
765        &self,
766        config: &str,
767        vb: ShardedVarBuilder,
768        normal_loading_metadata: NormalLoadingMetadata,
769        attention_mechanism: AttentionImplementation,
770    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
771        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
772        Ok(Box::new(Idefics2::new(
773            &cfg,
774            vb,
775            self.is_gptx(config),
776            normal_loading_metadata,
777            attention_mechanism,
778        )?))
779    }
780    fn is_gptx(&self, _config: &str) -> bool {
781        true
782    }
783    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
784        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
785        Ok(Box::new(cfg))
786    }
787    fn get_processor(
788        &self,
789        _model_config: &str,
790        processor_config: Option<ProcessorConfig>,
791        preprocessor_config: PreProcessorConfig,
792        max_edge: Option<u32>,
793    ) -> Arc<dyn Processor + Send + Sync> {
794        Arc::new(Idefics2Processor::new(
795            processor_config.unwrap(),
796            preprocessor_config,
797            max_edge,
798        ))
799    }
800    fn supports_paged_attention(&self, _config: &str) -> bool {
801        true
802    }
803    fn supports_prefix_cacher(&self, _config: &str) -> bool {
804        true
805    }
806    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
807        Arc::new(Idefics2Prefixer)
808    }
809    fn modalities(&self, _config: &str) -> Result<Modalities> {
810        Ok(Modalities {
811            input: vec![SupportedModality::Text, SupportedModality::Vision],
812            output: vec![SupportedModality::Text],
813        })
814    }
815}
816
817impl IsqModelLoader for Idefics2Loader {
818    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
819        Ok(vec![
820            Regex::new(r"lm_head\.(weight|bias)$")?,
821            // Attention
822            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
823            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
824            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
825            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
826            // MLP
827            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
828            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
829            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
830        ])
831    }
832    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
833        Ok(vec![
834            Regex::new(r"lm_head\.(weight|bias)$")?,
835            // Attention
836            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
837            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
838            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
839            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
840            // MLP
841            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
842            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
843            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
844        ])
845    }
846}
847
848impl DeviceMappedModelLoader for Idefics2Loader {
849    fn mapped_max_act_size_elems(
850        &self,
851        config: &str,
852        params: &AutoDeviceMapParams,
853        _prompt_chunksize: usize,
854    ) -> Result<usize> {
855        let AutoDeviceMapParams::Vision {
856            max_seq_len,
857            max_batch_size,
858            max_image_shape: _,
859            max_num_images,
860        } = params
861        else {
862            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
863        };
864
865        let cfg: Idefics2Config = serde_json::from_str(config)?;
866
867        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
868        let img_seq_len = (num_patches + 1) * max_num_images;
869
870        let max_text_attn = {
871            // This model injects the vision information directly into the input embeddings
872            let max_seq_len = img_seq_len + max_seq_len;
873            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
874        };
875
876        Ok(max_text_attn)
877    }
878
879    fn non_mapped_max_act_size_elems(
880        &self,
881        config: &str,
882        params: &AutoDeviceMapParams,
883    ) -> Result<usize> {
884        let AutoDeviceMapParams::Vision {
885            max_seq_len: _,
886            max_batch_size,
887            max_image_shape: _,
888            max_num_images,
889        } = params
890        else {
891            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
892        };
893
894        let cfg: Idefics2Config = serde_json::from_str(config)?;
895
896        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
897        let img_seq_len = num_patches + 1;
898
899        let max_vision_attn = {
900            // do_image_splitting = true
901            let images_factor = 5;
902
903            (max_batch_size * images_factor * max_num_images)
904                * cfg.vision_config.num_attention_heads
905                * img_seq_len
906                * img_seq_len
907        };
908
909        Ok(max_vision_attn)
910    }
911
912    fn non_mapped_size_in_bytes(
913        &self,
914        config: &str,
915        dtype: DType,
916        weight_pack_factor: usize,
917        _matformer_config: Option<&MatformerSliceConfig>,
918    ) -> Result<usize> {
919        let cfg: Idefics2Config = serde_json::from_str(config)?;
920        let text_elems = {
921            let tie_word_embeddings = cfg.tie_word_embeddings;
922            let cfg = &cfg.text_config;
923
924            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
925            let lm_head = if !tie_word_embeddings {
926                cfg.hidden_size * cfg.vocab_size
927            } else {
928                0
929            };
930            let norm = cfg.hidden_size;
931            embed_tokens + lm_head + norm
932        };
933
934        let connector_elems = {
935            let tcfg = &cfg.text_config;
936            let vcfg = &cfg.vision_config;
937            let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
938            let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
939            let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
940
941            let perceiver_elems = {
942                let tcfg = &cfg.text_config;
943                let pcfg = &cfg.perceiver_config;
944
945                let n_latents = pcfg.resampler_n_latents;
946                let hidden_size = tcfg.hidden_size;
947                let depth = pcfg.resampler_depth;
948
949                let norm = tcfg.hidden_size;
950                let latents = n_latents * hidden_size;
951
952                let layer_elems = {
953                    let input_latents_norm = hidden_size;
954                    let input_context_norm = hidden_size;
955                    let post_attn_norm = hidden_size;
956
957                    let num_heads = pcfg.resampler_n_heads;
958                    let head_dim = pcfg.resampler_head_dim;
959                    let num_key_value_heads = pcfg.num_key_value_heads;
960
961                    let q_proj = hidden_size * num_heads * head_dim;
962                    let k_proj = hidden_size * num_key_value_heads * head_dim;
963                    let v_proj = hidden_size * num_key_value_heads * head_dim;
964                    let o_proj = num_heads * head_dim * hidden_size;
965
966                    let gate_proj = hidden_size * hidden_size * 4;
967                    let up_proj = hidden_size * hidden_size * 4;
968                    let down_proj = hidden_size * 4 * hidden_size;
969
970                    input_latents_norm
971                        + input_context_norm
972                        + post_attn_norm
973                        + q_proj
974                        + k_proj
975                        + v_proj
976                        + o_proj
977                        + gate_proj
978                        + up_proj
979                        + down_proj
980                };
981
982                norm + latents + layer_elems * depth
983            };
984
985            gate_proj + up_proj + down_proj + perceiver_elems
986        };
987
988        let vision_transformer = {
989            let cfg = &cfg.vision_config;
990
991            let post_layernorm = cfg.hidden_size;
992
993            let conv_config = Conv2dConfig {
994                stride: cfg.patch_size,
995                ..Default::default()
996            };
997            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
998                * cfg.patch_size
999                * cfg.patch_size;
1000
1001            let num_patches_per_side = cfg.image_size / cfg.patch_size;
1002            let num_patches = num_patches_per_side.pow(2);
1003            let position_embedding = num_patches * cfg.hidden_size;
1004
1005            let layer_elems = {
1006                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1007                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1008
1009                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
1010                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
1011
1012                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1013                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1014                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1015                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1016
1017                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
1018            };
1019
1020            post_layernorm + patch_embedding + position_embedding + layer_elems
1021        };
1022
1023        let elems = text_elems + connector_elems + vision_transformer;
1024
1025        Ok(elems * dtype.size_in_bytes())
1026    }
1027
1028    fn layer_sizes_in_bytes(
1029        &self,
1030        config: &str,
1031        dtype: DType,
1032        weight_pack_factor: usize,
1033        _matformer_config: Option<&MatformerSliceConfig>,
1034    ) -> Result<Vec<usize>> {
1035        let cfg: Idefics2Config = serde_json::from_str(config)?;
1036        let cfg = cfg.text_config;
1037        let per_layer_elems = {
1038            let input_layernorm = cfg.hidden_size;
1039            let post_attention_layernorm = cfg.hidden_size;
1040
1041            let size_in = cfg.hidden_size;
1042            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1043            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1044            let q_proj = size_in * size_q / weight_pack_factor;
1045            let k_proj = size_in * size_kv / weight_pack_factor;
1046            let v_proj = size_in * size_kv / weight_pack_factor;
1047            let o_proj = size_q * size_in / weight_pack_factor;
1048
1049            let h_size = cfg.hidden_size;
1050            let i_size = cfg.intermediate_size;
1051            let gate_proj = h_size * i_size / weight_pack_factor;
1052            let up_proj = h_size * i_size / weight_pack_factor;
1053            let down_proj = i_size * h_size / weight_pack_factor;
1054
1055            input_layernorm
1056                + post_attention_layernorm
1057                + q_proj
1058                + k_proj
1059                + v_proj
1060                + o_proj
1061                + gate_proj
1062                + up_proj
1063                + down_proj
1064        };
1065        Ok(vec![
1066            per_layer_elems * dtype.size_in_bytes();
1067            cfg.num_hidden_layers
1068        ])
1069    }
1070
1071    fn num_layers(&self, config: &str) -> Result<usize> {
1072        let cfg: Idefics2Config = serde_json::from_str(config)?;
1073        Ok(cfg.text_config.num_hidden_layers)
1074    }
1075    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1076        let cfg: Idefics2Config = serde_json::from_str(config)?;
1077        let cfg = &cfg.text_config;
1078
1079        let cfg = ModelConfigMetadata {
1080            max_seq_len: cfg.max_position_embeddings,
1081            num_layers: cfg.num_hidden_layers,
1082            hidden_size: cfg.hidden_size,
1083            num_kv_heads: cfg.num_key_value_heads,
1084            num_attn_heads: cfg.num_attention_heads,
1085            sliding_window: cfg.sliding_window,
1086            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1087            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1088        };
1089
1090        Ok(Box::new(cfg))
1091    }
1092
1093    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1094        Some(vec![NonMappedSubModel::Vision])
1095    }
1096}
1097
1098// ======================== LLaVANext Loader
1099
1100/// [`VisionLoader`] for an LLaVANext Vision model.
1101///
1102/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1103pub struct LLaVANextLoader;
1104
1105pub struct LLaVANextPrefixer;
1106
1107impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1108    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1109        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1110    }
1111}
1112
1113impl VisionModelLoader for LLaVANextLoader {
1114    fn load(
1115        &self,
1116        config: &str,
1117        vb: ShardedVarBuilder,
1118        normal_loading_metadata: NormalLoadingMetadata,
1119        attention_mechanism: AttentionImplementation,
1120    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1121        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1122        Ok(Box::new(LLaVANext::new(
1123            &cfg,
1124            vb,
1125            self.is_gptx(config),
1126            normal_loading_metadata,
1127            attention_mechanism,
1128        )?))
1129    }
1130    fn is_gptx(&self, _config: &str) -> bool {
1131        false
1132    }
1133    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1134        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1135        Ok(Box::new(cfg))
1136    }
1137    fn get_processor(
1138        &self,
1139        model_config: &str,
1140        _processor_config: Option<ProcessorConfig>,
1141        _preprocessor_config: PreProcessorConfig,
1142        _max_edge: Option<u32>,
1143    ) -> Arc<dyn Processor + Send + Sync> {
1144        Arc::new(LLaVANextProcessor::new(model_config))
1145    }
1146    fn supports_paged_attention(&self, _config: &str) -> bool {
1147        true
1148    }
1149    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1150        true
1151    }
1152    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1153        Arc::new(LLaVANextPrefixer)
1154    }
1155    fn modalities(&self, _config: &str) -> Result<Modalities> {
1156        Ok(Modalities {
1157            input: vec![SupportedModality::Text, SupportedModality::Vision],
1158            output: vec![SupportedModality::Text],
1159        })
1160    }
1161}
1162
1163impl IsqModelLoader for LLaVANextLoader {
1164    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1165        Ok(vec![
1166            Regex::new(r"lm_head\.(weight|bias)$")?,
1167            // Attention
1168            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1169            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1170            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1171            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1172            // MLP
1173            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1174            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1175            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1176        ])
1177    }
1178    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1179        Ok(vec![
1180            Regex::new(r"lm_head\.(weight|bias)$")?,
1181            // Attention
1182            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1183            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1184            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1185            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1186            // MLP
1187            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1188            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1189            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1190        ])
1191    }
1192}
1193
1194impl DeviceMappedModelLoader for LLaVANextLoader {
1195    fn mapped_max_act_size_elems(
1196        &self,
1197        config: &str,
1198        params: &AutoDeviceMapParams,
1199        _prompt_chunksize: usize,
1200    ) -> Result<usize> {
1201        let AutoDeviceMapParams::Vision {
1202            max_seq_len,
1203            max_batch_size,
1204            max_image_shape,
1205            max_num_images,
1206        } = params
1207        else {
1208            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1209        };
1210
1211        let config: LLaVAConfig = serde_json::from_str(config)?;
1212
1213        #[allow(clippy::cast_possible_truncation)]
1214        let img_seq_len =
1215            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1216                &config,
1217                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1218            );
1219        let img_seq_len = img_seq_len * max_num_images;
1220
1221        let max_text_attn = {
1222            let cfg = &config.text_config;
1223            // This model injects the vision information directly into the input embeddings
1224            let max_seq_len = img_seq_len + max_seq_len;
1225
1226            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1227        };
1228
1229        Ok(max_text_attn)
1230    }
1231
1232    fn non_mapped_max_act_size_elems(
1233        &self,
1234        config: &str,
1235        params: &AutoDeviceMapParams,
1236    ) -> Result<usize> {
1237        let AutoDeviceMapParams::Vision {
1238            max_seq_len: _,
1239            max_batch_size,
1240            max_image_shape,
1241            max_num_images,
1242        } = params
1243        else {
1244            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1245        };
1246
1247        let config: LLaVAConfig = serde_json::from_str(config)?;
1248
1249        #[allow(clippy::cast_possible_truncation)]
1250        let img_seq_len =
1251            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1252                &config,
1253                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1254            );
1255
1256        let max_vision_attn = {
1257            (max_batch_size * max_num_images)
1258                * config.vision_config.num_attention_heads
1259                * img_seq_len
1260                * img_seq_len
1261        };
1262
1263        Ok(max_vision_attn)
1264    }
1265
1266    fn non_mapped_size_in_bytes(
1267        &self,
1268        config: &str,
1269        dtype: DType,
1270        weight_pack_factor: usize,
1271        _matformer_config: Option<&MatformerSliceConfig>,
1272    ) -> Result<usize> {
1273        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1274        let text_elems = {
1275            let cfg = &cfg.text_config;
1276            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1277            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1278            let norm = cfg.hidden_size;
1279            embed_tokens + lm_head + norm
1280        };
1281
1282        let image_newline = cfg.text_config.hidden_size;
1283        let mmproj = {
1284            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1285                + cfg.text_config.hidden_size;
1286            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1287                + cfg.text_config.hidden_size;
1288
1289            linear_1 + linear_2
1290        };
1291        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1292
1293        let elems = text_elems + image_newline + mmproj + vision_tower;
1294        Ok(elems * dtype.size_in_bytes())
1295    }
1296
1297    fn layer_sizes_in_bytes(
1298        &self,
1299        config: &str,
1300        dtype: DType,
1301        weight_pack_factor: usize,
1302        _matformer_config: Option<&MatformerSliceConfig>,
1303    ) -> Result<Vec<usize>> {
1304        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1305        let per_layer_elems = {
1306            let cfg = &cfg.text_config;
1307            let input_layernorm = cfg.hidden_size;
1308            let post_attention_layernorm = cfg.hidden_size;
1309
1310            let size_in = cfg.hidden_size;
1311            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1312            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1313            let q_proj = size_in * size_q / weight_pack_factor;
1314            let k_proj = size_in * size_kv / weight_pack_factor;
1315            let v_proj = size_in * size_kv / weight_pack_factor;
1316            let o_proj = size_q * size_in / weight_pack_factor;
1317
1318            let h_size = cfg.hidden_size;
1319            let i_size = cfg.intermediate_size;
1320            let gate_proj = h_size * i_size / weight_pack_factor;
1321            let up_proj = h_size * i_size / weight_pack_factor;
1322            let down_proj = i_size * h_size / weight_pack_factor;
1323
1324            input_layernorm
1325                + post_attention_layernorm
1326                + q_proj
1327                + k_proj
1328                + v_proj
1329                + o_proj
1330                + gate_proj
1331                + up_proj
1332                + down_proj
1333        };
1334        Ok(vec![
1335            per_layer_elems * dtype.size_in_bytes();
1336            cfg.text_config.num_hidden_layers
1337        ])
1338    }
1339
1340    fn num_layers(&self, config: &str) -> Result<usize> {
1341        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1342        Ok(cfg.text_config.num_hidden_layers)
1343    }
1344
1345    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1346        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1347        let cfg = &cfg.text_config;
1348
1349        let cfg = ModelConfigMetadata {
1350            max_seq_len: cfg.max_position_embeddings,
1351            num_layers: cfg.num_hidden_layers,
1352            hidden_size: cfg.hidden_size,
1353            num_kv_heads: cfg.num_key_value_heads,
1354            num_attn_heads: cfg.num_attention_heads,
1355            sliding_window: cfg.sliding_window,
1356            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1357            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1358        };
1359
1360        Ok(Box::new(cfg))
1361    }
1362
1363    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1364        Some(vec![NonMappedSubModel::Vision])
1365    }
1366}
1367
1368// ======================== LLaVA Loader
1369
1370/// [`VisionLoader`] for an LLaVA Vision model.
1371///
1372/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1373pub struct LLaVALoader;
1374
1375pub struct LLaVAPrefixer;
1376
1377impl MultimodalPromptPrefixer for LLaVAPrefixer {
1378    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1379        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1380    }
1381}
1382
1383impl VisionModelLoader for LLaVALoader {
1384    fn load(
1385        &self,
1386        config: &str,
1387        vb: ShardedVarBuilder,
1388        normal_loading_metadata: NormalLoadingMetadata,
1389        attention_mechanism: AttentionImplementation,
1390    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1391        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1392        Ok(Box::new(LLaVA::new(
1393            &cfg,
1394            vb,
1395            self.is_gptx(config),
1396            normal_loading_metadata,
1397            attention_mechanism,
1398        )?))
1399    }
1400    fn is_gptx(&self, _config: &str) -> bool {
1401        false
1402    }
1403    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1404        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1405        Ok(Box::new(cfg))
1406    }
1407    fn get_processor(
1408        &self,
1409        model_config: &str,
1410        _processor_config: Option<ProcessorConfig>,
1411        _preprocessor_config: PreProcessorConfig,
1412        _max_edge: Option<u32>,
1413    ) -> Arc<dyn Processor + Send + Sync> {
1414        Arc::new(LLaVAProcessor::new(model_config))
1415    }
1416    fn supports_paged_attention(&self, _config: &str) -> bool {
1417        true
1418    }
1419    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1420        true
1421    }
1422    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1423        Arc::new(LLaVAPrefixer)
1424    }
1425    fn modalities(&self, _config: &str) -> Result<Modalities> {
1426        Ok(Modalities {
1427            input: vec![SupportedModality::Text, SupportedModality::Vision],
1428            output: vec![SupportedModality::Text],
1429        })
1430    }
1431}
1432
1433impl IsqModelLoader for LLaVALoader {
1434    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1435        Ok(vec![
1436            Regex::new(r"lm_head\.(weight|bias)$")?,
1437            // Attention
1438            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1439            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1440            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1441            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1442            // MLP
1443            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1444            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1445            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1446        ])
1447    }
1448    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1449        Ok(vec![
1450            Regex::new(r"lm_head\.(weight|bias)$")?,
1451            // Attention
1452            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1453            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1454            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1455            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1456            // MLP
1457            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1458            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1459            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1460        ])
1461    }
1462}
1463
1464impl DeviceMappedModelLoader for LLaVALoader {
1465    fn mapped_max_act_size_elems(
1466        &self,
1467        config: &str,
1468        params: &AutoDeviceMapParams,
1469        _prompt_chunksize: usize,
1470    ) -> Result<usize> {
1471        let AutoDeviceMapParams::Vision {
1472            max_seq_len,
1473            max_batch_size,
1474            max_image_shape: _,
1475            max_num_images,
1476        } = params
1477        else {
1478            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1479        };
1480
1481        let config: LLaVAConfig = serde_json::from_str(config)?;
1482
1483        let img_seq_len =
1484            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1485        let img_seq_len = img_seq_len * max_num_images;
1486
1487        let max_text_attn = {
1488            let cfg = &config.text_config;
1489            // This model injects the vision information directly into the input embeddings
1490            let max_seq_len = img_seq_len + max_seq_len;
1491
1492            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1493        };
1494
1495        Ok(max_text_attn)
1496    }
1497
1498    fn non_mapped_max_act_size_elems(
1499        &self,
1500        config: &str,
1501        params: &AutoDeviceMapParams,
1502    ) -> Result<usize> {
1503        let AutoDeviceMapParams::Vision {
1504            max_seq_len: _,
1505            max_batch_size,
1506            max_image_shape: _,
1507            max_num_images,
1508        } = params
1509        else {
1510            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1511        };
1512
1513        let config: LLaVAConfig = serde_json::from_str(config)?;
1514
1515        let img_seq_len =
1516            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1517
1518        let max_vision_attn = {
1519            (max_batch_size * max_num_images)
1520                * config.vision_config.num_attention_heads
1521                * img_seq_len
1522                * img_seq_len
1523        };
1524
1525        Ok(max_vision_attn)
1526    }
1527
1528    fn non_mapped_size_in_bytes(
1529        &self,
1530        config: &str,
1531        dtype: DType,
1532        weight_pack_factor: usize,
1533        _matformer_config: Option<&MatformerSliceConfig>,
1534    ) -> Result<usize> {
1535        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1536        let text_elems = {
1537            let cfg = &cfg.text_config;
1538            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1539            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1540            let norm = cfg.hidden_size;
1541            embed_tokens + lm_head + norm
1542        };
1543
1544        let image_newline = cfg.text_config.hidden_size;
1545        let mmproj = {
1546            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1547                + cfg.text_config.hidden_size;
1548            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1549                + cfg.text_config.hidden_size;
1550
1551            linear_1 + linear_2
1552        };
1553        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1554
1555        let elems = text_elems + image_newline + mmproj + vision_tower;
1556        Ok(elems * dtype.size_in_bytes())
1557    }
1558
1559    fn layer_sizes_in_bytes(
1560        &self,
1561        config: &str,
1562        dtype: DType,
1563        weight_pack_factor: usize,
1564        _matformer_config: Option<&MatformerSliceConfig>,
1565    ) -> Result<Vec<usize>> {
1566        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1567        let per_layer_elems = {
1568            let cfg = &cfg.text_config;
1569            let input_layernorm = cfg.hidden_size;
1570            let post_attention_layernorm = cfg.hidden_size;
1571
1572            let size_in = cfg.hidden_size;
1573            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1574            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1575            let q_proj = size_in * size_q / weight_pack_factor;
1576            let k_proj = size_in * size_kv / weight_pack_factor;
1577            let v_proj = size_in * size_kv / weight_pack_factor;
1578            let o_proj = size_q * size_in / weight_pack_factor;
1579
1580            let h_size = cfg.hidden_size;
1581            let i_size = cfg.intermediate_size;
1582            let gate_proj = h_size * i_size / weight_pack_factor;
1583            let up_proj = h_size * i_size / weight_pack_factor;
1584            let down_proj = i_size * h_size / weight_pack_factor;
1585
1586            input_layernorm
1587                + post_attention_layernorm
1588                + q_proj
1589                + k_proj
1590                + v_proj
1591                + o_proj
1592                + gate_proj
1593                + up_proj
1594                + down_proj
1595        };
1596        Ok(vec![
1597            per_layer_elems * dtype.size_in_bytes();
1598            cfg.text_config.num_hidden_layers
1599        ])
1600    }
1601
1602    fn num_layers(&self, config: &str) -> Result<usize> {
1603        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1604        Ok(cfg.text_config.num_hidden_layers)
1605    }
1606
1607    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1608        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1609        let cfg = &cfg.text_config;
1610
1611        let cfg = ModelConfigMetadata {
1612            max_seq_len: cfg.max_position_embeddings,
1613            num_layers: cfg.num_hidden_layers,
1614            hidden_size: cfg.hidden_size,
1615            num_kv_heads: cfg.num_key_value_heads,
1616            num_attn_heads: cfg.num_attention_heads,
1617            sliding_window: cfg.sliding_window,
1618            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1619            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1620        };
1621
1622        Ok(Box::new(cfg))
1623    }
1624
1625    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1626        Some(vec![NonMappedSubModel::Vision])
1627    }
1628}
1629
1630// ======================== MLlama Loader
1631
1632/// [`VisionLoader`] for an Llama Vision model.
1633///
1634/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1635pub struct VLlamaLoader;
1636
1637pub struct VLlamaPrefixer;
1638
1639impl MultimodalPromptPrefixer for VLlamaPrefixer {
1640    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1641        format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1642    }
1643}
1644
1645impl VisionModelLoader for VLlamaLoader {
1646    fn load(
1647        &self,
1648        config: &str,
1649        vb: ShardedVarBuilder,
1650        normal_loading_metadata: NormalLoadingMetadata,
1651        attention_mechanism: AttentionImplementation,
1652    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1653        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1654        Ok(Box::new(MLlamaModel::new(
1655            &cfg,
1656            vb,
1657            self.is_gptx(config),
1658            normal_loading_metadata,
1659            attention_mechanism,
1660        )?))
1661    }
1662    fn is_gptx(&self, _config: &str) -> bool {
1663        true
1664    }
1665    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1666        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1667        Ok(Box::new(cfg))
1668    }
1669    fn get_processor(
1670        &self,
1671        _model_config: &str,
1672        _processor_config: Option<ProcessorConfig>,
1673        _preprocessor_config: PreProcessorConfig,
1674        _max_edge: Option<u32>,
1675    ) -> Arc<dyn Processor + Send + Sync> {
1676        Arc::new(MLlamaProcessor::new())
1677    }
1678    fn supports_paged_attention(&self, _config: &str) -> bool {
1679        false
1680    }
1681    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1682        true
1683    }
1684    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1685        Arc::new(VLlamaPrefixer)
1686    }
1687    fn modalities(&self, _config: &str) -> Result<Modalities> {
1688        Ok(Modalities {
1689            input: vec![SupportedModality::Text, SupportedModality::Vision],
1690            output: vec![SupportedModality::Text],
1691        })
1692    }
1693}
1694
1695impl IsqModelLoader for VLlamaLoader {
1696    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1697        let config: MLlamaConfig = serde_json::from_str(config)?;
1698        let cross_attn_layers = &config.text_config.cross_attention_layers;
1699        let transformer_layers =
1700            (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1701        let mut text_regexes = Vec::new();
1702        for layer in transformer_layers {
1703            text_regexes.extend(vec![
1704                // Attention text
1705                Regex::new(&format!(
1706                    r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1707                ))?,
1708                Regex::new(&format!(
1709                    r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1710                ))?,
1711                Regex::new(&format!(
1712                    r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1713                ))?,
1714                Regex::new(&format!(
1715                    r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1716                ))?,
1717                // MLP text
1718                Regex::new(&format!(
1719                    r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1720                ))?,
1721                Regex::new(&format!(
1722                    r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1723                ))?,
1724                Regex::new(&format!(
1725                    r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1726                ))?,
1727            ]);
1728        }
1729        let vision_regexes = vec![
1730            // Vision attention (transformer)
1731            Regex::new(
1732                r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1733            )?,
1734            Regex::new(
1735                r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1736            )?,
1737            Regex::new(
1738                r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1739            )?,
1740            Regex::new(
1741                r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1742            )?,
1743            // Vision attention (global transforemr)
1744            Regex::new(
1745                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1746            )?,
1747            Regex::new(
1748                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1749            )?,
1750            Regex::new(
1751                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1752            )?,
1753            Regex::new(
1754                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1755            )?,
1756            // MLP vision
1757            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1758            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1759        ];
1760
1761        Ok([text_regexes, vision_regexes].concat())
1762    }
1763    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1764        self.isq_layer_regexes(config)
1765    }
1766}
1767
1768impl DeviceMappedModelLoader for VLlamaLoader {
1769    fn mapped_max_act_size_elems(
1770        &self,
1771        config: &str,
1772        params: &AutoDeviceMapParams,
1773        _prompt_chunksize: usize,
1774    ) -> Result<usize> {
1775        let AutoDeviceMapParams::Vision {
1776            max_seq_len,
1777            max_batch_size,
1778            max_image_shape: _,
1779            max_num_images,
1780        } = params
1781        else {
1782            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1783        };
1784
1785        let config: MLlamaConfig = serde_json::from_str(config)?;
1786
1787        let img_seq_len = {
1788            let cfg = &config.vision_config;
1789            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1790            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1791            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1792        };
1793        let img_seq_len = img_seq_len * max_num_images;
1794
1795        let max_cross_text_attn = {
1796            let cfg = &config.text_config;
1797            max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1798        };
1799
1800        let max_self_text_attn = {
1801            let cfg = &config.text_config;
1802            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1803        };
1804
1805        Ok(max_self_text_attn.max(max_cross_text_attn))
1806    }
1807
1808    fn non_mapped_max_act_size_elems(
1809        &self,
1810        config: &str,
1811        params: &AutoDeviceMapParams,
1812    ) -> Result<usize> {
1813        let AutoDeviceMapParams::Vision {
1814            max_seq_len: _,
1815            max_batch_size,
1816            max_image_shape: _,
1817            max_num_images,
1818        } = params
1819        else {
1820            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1821        };
1822
1823        let config: MLlamaConfig = serde_json::from_str(config)?;
1824
1825        let img_seq_len = {
1826            let cfg = &config.vision_config;
1827            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1828            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1829            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1830        };
1831        let max_vision_attn = {
1832            let cfg = &config.vision_config;
1833            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1834        };
1835
1836        Ok(max_vision_attn)
1837    }
1838
1839    fn non_mapped_size_in_bytes(
1840        &self,
1841        config: &str,
1842        dtype: DType,
1843        weight_pack_factor: usize,
1844        _matformer_config: Option<&MatformerSliceConfig>,
1845    ) -> Result<usize> {
1846        let config: MLlamaConfig = serde_json::from_str(config)?;
1847        let text_elems = {
1848            let cfg = &config.text_config;
1849            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1850            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1851            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1852                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1853            } else {
1854                0
1855            };
1856            let norm = cfg.hidden_size;
1857            embed_tokens + lm_head + norm
1858        };
1859
1860        let vision_elems = {
1861            let cfg = &config.vision_config;
1862
1863            let conv_cfg = Conv2dConfig {
1864                stride: cfg.patch_size,
1865                ..Default::default()
1866            };
1867            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1868                * cfg.patch_size
1869                * cfg.patch_size;
1870
1871            let class_embedding = cfg.hidden_size;
1872
1873            let gated_positional_embedding = {
1874                let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1875                let embedding = num_patches * cfg.hidden_size;
1876                let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1877                    * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1878
1879                embedding + tile_embedding
1880            };
1881
1882            let pre_tile_positional_embedding =
1883                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1884            let post_tile_positional_embedding =
1885                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1886
1887            let layernorm_pre = cfg.hidden_size;
1888            let layernorm_post = cfg.hidden_size;
1889
1890            let encoder_layer = {
1891                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1892                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1893
1894                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1895                let q_proj =
1896                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1897                let k_proj =
1898                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1899                let v_proj =
1900                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1901                let o_proj =
1902                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1903
1904                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1905                    + cfg.intermediate_size;
1906                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1907                    + cfg.hidden_size;
1908
1909                input_layernorm
1910                    + post_attention_layernorm
1911                    + q_proj
1912                    + k_proj
1913                    + v_proj
1914                    + o_proj
1915                    + fc1
1916                    + fc2
1917            };
1918
1919            patch_embedding
1920                + class_embedding
1921                + gated_positional_embedding
1922                + pre_tile_positional_embedding
1923                + post_tile_positional_embedding
1924                + layernorm_pre
1925                + layernorm_post
1926                + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1927        };
1928
1929        let elems = text_elems + vision_elems;
1930        Ok(elems * dtype.size_in_bytes())
1931    }
1932
1933    fn layer_sizes_in_bytes(
1934        &self,
1935        config: &str,
1936        dtype: DType,
1937        weight_pack_factor: usize,
1938        _matformer_config: Option<&MatformerSliceConfig>,
1939    ) -> Result<Vec<usize>> {
1940        let config: MLlamaConfig = serde_json::from_str(config)?;
1941        let cfg = &config.text_config;
1942
1943        let mut layer_sizes = Vec::new();
1944
1945        for i in 0..cfg.num_hidden_layers {
1946            let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1947                // No isq for cross attention
1948                1
1949            } else {
1950                weight_pack_factor
1951            };
1952
1953            let per_layer_elems = {
1954                let input_layernorm = cfg.hidden_size;
1955                let post_attention_layernorm = cfg.hidden_size;
1956
1957                let size_in = cfg.hidden_size;
1958                let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1959                let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1960                let q_proj = size_in * size_q / weight_pack_factor;
1961                let k_proj = size_in * size_kv / weight_pack_factor;
1962                let v_proj = size_in * size_kv / weight_pack_factor;
1963                let o_proj = size_q * size_in / weight_pack_factor;
1964
1965                let h_size = cfg.hidden_size;
1966                let i_size = cfg.intermediate_size;
1967                let gate_proj = h_size * i_size / weight_pack_factor;
1968                let up_proj = h_size * i_size / weight_pack_factor;
1969                let down_proj = i_size * h_size / weight_pack_factor;
1970
1971                input_layernorm
1972                    + post_attention_layernorm
1973                    + q_proj
1974                    + k_proj
1975                    + v_proj
1976                    + o_proj
1977                    + gate_proj
1978                    + up_proj
1979                    + down_proj
1980            };
1981
1982            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1983        }
1984
1985        Ok(layer_sizes)
1986    }
1987
1988    fn num_layers(&self, config: &str) -> Result<usize> {
1989        let config: MLlamaConfig = serde_json::from_str(config)?;
1990        Ok(config.text_config.num_hidden_layers)
1991    }
1992
1993    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1994        let cfg: MLlamaConfig = serde_json::from_str(config)?;
1995        let cfg = &cfg.text_config;
1996
1997        let cfg = ModelConfigMetadata {
1998            max_seq_len: cfg.max_position_embeddings,
1999            num_layers: cfg.num_hidden_layers,
2000            hidden_size: cfg.hidden_size,
2001            num_kv_heads: cfg.num_key_value_heads,
2002            num_attn_heads: cfg.num_attention_heads,
2003            sliding_window: None,
2004            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2005            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2006        };
2007
2008        Ok(Box::new(cfg))
2009    }
2010
2011    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2012        Some(vec![NonMappedSubModel::Vision])
2013    }
2014}
2015
2016// ======================== Qwen2VL Loader
2017
2018/// [`VisionLoader`] for an Qwen2-VL model.
2019///
2020/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2021pub struct Qwen2VLLoader;
2022
2023pub struct Qwen2VLPrefixer;
2024
2025impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
2026    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2027        format!(
2028            "{}{prompt}",
2029            format!(
2030                "{}{}{}",
2031                Qwen2VLProcessor::VISION_START,
2032                Qwen2VLProcessor::IMAGE_PAD,
2033                Qwen2VLProcessor::VISION_END
2034            )
2035            .repeat(image_indexes.len())
2036        )
2037    }
2038}
2039
2040impl VisionModelLoader for Qwen2VLLoader {
2041    fn load(
2042        &self,
2043        config: &str,
2044        vb: ShardedVarBuilder,
2045        normal_loading_metadata: NormalLoadingMetadata,
2046        attention_mechanism: AttentionImplementation,
2047    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2048        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2049        Ok(Box::new(Qwen2VLModel::new(
2050            &cfg,
2051            vb,
2052            self.is_gptx(config),
2053            normal_loading_metadata,
2054            attention_mechanism,
2055        )?))
2056    }
2057    fn is_gptx(&self, _config: &str) -> bool {
2058        true
2059    }
2060    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2061        let config: Qwen2VLConfig = serde_json::from_str(config)?;
2062        Ok(Box::new(config))
2063    }
2064    fn get_processor(
2065        &self,
2066        _model_config: &str,
2067        _processor_config: Option<ProcessorConfig>,
2068        _preprocessor_config: PreProcessorConfig,
2069        max_edge: Option<u32>,
2070    ) -> Arc<dyn Processor + Send + Sync> {
2071        Arc::new(Qwen2VLProcessor::new(max_edge))
2072    }
2073    fn supports_paged_attention(&self, _config: &str) -> bool {
2074        false
2075    }
2076    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2077        Arc::new(Qwen2VLPrefixer)
2078    }
2079    fn modalities(&self, _config: &str) -> Result<Modalities> {
2080        Ok(Modalities {
2081            input: vec![SupportedModality::Text, SupportedModality::Vision],
2082            output: vec![SupportedModality::Text],
2083        })
2084    }
2085}
2086
2087impl IsqModelLoader for Qwen2VLLoader {
2088    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2089        Ok(vec![
2090            Regex::new(r"lm_head\.(weight|bias)$")?,
2091            // Attention
2092            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2093            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2094            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2095            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2096            // MLP
2097            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2098            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2099            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2100        ])
2101    }
2102    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2103        self.isq_layer_regexes(config)
2104    }
2105}
2106
2107impl DeviceMappedModelLoader for Qwen2VLLoader {
2108    fn mapped_max_act_size_elems(
2109        &self,
2110        config: &str,
2111        params: &AutoDeviceMapParams,
2112        _prompt_chunksize: usize,
2113    ) -> Result<usize> {
2114        let AutoDeviceMapParams::Vision {
2115            max_seq_len,
2116            max_batch_size,
2117            max_image_shape,
2118            max_num_images,
2119        } = params
2120        else {
2121            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2122        };
2123
2124        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2125
2126        let img_seq_len = {
2127            let cfg = &cfg.vision_config;
2128            let grid_t = max_num_images / cfg.temporal_patch_size;
2129            let grid_h = max_image_shape.0 / cfg.patch_size;
2130            let grid_w = max_image_shape.1 / cfg.patch_size;
2131            grid_t * grid_h * grid_w
2132        };
2133        let img_seq_len = img_seq_len * max_num_images;
2134
2135        let max_text_attn = {
2136            // This model injects the vision information directly into the input embeddings
2137            let max_seq_len = img_seq_len + max_seq_len;
2138            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2139        };
2140
2141        Ok(max_text_attn)
2142    }
2143
2144    fn non_mapped_max_act_size_elems(
2145        &self,
2146        config: &str,
2147        params: &AutoDeviceMapParams,
2148    ) -> Result<usize> {
2149        let AutoDeviceMapParams::Vision {
2150            max_seq_len: _,
2151            max_batch_size,
2152            max_image_shape,
2153            max_num_images,
2154        } = params
2155        else {
2156            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2157        };
2158
2159        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2160
2161        let img_seq_len = {
2162            let cfg = &cfg.vision_config;
2163            let grid_t = max_num_images / cfg.temporal_patch_size;
2164            let grid_h = max_image_shape.0 / cfg.patch_size;
2165            let grid_w = max_image_shape.1 / cfg.patch_size;
2166            grid_t * grid_h * grid_w
2167        };
2168
2169        let max_vision_attn = {
2170            let cfg = &cfg.vision_config;
2171            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2172        };
2173
2174        Ok(max_vision_attn)
2175    }
2176
2177    fn non_mapped_size_in_bytes(
2178        &self,
2179        config: &str,
2180        dtype: DType,
2181        weight_pack_factor: usize,
2182        _matformer_config: Option<&MatformerSliceConfig>,
2183    ) -> Result<usize> {
2184        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2185        let text_elems = {
2186            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2187            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2188            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2189                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2190            } else {
2191                0
2192            };
2193            let norm = cfg.hidden_size;
2194            embed_tokens + lm_head + norm
2195        };
2196
2197        let patch_merger = {
2198            let cfg = &cfg.vision_config;
2199            let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2200
2201            let mlp0 = hidden_size * hidden_size + hidden_size;
2202            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2203
2204            let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2205
2206            mlp0 + mlp2 + ln_q
2207        };
2208
2209        let patch_embed = {
2210            let cfg = &cfg.vision_config;
2211            let conv_cfg = Conv3dConfig {
2212                stride: cfg.patch_size,
2213                ..Default::default()
2214            };
2215            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2216            cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2217                * kernel_sizes[0]
2218                * kernel_sizes[1]
2219                * kernel_sizes[2]
2220        };
2221
2222        let encoder_layer = {
2223            let cfg = &cfg.vision_config;
2224            let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2225            let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2226
2227            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2228            let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2229            let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2230            let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2231
2232            let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2233            let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2234
2235            norm1 + norm2 + fc1 + fc2 + qkv + out
2236        };
2237
2238        let elems =
2239            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2240
2241        Ok(elems * dtype.size_in_bytes())
2242    }
2243
2244    fn layer_sizes_in_bytes(
2245        &self,
2246        config: &str,
2247        dtype: DType,
2248        weight_pack_factor: usize,
2249        _matformer_config: Option<&MatformerSliceConfig>,
2250    ) -> Result<Vec<usize>> {
2251        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2252        let per_layer_elems = {
2253            let input_layernorm = cfg.hidden_size;
2254            let post_attention_layernorm = cfg.hidden_size;
2255
2256            let size_in = cfg.hidden_size;
2257            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2258            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2259            let q_proj = size_in * size_q / weight_pack_factor + size_q;
2260            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2261            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2262            let o_proj = size_q * size_in / weight_pack_factor;
2263
2264            let h_size = cfg.hidden_size;
2265            let i_size = cfg.intermediate_size;
2266            let gate_proj = h_size * i_size / weight_pack_factor;
2267            let up_proj = h_size * i_size / weight_pack_factor;
2268            let down_proj = i_size * h_size / weight_pack_factor;
2269
2270            input_layernorm
2271                + post_attention_layernorm
2272                + q_proj
2273                + k_proj
2274                + v_proj
2275                + o_proj
2276                + gate_proj
2277                + up_proj
2278                + down_proj
2279        };
2280        Ok(vec![
2281            per_layer_elems * dtype.size_in_bytes();
2282            cfg.num_hidden_layers
2283        ])
2284    }
2285
2286    fn num_layers(&self, config: &str) -> Result<usize> {
2287        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2288        Ok(cfg.num_hidden_layers)
2289    }
2290
2291    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2292        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2293
2294        let cfg = ModelConfigMetadata {
2295            max_seq_len: cfg.max_position_embeddings,
2296            num_layers: cfg.num_hidden_layers,
2297            hidden_size: cfg.hidden_size,
2298            num_kv_heads: cfg.num_key_value_heads,
2299            num_attn_heads: cfg.num_attention_heads,
2300            sliding_window: cfg.sliding_window,
2301            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2302            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2303        };
2304
2305        Ok(Box::new(cfg))
2306    }
2307
2308    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2309        Some(vec![NonMappedSubModel::Vision])
2310    }
2311}
2312
2313// ======================== Idefics 3 loader
2314
2315/// [`VisionLoader`] for an Idefics 3 Vision model.
2316///
2317/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2318pub struct Idefics3Loader;
2319
2320pub struct Idefics3Prefixer;
2321
2322impl MultimodalPromptPrefixer for Idefics3Prefixer {
2323    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2324        // Chat template does it
2325        prompt.to_string()
2326    }
2327}
2328
2329impl VisionModelLoader for Idefics3Loader {
2330    fn load(
2331        &self,
2332        config: &str,
2333        vb: ShardedVarBuilder,
2334        normal_loading_metadata: NormalLoadingMetadata,
2335        attention_mechanism: AttentionImplementation,
2336    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2337        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2338        Ok(Box::new(Idefics3Model::new(
2339            &cfg,
2340            vb,
2341            self.is_gptx(config),
2342            normal_loading_metadata,
2343            attention_mechanism,
2344        )?))
2345    }
2346    fn is_gptx(&self, _config: &str) -> bool {
2347        true
2348    }
2349    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2350        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2351        Ok(Box::new(cfg))
2352    }
2353    fn get_processor(
2354        &self,
2355        _model_config: &str,
2356        processor_config: Option<ProcessorConfig>,
2357        preprocessor_config: PreProcessorConfig,
2358        max_edge: Option<u32>,
2359    ) -> Arc<dyn Processor + Send + Sync> {
2360        Arc::new(Idefics3Processor::new(
2361            processor_config.unwrap_or_default(),
2362            preprocessor_config,
2363            max_edge,
2364        ))
2365    }
2366    fn supports_paged_attention(&self, _config: &str) -> bool {
2367        true
2368    }
2369    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2370        true
2371    }
2372    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2373        Arc::new(Idefics3Prefixer)
2374    }
2375    fn modalities(&self, _config: &str) -> Result<Modalities> {
2376        Ok(Modalities {
2377            input: vec![SupportedModality::Text, SupportedModality::Vision],
2378            output: vec![SupportedModality::Text],
2379        })
2380    }
2381}
2382
2383impl IsqModelLoader for Idefics3Loader {
2384    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2385        Ok(vec![
2386            Regex::new(r"lm_head\.(weight|bias)$")?,
2387            // Attention
2388            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2389            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2390            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2391            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2392            // MLP
2393            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2394            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2395            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2396        ])
2397    }
2398    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2399        Ok(vec![
2400            Regex::new(r"lm_head\.(weight|bias)$")?,
2401            // Attention
2402            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2403            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2404            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2405            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2406            // MLP
2407            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2408            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2409            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2410            // // Attention (vision)
2411            // Regex::new(
2412            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2413            // )?,
2414            // Regex::new(
2415            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
2416            // )?,
2417            // Regex::new(
2418            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
2419            // )?,
2420            // Regex::new(
2421            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(weight|bias)$",
2422            // )?,
2423            // MLP (vision)
2424            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2425            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
2426        ])
2427    }
2428}
2429
2430impl DeviceMappedModelLoader for Idefics3Loader {
2431    fn mapped_max_act_size_elems(
2432        &self,
2433        config: &str,
2434        params: &AutoDeviceMapParams,
2435        _prompt_chunksize: usize,
2436    ) -> Result<usize> {
2437        let AutoDeviceMapParams::Vision {
2438            max_seq_len,
2439            max_batch_size,
2440            max_image_shape: _,
2441            max_num_images,
2442        } = params
2443        else {
2444            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2445        };
2446
2447        let cfg: Idefics3Config = serde_json::from_str(config)?;
2448
2449        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2450        let img_seq_len = (num_patches + 1) * max_num_images;
2451
2452        let max_text_attn = {
2453            // This model injects the vision information directly into the input embeddings
2454            let max_seq_len = img_seq_len + max_seq_len;
2455            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2456        };
2457
2458        Ok(max_text_attn)
2459    }
2460
2461    fn non_mapped_max_act_size_elems(
2462        &self,
2463        config: &str,
2464        params: &AutoDeviceMapParams,
2465    ) -> Result<usize> {
2466        let AutoDeviceMapParams::Vision {
2467            max_seq_len: _,
2468            max_batch_size,
2469            max_image_shape: _,
2470            max_num_images,
2471        } = params
2472        else {
2473            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2474        };
2475
2476        let cfg: Idefics3Config = serde_json::from_str(config)?;
2477
2478        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2479        let img_seq_len = num_patches + 1;
2480
2481        let max_vision_attn = {
2482            // do_image_splitting = true
2483            let images_factor = 5;
2484
2485            (max_batch_size * images_factor * max_num_images)
2486                * cfg.vision_config.num_attention_heads
2487                * img_seq_len
2488                * img_seq_len
2489        };
2490
2491        Ok(max_vision_attn)
2492    }
2493
2494    fn non_mapped_size_in_bytes(
2495        &self,
2496        config: &str,
2497        dtype: DType,
2498        weight_pack_factor: usize,
2499        _matformer_config: Option<&MatformerSliceConfig>,
2500    ) -> Result<usize> {
2501        let cfg: Idefics3Config = serde_json::from_str(config)?;
2502        let text_elems = {
2503            let cfg = &cfg.text_config;
2504
2505            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2506            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2507            let norm = cfg.hidden_size;
2508            embed_tokens + lm_head + norm
2509        };
2510
2511        let connector_elems = {
2512            let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2513            let out_dim = cfg.text_config.hidden_size;
2514
2515            in_dim * out_dim
2516        };
2517
2518        let vision_transformer = {
2519            let cfg = &cfg.vision_config;
2520
2521            let post_layernorm = cfg.hidden_size;
2522
2523            let conv_config = Conv2dConfig {
2524                stride: cfg.patch_size,
2525                ..Default::default()
2526            };
2527            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2528                * cfg.patch_size
2529                * cfg.patch_size;
2530
2531            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2532            let num_patches = num_patches_per_side.pow(2);
2533            let position_embedding = num_patches * cfg.hidden_size;
2534
2535            let layer_elems = {
2536                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2537                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2538
2539                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2540                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2541
2542                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2543                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2544                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2545                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2546
2547                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2548            };
2549
2550            post_layernorm
2551                + patch_embedding
2552                + position_embedding
2553                + layer_elems * cfg.num_hidden_layers
2554        };
2555
2556        let elems = text_elems + connector_elems + vision_transformer;
2557
2558        Ok(elems * dtype.size_in_bytes())
2559    }
2560
2561    fn layer_sizes_in_bytes(
2562        &self,
2563        config: &str,
2564        dtype: DType,
2565        weight_pack_factor: usize,
2566        _matformer_config: Option<&MatformerSliceConfig>,
2567    ) -> Result<Vec<usize>> {
2568        let cfg: Idefics3Config = serde_json::from_str(config)?;
2569        let cfg = cfg.text_config;
2570        let per_layer_elems = {
2571            let input_layernorm = cfg.hidden_size;
2572            let post_attention_layernorm = cfg.hidden_size;
2573
2574            let size_in = cfg.hidden_size;
2575            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2576            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2577            let q_proj = size_in * size_q / weight_pack_factor;
2578            let k_proj = size_in * size_kv / weight_pack_factor;
2579            let v_proj = size_in * size_kv / weight_pack_factor;
2580            let o_proj = size_q * size_in / weight_pack_factor;
2581
2582            let h_size = cfg.hidden_size;
2583            let i_size = cfg.intermediate_size;
2584            let gate_proj = h_size * i_size / weight_pack_factor;
2585            let up_proj = h_size * i_size / weight_pack_factor;
2586            let down_proj = i_size * h_size / weight_pack_factor;
2587
2588            input_layernorm
2589                + post_attention_layernorm
2590                + q_proj
2591                + k_proj
2592                + v_proj
2593                + o_proj
2594                + gate_proj
2595                + up_proj
2596                + down_proj
2597        };
2598        Ok(vec![
2599            per_layer_elems * dtype.size_in_bytes();
2600            cfg.num_hidden_layers
2601        ])
2602    }
2603
2604    fn num_layers(&self, config: &str) -> Result<usize> {
2605        let cfg: Idefics3Config = serde_json::from_str(config)?;
2606        Ok(cfg.text_config.num_hidden_layers)
2607    }
2608    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2609        let cfg: Idefics3Config = serde_json::from_str(config)?;
2610        let cfg = &cfg.text_config;
2611
2612        let cfg = ModelConfigMetadata {
2613            max_seq_len: cfg.max_position_embeddings,
2614            num_layers: cfg.num_hidden_layers,
2615            hidden_size: cfg.hidden_size,
2616            num_kv_heads: cfg.num_key_value_heads,
2617            num_attn_heads: cfg.num_attention_heads,
2618            sliding_window: None,
2619            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2620            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2621        };
2622
2623        Ok(Box::new(cfg))
2624    }
2625
2626    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2627        Some(vec![NonMappedSubModel::Vision])
2628    }
2629}
2630
2631// ======================== MiniCpm-O loader
2632
2633/// [`VisionLoader`] for an MiniCpm-O model.
2634///
2635/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2636pub struct MiniCpmOLoader;
2637
2638pub struct MiniCpmOPrefixer;
2639
2640impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2641    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2642        format!(
2643            "{}{prompt}",
2644            "(<image>./</image>)".repeat(image_indexes.len())
2645        )
2646    }
2647}
2648
2649impl VisionModelLoader for MiniCpmOLoader {
2650    fn load(
2651        &self,
2652        config: &str,
2653        vb: ShardedVarBuilder,
2654        normal_loading_metadata: NormalLoadingMetadata,
2655        attention_mechanism: AttentionImplementation,
2656    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2657        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2658        Ok(Box::new(MiniCpmOModel::new(
2659            &cfg,
2660            vb,
2661            self.is_gptx(config),
2662            normal_loading_metadata,
2663            attention_mechanism,
2664        )?))
2665    }
2666    fn is_gptx(&self, _config: &str) -> bool {
2667        true
2668    }
2669    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2670        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2671        Ok(Box::new(cfg))
2672    }
2673    fn get_processor(
2674        &self,
2675        _model_config: &str,
2676        processor_config: Option<ProcessorConfig>,
2677        preprocessor_config: PreProcessorConfig,
2678        max_edge: Option<u32>,
2679    ) -> Arc<dyn Processor + Send + Sync> {
2680        Arc::new(MiniCpmOProcessor::new(
2681            processor_config.unwrap_or_default(),
2682            preprocessor_config,
2683            max_edge,
2684        ))
2685    }
2686    fn supports_paged_attention(&self, _config: &str) -> bool {
2687        true
2688    }
2689    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2690        Arc::new(MiniCpmOPrefixer)
2691    }
2692    fn modalities(&self, _config: &str) -> Result<Modalities> {
2693        Ok(Modalities {
2694            input: vec![SupportedModality::Text, SupportedModality::Vision],
2695            output: vec![SupportedModality::Text],
2696        })
2697    }
2698}
2699
2700impl IsqModelLoader for MiniCpmOLoader {
2701    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2702        Ok(vec![
2703            Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2704            // Attention
2705            Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2706            Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2707            Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2708            Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2709            // MLP
2710            Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2711            Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2712            Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2713        ])
2714    }
2715    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2716        self.isq_layer_regexes(config)
2717    }
2718}
2719
2720impl DeviceMappedModelLoader for MiniCpmOLoader {
2721    fn mapped_max_act_size_elems(
2722        &self,
2723        config: &str,
2724        params: &AutoDeviceMapParams,
2725        _prompt_chunksize: usize,
2726    ) -> Result<usize> {
2727        let AutoDeviceMapParams::Vision {
2728            max_seq_len,
2729            max_batch_size,
2730            max_image_shape: _,
2731            max_num_images,
2732        } = params
2733        else {
2734            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2735        };
2736
2737        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2738
2739        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2740        let img_seq_len = (num_patches + 1) * max_num_images;
2741
2742        let max_text_attn = {
2743            // This model injects the vision information directly into the input embeddings
2744            let max_seq_len = img_seq_len + max_seq_len;
2745            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2746        };
2747
2748        Ok(max_text_attn)
2749    }
2750
2751    fn non_mapped_max_act_size_elems(
2752        &self,
2753        config: &str,
2754        params: &AutoDeviceMapParams,
2755    ) -> Result<usize> {
2756        let AutoDeviceMapParams::Vision {
2757            max_seq_len: _,
2758            max_batch_size,
2759            max_image_shape: _,
2760            max_num_images,
2761        } = params
2762        else {
2763            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2764        };
2765
2766        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2767
2768        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2769        let img_seq_len = num_patches + 1;
2770
2771        let max_vision_attn = {
2772            // do_image_splitting = true
2773            let images_factor = 5;
2774
2775            (max_batch_size * images_factor * max_num_images)
2776                * cfg.vision_config.num_attention_heads
2777                * img_seq_len
2778                * img_seq_len
2779        };
2780
2781        Ok(max_vision_attn)
2782    }
2783
2784    fn non_mapped_size_in_bytes(
2785        &self,
2786        config: &str,
2787        dtype: DType,
2788        weight_pack_factor: usize,
2789        _matformer_config: Option<&MatformerSliceConfig>,
2790    ) -> Result<usize> {
2791        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2792        let text_elems = {
2793            let cfg = &cfg.text_config;
2794
2795            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2796            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2797            let norm = cfg.hidden_size;
2798            embed_tokens + lm_head + norm
2799        };
2800
2801        let vision_transformer = {
2802            let cfg = &cfg.vision_config;
2803
2804            let post_layernorm = cfg.hidden_size;
2805
2806            let conv_config = Conv2dConfig {
2807                stride: cfg.patch_size,
2808                ..Default::default()
2809            };
2810            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2811                * cfg.patch_size
2812                * cfg.patch_size;
2813
2814            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2815            let num_patches = num_patches_per_side.pow(2);
2816            let position_embedding = num_patches * cfg.hidden_size;
2817
2818            let layer_elems = {
2819                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2820                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2821
2822                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2823                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2824
2825                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2826                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2827                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2828                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2829
2830                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2831            };
2832
2833            post_layernorm
2834                + patch_embedding
2835                + position_embedding
2836                + layer_elems * cfg.num_hidden_layers
2837        };
2838
2839        let elems = text_elems + vision_transformer;
2840
2841        Ok(elems * dtype.size_in_bytes())
2842    }
2843
2844    fn layer_sizes_in_bytes(
2845        &self,
2846        config: &str,
2847        dtype: DType,
2848        weight_pack_factor: usize,
2849        _matformer_config: Option<&MatformerSliceConfig>,
2850    ) -> Result<Vec<usize>> {
2851        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2852        let cfg = cfg.text_config;
2853        let per_layer_elems = {
2854            let input_layernorm = cfg.hidden_size;
2855            let post_attention_layernorm = cfg.hidden_size;
2856
2857            let size_in = cfg.hidden_size;
2858            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2859            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2860            let q_proj = size_in * size_q / weight_pack_factor;
2861            let k_proj = size_in * size_kv / weight_pack_factor;
2862            let v_proj = size_in * size_kv / weight_pack_factor;
2863            let o_proj = size_q * size_in / weight_pack_factor;
2864
2865            let h_size = cfg.hidden_size;
2866            let i_size = cfg.intermediate_size;
2867            let gate_proj = h_size * i_size / weight_pack_factor;
2868            let up_proj = h_size * i_size / weight_pack_factor;
2869            let down_proj = i_size * h_size / weight_pack_factor;
2870
2871            input_layernorm
2872                + post_attention_layernorm
2873                + q_proj
2874                + k_proj
2875                + v_proj
2876                + o_proj
2877                + gate_proj
2878                + up_proj
2879                + down_proj
2880        };
2881        Ok(vec![
2882            per_layer_elems * dtype.size_in_bytes();
2883            cfg.num_hidden_layers
2884        ])
2885    }
2886
2887    fn num_layers(&self, config: &str) -> Result<usize> {
2888        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2889        Ok(cfg.text_config.num_hidden_layers)
2890    }
2891    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2892        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2893        let cfg = &cfg.text_config;
2894
2895        let cfg = ModelConfigMetadata {
2896            max_seq_len: cfg.max_position_embeddings,
2897            num_layers: cfg.num_hidden_layers,
2898            hidden_size: cfg.hidden_size,
2899            num_kv_heads: cfg.num_key_value_heads,
2900            num_attn_heads: cfg.num_attention_heads,
2901            sliding_window: None,
2902            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2903            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2904        };
2905
2906        Ok(Box::new(cfg))
2907    }
2908}
2909
2910// ======================== Phi 4MM loader
2911
2912/// [`VisionLoader`] for a Phi 4MM Vision model.
2913///
2914/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2915pub struct Phi4MMLoader;
2916
2917pub struct Phi4MMPrefixer;
2918
2919impl MultimodalPromptPrefixer for Phi4MMPrefixer {
2920    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2921        // Image indexing starts at 0.
2922
2923        format!(
2924            "{}{prompt}",
2925            image_indexes
2926                .into_iter()
2927                .map(|image_index| format!("<|image_{}|>", image_index + 1))
2928                .join("")
2929        )
2930    }
2931    fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
2932        // Image indexing starts at 0.
2933
2934        format!(
2935            "{}{prompt}",
2936            audio_indexes
2937                .into_iter()
2938                .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
2939                .join("")
2940        )
2941    }
2942}
2943
2944impl VisionModelLoader for Phi4MMLoader {
2945    fn load(
2946        &self,
2947        config: &str,
2948        vb: ShardedVarBuilder,
2949        normal_loading_metadata: NormalLoadingMetadata,
2950        attention_mechanism: AttentionImplementation,
2951    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2952        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2953        Ok(Box::new(Phi4MMModel::new(
2954            &cfg,
2955            vb,
2956            self.is_gptx(config),
2957            normal_loading_metadata,
2958            attention_mechanism,
2959        )?))
2960    }
2961    fn is_gptx(&self, _config: &str) -> bool {
2962        true
2963    }
2964    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2965        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2966        Ok(Box::new(cfg))
2967    }
2968    fn get_processor(
2969        &self,
2970        _model_config: &str,
2971        processor_config: Option<ProcessorConfig>,
2972        preprocessor_config: PreProcessorConfig,
2973        _max_edge: Option<u32>,
2974    ) -> Arc<dyn Processor + Send + Sync> {
2975        Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2976    }
2977    fn supports_paged_attention(&self, _config: &str) -> bool {
2978        true
2979    }
2980    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2981        true
2982    }
2983    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2984        Arc::new(Phi4MMPrefixer)
2985    }
2986    fn modalities(&self, _config: &str) -> Result<Modalities> {
2987        Ok(Modalities {
2988            input: vec![
2989                SupportedModality::Text,
2990                SupportedModality::Vision,
2991                SupportedModality::Audio,
2992            ],
2993            output: vec![SupportedModality::Text],
2994        })
2995    }
2996}
2997
2998impl IsqModelLoader for Phi4MMLoader {
2999    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3000        Ok(vec![
3001            Regex::new(r"lm_head\.(weight|bias)$")?,
3002            // Attention
3003            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
3004            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3005            // MLP
3006            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
3007            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3008        ])
3009    }
3010    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3011        self.isq_layer_regexes(config)
3012    }
3013}
3014
3015impl DeviceMappedModelLoader for Phi4MMLoader {
3016    fn mapped_max_act_size_elems(
3017        &self,
3018        config: &str,
3019        params: &AutoDeviceMapParams,
3020        _prompt_chunksize: usize,
3021    ) -> Result<usize> {
3022        // NOTE: we ignore max_num_images although it can only be one...
3023        let AutoDeviceMapParams::Vision {
3024            max_seq_len,
3025            max_batch_size,
3026            max_image_shape: _,
3027            max_num_images,
3028        } = params
3029        else {
3030            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3031        };
3032
3033        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3034
3035        let vcfg = &PHI4_MM_VISION_CFG;
3036
3037        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3038        let img_seq_len = (num_patches + 1) * max_num_images;
3039
3040        let max_text_attn = {
3041            // This model injects the vision information directly into the input embeddings
3042            let max_seq_len = img_seq_len + max_seq_len;
3043            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3044        };
3045
3046        Ok(max_text_attn)
3047    }
3048
3049    fn non_mapped_max_act_size_elems(
3050        &self,
3051        _config: &str,
3052        params: &AutoDeviceMapParams,
3053    ) -> Result<usize> {
3054        let AutoDeviceMapParams::Vision {
3055            max_seq_len: _,
3056            max_batch_size,
3057            max_image_shape,
3058            max_num_images,
3059        } = params
3060        else {
3061            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3062        };
3063
3064        let vcfg = &PHI4_MM_VISION_CFG;
3065
3066        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3067        let img_seq_len = num_patches + 1;
3068
3069        let max_batch_size = max_batch_size
3070            * (max_image_shape
3071                .0
3072                .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3073                * max_image_shape
3074                    .1
3075                    .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3076                + 1);
3077
3078        let max_vision_attn = (max_batch_size * max_num_images)
3079            * vcfg.num_attention_heads
3080            * img_seq_len
3081            * img_seq_len;
3082        let max_qkv = 3
3083            * (max_batch_size
3084                * vcfg.num_attention_heads
3085                * img_seq_len
3086                * (vcfg.hidden_size / vcfg.num_attention_heads));
3087
3088        Ok(max_vision_attn + max_qkv)
3089    }
3090
3091    fn non_mapped_size_in_bytes(
3092        &self,
3093        config: &str,
3094        dtype: DType,
3095        weight_pack_factor: usize,
3096        _matformer_config: Option<&MatformerSliceConfig>,
3097    ) -> Result<usize> {
3098        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3099        let elems = {
3100            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3101            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3102            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3103                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3104            } else {
3105                0
3106            };
3107            let norm = cfg.hidden_size;
3108
3109            let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3110                let projection_cls = img_embed
3111                    .projection_cls
3112                    .clone()
3113                    .unwrap_or("linear".to_string());
3114                let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3115                let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3116                let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3117
3118                let proj = match (projection_cls.as_str(), use_hd_transform) {
3119                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3120                    ("mlp", true) => {
3121                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3122                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3123                        a + b
3124                    }
3125                    ("mlp", false) => {
3126                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3127                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3128                        a + b
3129                    }
3130                    _ => {
3131                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3132                    }
3133                };
3134
3135                let (glb_gn, sub_gn) = if with_learnable_separator {
3136                    let glb_gn = image_dim_out * 4;
3137                    let sub_gn = image_dim_out * 4;
3138                    (glb_gn, sub_gn)
3139                } else {
3140                    (0, 0)
3141                };
3142
3143                let vision_transformer = {
3144                    let cfg = &PHI4_MM_VISION_CFG;
3145
3146                    let post_layernorm = cfg.hidden_size;
3147
3148                    let conv_config = Conv2dConfig {
3149                        stride: cfg.patch_size,
3150                        ..Default::default()
3151                    };
3152                    let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3153                        * cfg.patch_size
3154                        * cfg.patch_size;
3155
3156                    let num_patches_per_side = cfg.image_size / cfg.patch_size;
3157                    let num_patches = num_patches_per_side.pow(2);
3158                    let position_embedding = num_patches * cfg.hidden_size;
3159
3160                    let layer_elems = {
3161                        let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3162                        let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3163
3164                        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3165                        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3166
3167                        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3168                        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3169                        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3170                        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3171
3172                        layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3173                    };
3174
3175                    post_layernorm
3176                        + patch_embedding
3177                        + position_embedding
3178                        + layer_elems * cfg.num_hidden_layers
3179                };
3180
3181                proj + glb_gn + sub_gn + vision_transformer
3182            } else {
3183                0
3184            };
3185
3186            embed_tokens + lm_head + norm + image_embed
3187        };
3188
3189        Ok(elems * dtype.size_in_bytes())
3190    }
3191
3192    fn layer_sizes_in_bytes(
3193        &self,
3194        config: &str,
3195        dtype: DType,
3196        weight_pack_factor: usize,
3197        _matformer_config: Option<&MatformerSliceConfig>,
3198    ) -> Result<Vec<usize>> {
3199        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3200        let per_layer_elems = {
3201            let input_layernorm = cfg.hidden_size;
3202            let post_attention_layernorm = cfg.hidden_size;
3203
3204            let size_in = cfg.hidden_size;
3205            let head_dim = cfg.head_dim();
3206            let op_size =
3207                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3208            let qkv_proj = size_in * op_size / weight_pack_factor;
3209            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3210
3211            let h_size = cfg.hidden_size;
3212            let i_size = cfg.intermediate_size;
3213            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3214            let down_proj = h_size * i_size / weight_pack_factor;
3215
3216            input_layernorm
3217                + post_attention_layernorm
3218                + qkv_proj
3219                + o_proj
3220                + gate_up_proj
3221                + down_proj
3222        };
3223        Ok(vec![
3224            per_layer_elems * dtype.size_in_bytes();
3225            cfg.num_hidden_layers
3226        ])
3227    }
3228
3229    fn num_layers(&self, config: &str) -> Result<usize> {
3230        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3231        Ok(cfg.num_hidden_layers)
3232    }
3233
3234    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3235        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3236
3237        let cfg = ModelConfigMetadata {
3238            max_seq_len: cfg.max_position_embeddings,
3239            num_layers: cfg.num_hidden_layers,
3240            hidden_size: cfg.hidden_size,
3241            num_kv_heads: cfg.num_key_value_heads(),
3242            num_attn_heads: cfg.num_attention_heads,
3243            sliding_window: cfg.sliding_window,
3244            k_head_dim: cfg.head_dim(),
3245            v_head_dim: cfg.head_dim(),
3246        };
3247
3248        Ok(Box::new(cfg))
3249    }
3250
3251    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3252        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
3253    }
3254}
3255
3256// ======================== Qwen2_5VL Loader
3257
3258/// [`VisionLoader`] for an Qwen2_5VL model.
3259///
3260/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3261pub struct Qwen2_5VLLoader;
3262
3263pub struct Qwen2_5VLPrefixer;
3264
3265impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3266    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3267        format!(
3268            "{}{prompt}",
3269            format!(
3270                "{}{}{}",
3271                Qwen2_5VLProcessor::VISION_START,
3272                Qwen2_5VLProcessor::IMAGE_PAD,
3273                Qwen2_5VLProcessor::VISION_END
3274            )
3275            .repeat(image_indexes.len())
3276        )
3277    }
3278}
3279
3280impl VisionModelLoader for Qwen2_5VLLoader {
3281    fn load(
3282        &self,
3283        config: &str,
3284        vb: ShardedVarBuilder,
3285        normal_loading_metadata: NormalLoadingMetadata,
3286        attention_mechanism: AttentionImplementation,
3287    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3288        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3289        Ok(Box::new(Qwen2_5VLModel::new(
3290            &cfg,
3291            vb,
3292            self.is_gptx(config),
3293            normal_loading_metadata,
3294            attention_mechanism,
3295        )?))
3296    }
3297    fn is_gptx(&self, _config: &str) -> bool {
3298        true
3299    }
3300    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3301        let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3302        Ok(Box::new(config))
3303    }
3304    fn get_processor(
3305        &self,
3306        _model_config: &str,
3307        _processor_config: Option<ProcessorConfig>,
3308        _preprocessor_config: PreProcessorConfig,
3309        max_edge: Option<u32>,
3310    ) -> Arc<dyn Processor + Send + Sync> {
3311        Arc::new(Qwen2_5VLProcessor::new(max_edge))
3312    }
3313    fn supports_paged_attention(&self, _config: &str) -> bool {
3314        false
3315    }
3316    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3317        Arc::new(Qwen2_5VLPrefixer)
3318    }
3319    fn modalities(&self, _config: &str) -> Result<Modalities> {
3320        Ok(Modalities {
3321            input: vec![SupportedModality::Text, SupportedModality::Vision],
3322            output: vec![SupportedModality::Text],
3323        })
3324    }
3325}
3326
3327impl IsqModelLoader for Qwen2_5VLLoader {
3328    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3329        Ok(vec![
3330            Regex::new(r"lm_head\.(weight|bias)$")?,
3331            // Attention
3332            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3333            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3334            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3335            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3336            // MLP
3337            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3338            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3339            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3340        ])
3341    }
3342    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3343        self.isq_layer_regexes(config)
3344    }
3345}
3346
3347impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3348    fn mapped_max_act_size_elems(
3349        &self,
3350        config: &str,
3351        params: &AutoDeviceMapParams,
3352        _prompt_chunksize: usize,
3353    ) -> Result<usize> {
3354        let AutoDeviceMapParams::Vision {
3355            max_seq_len,
3356            max_batch_size,
3357            max_image_shape,
3358            max_num_images,
3359        } = params
3360        else {
3361            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3362        };
3363
3364        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3365
3366        let img_seq_len = {
3367            let cfg = &cfg.vision_config;
3368            let grid_t = max_num_images / cfg.temporal_patch_size;
3369            let grid_h = max_image_shape.0 / cfg.patch_size;
3370            let grid_w = max_image_shape.1 / cfg.patch_size;
3371            grid_t * grid_h * grid_w
3372        };
3373        let img_seq_len = img_seq_len * max_num_images;
3374
3375        let max_text_attn = {
3376            // This model injects the vision information directly into the input embeddings
3377            let max_seq_len = img_seq_len + max_seq_len;
3378            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3379        };
3380
3381        Ok(max_text_attn)
3382    }
3383
3384    fn non_mapped_max_act_size_elems(
3385        &self,
3386        config: &str,
3387        params: &AutoDeviceMapParams,
3388    ) -> Result<usize> {
3389        let AutoDeviceMapParams::Vision {
3390            max_seq_len: _,
3391            max_batch_size,
3392            max_image_shape,
3393            max_num_images,
3394        } = params
3395        else {
3396            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3397        };
3398
3399        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3400
3401        let img_seq_len = {
3402            let cfg = &cfg.vision_config;
3403            let grid_t = max_num_images / cfg.temporal_patch_size;
3404            let grid_h = max_image_shape.0 / cfg.patch_size;
3405            let grid_w = max_image_shape.1 / cfg.patch_size;
3406            grid_t * grid_h * grid_w
3407        };
3408
3409        let max_vision_attn = {
3410            let cfg = &cfg.vision_config;
3411            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3412        };
3413
3414        Ok(max_vision_attn)
3415    }
3416
3417    fn non_mapped_size_in_bytes(
3418        &self,
3419        config: &str,
3420        dtype: DType,
3421        weight_pack_factor: usize,
3422        _matformer_config: Option<&MatformerSliceConfig>,
3423    ) -> Result<usize> {
3424        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3425        let text_elems = {
3426            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3427            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3428            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3429                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3430            } else {
3431                0
3432            };
3433            let norm = cfg.hidden_size;
3434            embed_tokens + lm_head + norm
3435        };
3436
3437        let patch_merger = {
3438            let cfg = &cfg.vision_config;
3439            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3440
3441            let mlp0 = hidden_size * hidden_size + hidden_size;
3442            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3443
3444            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3445
3446            mlp0 + mlp2 + ln_q
3447        };
3448
3449        let patch_embed = {
3450            let cfg = &cfg.vision_config;
3451            let conv_cfg = Conv3dConfig {
3452                stride: cfg.patch_size,
3453                ..Default::default()
3454            };
3455            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3456            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3457                * kernel_sizes[0]
3458                * kernel_sizes[1]
3459                * kernel_sizes[2]
3460        };
3461
3462        let encoder_layer = {
3463            let cfg = &cfg.vision_config;
3464            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3465            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3466
3467            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3468            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3469            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3470
3471            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3472            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3473
3474            norm1 + norm2 + fc1 + fc2 + qkv + out
3475        };
3476
3477        let elems =
3478            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3479
3480        Ok(elems * dtype.size_in_bytes())
3481    }
3482
3483    fn layer_sizes_in_bytes(
3484        &self,
3485        config: &str,
3486        dtype: DType,
3487        weight_pack_factor: usize,
3488        _matformer_config: Option<&MatformerSliceConfig>,
3489    ) -> Result<Vec<usize>> {
3490        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3491        let per_layer_elems = {
3492            let input_layernorm = cfg.hidden_size;
3493            let post_attention_layernorm = cfg.hidden_size;
3494
3495            let size_in = cfg.hidden_size;
3496            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3497            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3498            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3499            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3500            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3501            let o_proj = size_q * size_in / weight_pack_factor;
3502
3503            let h_size = cfg.hidden_size;
3504            let i_size = cfg.intermediate_size;
3505            let gate_proj = h_size * i_size / weight_pack_factor;
3506            let up_proj = h_size * i_size / weight_pack_factor;
3507            let down_proj = i_size * h_size / weight_pack_factor;
3508
3509            input_layernorm
3510                + post_attention_layernorm
3511                + q_proj
3512                + k_proj
3513                + v_proj
3514                + o_proj
3515                + gate_proj
3516                + up_proj
3517                + down_proj
3518        };
3519        Ok(vec![
3520            per_layer_elems * dtype.size_in_bytes();
3521            cfg.num_hidden_layers
3522        ])
3523    }
3524
3525    fn num_layers(&self, config: &str) -> Result<usize> {
3526        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3527        Ok(cfg.num_hidden_layers)
3528    }
3529
3530    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3531        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3532
3533        let cfg = ModelConfigMetadata {
3534            max_seq_len: cfg.max_position_embeddings,
3535            num_layers: cfg.num_hidden_layers,
3536            hidden_size: cfg.hidden_size,
3537            num_kv_heads: cfg.num_key_value_heads,
3538            num_attn_heads: cfg.num_attention_heads,
3539            sliding_window: cfg.sliding_window,
3540            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3541            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3542        };
3543
3544        Ok(Box::new(cfg))
3545    }
3546
3547    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3548        Some(vec![NonMappedSubModel::Vision])
3549    }
3550}
3551
3552// ======================== Gemma 3 Loader
3553
3554/// [`VisionLoader`] for an Gemma 3 model.
3555///
3556/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3557pub struct Gemma3Loader;
3558
3559pub struct Gemma3Prefixer;
3560
3561impl MultimodalPromptPrefixer for Gemma3Prefixer {
3562    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3563        prompt.to_string()
3564    }
3565}
3566
3567impl VisionModelLoader for Gemma3Loader {
3568    fn load(
3569        &self,
3570        config: &str,
3571        vb: ShardedVarBuilder,
3572        normal_loading_metadata: NormalLoadingMetadata,
3573        attention_mechanism: AttentionImplementation,
3574    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3575        let cfg: Gemma3Config = serde_json::from_str(config)?;
3576        Ok(Box::new(Gemma3Model::new(
3577            &cfg,
3578            vb,
3579            self.is_gptx(config),
3580            normal_loading_metadata,
3581            attention_mechanism,
3582        )?))
3583    }
3584    fn is_gptx(&self, _config: &str) -> bool {
3585        true
3586    }
3587    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3588        let config: Gemma3Config = serde_json::from_str(config)?;
3589        Ok(Box::new(config))
3590    }
3591    fn get_processor(
3592        &self,
3593        config: &str,
3594        processor_config: Option<ProcessorConfig>,
3595        _preprocessor_config: PreProcessorConfig,
3596        _max_edge: Option<u32>,
3597    ) -> Arc<dyn Processor + Send + Sync> {
3598        let config: Gemma3Config = serde_json::from_str(config).unwrap();
3599        // Handle the Gemma 3 1b case here
3600        Arc::new(Gemma3Processor::new(
3601            processor_config.unwrap_or_default(),
3602            matches!(config, Gemma3Config::WithVision { .. }),
3603        ))
3604    }
3605    fn supports_paged_attention(&self, _config: &str) -> bool {
3606        true
3607    }
3608    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3609        true
3610    }
3611    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3612        Arc::new(Gemma3Prefixer)
3613    }
3614    fn modalities(&self, _config: &str) -> Result<Modalities> {
3615        Ok(Modalities {
3616            input: vec![SupportedModality::Text, SupportedModality::Vision],
3617            output: vec![SupportedModality::Text],
3618        })
3619    }
3620}
3621
3622impl IsqModelLoader for Gemma3Loader {
3623    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3624        Ok(vec![
3625            Regex::new(r"lm_head\.(weight|bias)$")?,
3626            // Attention
3627            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3628            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3629            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3630            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3631            // MLP
3632            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3633            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3634            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3635        ])
3636    }
3637    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3638        Ok(vec![
3639            Regex::new(r"lm_head\.(weight|bias)$")?,
3640            // Attention
3641            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3642            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3643            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3644            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3645            // MLP
3646            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3647            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3648            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3649        ])
3650    }
3651}
3652
3653impl DeviceMappedModelLoader for Gemma3Loader {
3654    fn mapped_max_act_size_elems(
3655        &self,
3656        config: &str,
3657        params: &AutoDeviceMapParams,
3658        prompt_chunksize: usize,
3659    ) -> Result<usize> {
3660        let AutoDeviceMapParams::Vision {
3661            max_seq_len,
3662            max_batch_size,
3663            max_image_shape: _,
3664            max_num_images,
3665        } = params
3666        else {
3667            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3668        };
3669
3670        let cfg: Gemma3Config = serde_json::from_str(config)?;
3671
3672        match cfg {
3673            Gemma3Config::Text(text_config) => Ok(max_batch_size
3674                * text_config.num_attention_heads
3675                * prompt_chunksize
3676                * prompt_chunksize),
3677            Gemma3Config::WithVision {
3678                text_config,
3679                vision_config,
3680                ..
3681            } => {
3682                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3683                let img_seq_len = (num_patches + 1) * max_num_images;
3684
3685                let max_text_attn = {
3686                    // This model injects the vision information directly into the input embeddings
3687                    let max_seq_len = img_seq_len + *max_seq_len;
3688                    max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3689                };
3690                Ok(max_text_attn)
3691            }
3692        }
3693    }
3694
3695    fn non_mapped_max_act_size_elems(
3696        &self,
3697        config: &str,
3698        params: &AutoDeviceMapParams,
3699    ) -> Result<usize> {
3700        let AutoDeviceMapParams::Vision {
3701            max_seq_len: _,
3702            max_batch_size,
3703            max_image_shape: _,
3704            max_num_images,
3705        } = params
3706        else {
3707            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3708        };
3709
3710        let cfg: Gemma3Config = serde_json::from_str(config)?;
3711
3712        match cfg {
3713            Gemma3Config::WithVision { vision_config, .. } => {
3714                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3715                let img_seq_len = num_patches + 1;
3716
3717                let max_vision_attn = {
3718                    (max_batch_size * max_num_images)
3719                        * vision_config.num_attention_heads
3720                        * img_seq_len
3721                        * img_seq_len
3722                };
3723
3724                Ok(max_vision_attn)
3725            }
3726            Gemma3Config::Text(_) => Ok(0),
3727        }
3728    }
3729
3730    fn non_mapped_size_in_bytes(
3731        &self,
3732        config: &str,
3733        dtype: DType,
3734        weight_pack_factor: usize,
3735        _matformer_config: Option<&MatformerSliceConfig>,
3736    ) -> Result<usize> {
3737        let cfg: Gemma3Config = serde_json::from_str(config)?;
3738
3739        let text_elems = {
3740            let cfg = match &cfg {
3741                Gemma3Config::Text(cfg) => cfg,
3742                Gemma3Config::WithVision { text_config, .. } => text_config,
3743            };
3744            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3745            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3746            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3747                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3748            } else {
3749                0
3750            };
3751            let norm = cfg.hidden_size;
3752            embed_tokens + lm_head + norm
3753        };
3754
3755        let vision_transformer = if let Gemma3Config::WithVision {
3756            vision_config: cfg, ..
3757        } = &cfg
3758        {
3759            let post_layernorm = cfg.hidden_size;
3760
3761            let conv_config = Conv2dConfig {
3762                stride: cfg.patch_size,
3763                ..Default::default()
3764            };
3765            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3766                * cfg.patch_size
3767                * cfg.patch_size;
3768
3769            let num_patches_per_side = cfg.image_size / cfg.patch_size;
3770            let num_patches = num_patches_per_side.pow(2);
3771            let position_embedding = num_patches * cfg.hidden_size;
3772
3773            let layer_elems = {
3774                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3775                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3776
3777                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3778                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3779
3780                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3781                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3782                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3783                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3784
3785                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3786            };
3787
3788            post_layernorm
3789                + patch_embedding
3790                + position_embedding
3791                + layer_elems * cfg.num_hidden_layers
3792        } else {
3793            0
3794        };
3795
3796        let elems = text_elems + vision_transformer;
3797
3798        Ok(elems * dtype.size_in_bytes())
3799    }
3800
3801    fn layer_sizes_in_bytes(
3802        &self,
3803        config: &str,
3804        dtype: DType,
3805        weight_pack_factor: usize,
3806        _matformer_config: Option<&MatformerSliceConfig>,
3807    ) -> Result<Vec<usize>> {
3808        let cfg: Gemma3Config = serde_json::from_str(config)?;
3809
3810        let txt_cfg = match &cfg {
3811            Gemma3Config::Text(cfg) => cfg,
3812            Gemma3Config::WithVision { text_config, .. } => text_config,
3813        };
3814        let per_layer_elems = {
3815            let cfg = txt_cfg;
3816
3817            let input_layernorm = cfg.hidden_size;
3818            let post_attention_layernorm = cfg.hidden_size;
3819
3820            let size_in = cfg.hidden_size;
3821            let size_q = cfg.head_dim * cfg.num_attention_heads;
3822            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3823            let q_proj =
3824                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3825            let k_proj =
3826                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3827            let v_proj =
3828                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3829            let o_proj =
3830                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3831
3832            let h_size = cfg.hidden_size;
3833            let i_size = cfg.intermediate_size;
3834            let gate_proj = h_size * i_size / weight_pack_factor;
3835            let up_proj = h_size * i_size / weight_pack_factor;
3836            let down_proj = i_size * h_size / weight_pack_factor;
3837
3838            input_layernorm
3839                + post_attention_layernorm
3840                + q_proj
3841                + k_proj
3842                + v_proj
3843                + o_proj
3844                + gate_proj
3845                + up_proj
3846                + down_proj
3847        };
3848        Ok(vec![
3849            per_layer_elems * dtype.size_in_bytes();
3850            txt_cfg.num_hidden_layers
3851        ])
3852    }
3853
3854    fn num_layers(&self, config: &str) -> Result<usize> {
3855        let cfg: Gemma3Config = serde_json::from_str(config)?;
3856
3857        let txt_cfg = match &cfg {
3858            Gemma3Config::Text(cfg) => cfg,
3859            Gemma3Config::WithVision { text_config, .. } => text_config,
3860        };
3861
3862        Ok(txt_cfg.num_hidden_layers)
3863    }
3864
3865    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3866        let cfg: Gemma3Config = serde_json::from_str(config)?;
3867
3868        let cfg = match &cfg {
3869            Gemma3Config::Text(cfg) => cfg,
3870            Gemma3Config::WithVision { text_config, .. } => text_config,
3871        };
3872
3873        let cfg = ModelConfigMetadata {
3874            max_seq_len: cfg.max_position_embeddings,
3875            num_layers: cfg.num_hidden_layers,
3876            hidden_size: cfg.hidden_size,
3877            num_kv_heads: cfg.num_key_value_heads,
3878            num_attn_heads: cfg.num_attention_heads,
3879            sliding_window: None, // None to be more forgiving, some do not
3880            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3881            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3882        };
3883
3884        Ok(Box::new(cfg))
3885    }
3886
3887    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3888        Some(vec![NonMappedSubModel::Vision])
3889    }
3890}
3891
3892// ======================== Mistral 3 Loader
3893
3894/// [`VisionLoader`] for an Mistral 3 model.
3895///
3896/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3897pub struct Mistral3Loader;
3898
3899pub struct Mistral3Prefixer;
3900
3901impl MultimodalPromptPrefixer for Mistral3Prefixer {
3902    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3903        prompt.to_string()
3904    }
3905}
3906
3907impl VisionModelLoader for Mistral3Loader {
3908    fn load(
3909        &self,
3910        config: &str,
3911        vb: ShardedVarBuilder,
3912        normal_loading_metadata: NormalLoadingMetadata,
3913        attention_mechanism: AttentionImplementation,
3914    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3915        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3916        Ok(Box::new(Mistral3Model::new(
3917            &cfg,
3918            vb,
3919            self.is_gptx(config),
3920            normal_loading_metadata,
3921            attention_mechanism,
3922        )?))
3923    }
3924    fn is_gptx(&self, _config: &str) -> bool {
3925        true
3926    }
3927    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3928        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3929        Ok(Box::new(cfg))
3930    }
3931    fn get_processor(
3932        &self,
3933        _model_config: &str,
3934        processor_config: Option<ProcessorConfig>,
3935        _preprocessor_config: PreProcessorConfig,
3936        _max_edge: Option<u32>,
3937    ) -> Arc<dyn Processor + Send + Sync> {
3938        Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3939    }
3940    fn supports_paged_attention(&self, _config: &str) -> bool {
3941        true
3942    }
3943    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3944        true
3945    }
3946    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3947        Arc::new(Mistral3Prefixer)
3948    }
3949    fn modalities(&self, _config: &str) -> Result<Modalities> {
3950        Ok(Modalities {
3951            input: vec![SupportedModality::Text, SupportedModality::Vision],
3952            output: vec![SupportedModality::Text],
3953        })
3954    }
3955}
3956
3957impl IsqModelLoader for Mistral3Loader {
3958    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3959        Ok(vec![
3960            Regex::new(r"lm_head\.(weight|bias)$")?,
3961            // Attention
3962            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3963            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3964            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3965            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3966            // MLP
3967            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3968            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3969            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3970        ])
3971    }
3972    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3973        Ok(vec![
3974            Regex::new(r"lm_head\.(weight|bias)$")?,
3975            // Attention
3976            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3977            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3978            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3979            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3980            // MLP
3981            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3982            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3983            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3984        ])
3985    }
3986}
3987
3988#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3989impl DeviceMappedModelLoader for Mistral3Loader {
3990    fn mapped_max_act_size_elems(
3991        &self,
3992        config: &str,
3993        params: &AutoDeviceMapParams,
3994        _prompt_chunksize: usize,
3995    ) -> Result<usize> {
3996        let cfg: Mistral3Config = serde_json::from_str(config)?;
3997        let vcfg = &cfg.vision_config;
3998        let tcfg = &cfg.text_config;
3999
4000        let AutoDeviceMapParams::Vision {
4001            max_seq_len,
4002            max_batch_size,
4003            max_image_shape: (mut height, mut width),
4004            max_num_images,
4005        } = params
4006        else {
4007            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4008        };
4009
4010        let img_seq_len = {
4011            // Reshaping algorithm
4012
4013            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
4014            let (max_height, max_width) = (1540, 1540);
4015            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4016            if ratio > 1. {
4017                height = (height as f64 / ratio).floor() as usize;
4018                width = (width as f64 / ratio).floor() as usize;
4019            }
4020
4021            let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
4022            let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
4023
4024            height = num_height_tokens * vcfg.patch_size;
4025            width = num_width_tokens * vcfg.patch_size;
4026
4027            let num_height_tokens = height / vcfg.patch_size;
4028            let num_width_tokens = width / vcfg.patch_size;
4029
4030            (num_width_tokens + 1) * num_height_tokens
4031        };
4032
4033        // This model injects the vision information directly into the input embeddings
4034        let max_seq_len = img_seq_len * max_num_images + *max_seq_len;
4035        Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
4036    }
4037
4038    fn non_mapped_max_act_size_elems(
4039        &self,
4040        config: &str,
4041        params: &AutoDeviceMapParams,
4042    ) -> Result<usize> {
4043        let cfg: Mistral3Config = serde_json::from_str(config)?;
4044        let cfg = &cfg.vision_config;
4045
4046        let AutoDeviceMapParams::Vision {
4047            max_seq_len: _,
4048            max_batch_size,
4049            max_image_shape: (mut height, mut width),
4050            max_num_images,
4051        } = params
4052        else {
4053            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4054        };
4055
4056        let img_seq_len = {
4057            // Reshaping algorithm
4058
4059            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
4060            let (max_height, max_width) = (1540, 1540);
4061            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4062            if ratio > 1. {
4063                height = (height as f64 / ratio).floor() as usize;
4064                width = (width as f64 / ratio).floor() as usize;
4065            }
4066
4067            let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4068            let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4069
4070            height = num_height_tokens * cfg.patch_size;
4071            width = num_width_tokens * cfg.patch_size;
4072
4073            let num_height_tokens = height / cfg.patch_size;
4074            let num_width_tokens = width / cfg.patch_size;
4075
4076            (num_width_tokens + 1) * num_height_tokens
4077        };
4078
4079        Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4080    }
4081
4082    fn non_mapped_size_in_bytes(
4083        &self,
4084        config: &str,
4085        dtype: DType,
4086        weight_pack_factor: usize,
4087        _matformer_config: Option<&MatformerSliceConfig>,
4088    ) -> Result<usize> {
4089        let cfg: Mistral3Config = serde_json::from_str(config)?;
4090
4091        let text_elems = {
4092            let cfg = &cfg.text_config;
4093
4094            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4095            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4096            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4097                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4098            } else {
4099                0
4100            };
4101            let norm = cfg.hidden_size;
4102            embed_tokens + lm_head + norm
4103        };
4104
4105        let vision_elems = {
4106            let cfg = &cfg.vision_config;
4107
4108            let patch_embed = {
4109                let conv_cfg = Conv2dConfig {
4110                    stride: cfg.patch_size,
4111                    ..Default::default()
4112                };
4113                cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4114                    * cfg.patch_size
4115                    * cfg.patch_size
4116                    * cfg.patch_size
4117            };
4118            let ln_pre = cfg.hidden_size;
4119            let vision_layer = {
4120                let attn_norm = cfg.hidden_size;
4121                let ffn_norm = cfg.hidden_size;
4122
4123                let gate = cfg.hidden_size * cfg.intermediate_size;
4124                let up = cfg.hidden_size * cfg.intermediate_size;
4125                let down = cfg.hidden_size * cfg.intermediate_size;
4126
4127                let q = cfg.hidden_size * cfg.hidden_size;
4128                let k = cfg.hidden_size * cfg.hidden_size;
4129                let v = cfg.hidden_size * cfg.hidden_size;
4130                let o = cfg.hidden_size * cfg.hidden_size;
4131
4132                attn_norm + ffn_norm + gate + up + down + q + k + v + o
4133            };
4134
4135            patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4136        };
4137
4138        let elems = text_elems + vision_elems;
4139
4140        Ok(elems * dtype.size_in_bytes())
4141    }
4142
4143    fn layer_sizes_in_bytes(
4144        &self,
4145        config: &str,
4146        dtype: DType,
4147        weight_pack_factor: usize,
4148        _matformer_config: Option<&MatformerSliceConfig>,
4149    ) -> Result<Vec<usize>> {
4150        let cfg: Mistral3Config = serde_json::from_str(config)?;
4151        let cfg = &cfg.text_config;
4152
4153        let per_layer_elems = {
4154            let input_layernorm = cfg.hidden_size;
4155            let post_attention_layernorm = cfg.hidden_size;
4156
4157            let size_in = cfg.hidden_size;
4158            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4159            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4160            let q_proj = size_in * size_q / weight_pack_factor;
4161            let k_proj = size_in * size_kv / weight_pack_factor;
4162            let v_proj = size_in * size_kv / weight_pack_factor;
4163            let o_proj = size_q * size_in / weight_pack_factor;
4164
4165            let h_size = cfg.hidden_size;
4166            let i_size = cfg.intermediate_size;
4167            let gate_proj = h_size * i_size / weight_pack_factor;
4168            let up_proj = h_size * i_size / weight_pack_factor;
4169            let down_proj = i_size * h_size / weight_pack_factor;
4170
4171            input_layernorm
4172                + post_attention_layernorm
4173                + q_proj
4174                + k_proj
4175                + v_proj
4176                + o_proj
4177                + gate_proj
4178                + up_proj
4179                + down_proj
4180        };
4181        Ok(vec![
4182            per_layer_elems * dtype.size_in_bytes();
4183            cfg.num_hidden_layers
4184        ])
4185    }
4186
4187    fn num_layers(&self, config: &str) -> Result<usize> {
4188        let cfg: Mistral3Config = serde_json::from_str(config)?;
4189        let cfg = &cfg.text_config;
4190        Ok(cfg.num_hidden_layers)
4191    }
4192
4193    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4194        let cfg: Mistral3Config = serde_json::from_str(config)?;
4195        let cfg = &cfg.text_config;
4196
4197        let cfg = ModelConfigMetadata {
4198            max_seq_len: cfg.max_position_embeddings,
4199            num_layers: cfg.num_hidden_layers,
4200            hidden_size: cfg.hidden_size,
4201            num_kv_heads: cfg.num_key_value_heads,
4202            num_attn_heads: cfg.num_attention_heads,
4203            sliding_window: cfg.sliding_window,
4204            k_head_dim: cfg.head_dim(),
4205            v_head_dim: cfg.head_dim(),
4206        };
4207
4208        Ok(Box::new(cfg))
4209    }
4210
4211    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4212        Some(vec![NonMappedSubModel::Vision])
4213    }
4214}
4215
4216// ======================== Llama 4 Loader
4217
4218/// [`VisionLoader`] for an Llama Vision model.
4219///
4220/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
4221pub struct VLlama4Loader;
4222
4223pub struct VLlama4Prefixer;
4224
4225impl MultimodalPromptPrefixer for VLlama4Prefixer {
4226    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4227        format!(
4228            "{}{prompt}",
4229            llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4230        )
4231    }
4232}
4233
4234impl VisionModelLoader for VLlama4Loader {
4235    fn load(
4236        &self,
4237        config: &str,
4238        vb: ShardedVarBuilder,
4239        normal_loading_metadata: NormalLoadingMetadata,
4240        attention_mechanism: AttentionImplementation,
4241    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4242        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4243        Ok(Box::new(Llama4Model::new(
4244            &cfg,
4245            vb,
4246            self.is_gptx(config),
4247            normal_loading_metadata,
4248            attention_mechanism,
4249        )?))
4250    }
4251    fn is_gptx(&self, _config: &str) -> bool {
4252        false
4253    }
4254    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4255        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4256        Ok(Box::new(cfg))
4257    }
4258    fn get_processor(
4259        &self,
4260        _model_config: &str,
4261        processor_config: Option<ProcessorConfig>,
4262        _preprocessor_config: PreProcessorConfig,
4263        _max_edge: Option<u32>,
4264    ) -> Arc<dyn Processor + Send + Sync> {
4265        Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4266    }
4267    fn supports_paged_attention(&self, _config: &str) -> bool {
4268        true
4269    }
4270    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4271        Arc::new(VLlama4Prefixer)
4272    }
4273    fn modalities(&self, _config: &str) -> Result<Modalities> {
4274        Ok(Modalities {
4275            input: vec![SupportedModality::Text, SupportedModality::Vision],
4276            output: vec![SupportedModality::Text],
4277        })
4278    }
4279}
4280
4281impl IsqModelLoader for VLlama4Loader {
4282    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4283        Ok(vec![
4284            Regex::new(r"lm_head\.(weight|bias)$")?,
4285            // Attention
4286            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4287            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4288            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4289            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4290            // FF MoE
4291            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4292            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4293            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4294            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4295            Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4296            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4297            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4298            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4299            // FF MLP
4300            Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4301            Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4302            Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4303        ])
4304    }
4305    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4306        Ok(vec![
4307            Regex::new(r"lm_head\.(weight|bias)$")?,
4308            // Attention
4309            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4310            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4311            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4312            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4313            // FF MoE
4314            Regex::new(
4315                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4316            )?,
4317            Regex::new(
4318                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4319            )?,
4320            Regex::new(
4321                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4322            )?,
4323            Regex::new(
4324                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4325            )?,
4326            Regex::new(
4327                r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4328            )?,
4329            Regex::new(
4330                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4331            )?,
4332            Regex::new(
4333                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4334            )?,
4335            Regex::new(
4336                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4337            )?,
4338            // FF MLP
4339            Regex::new(
4340                r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4341            )?,
4342            Regex::new(
4343                r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4344            )?,
4345            Regex::new(
4346                r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4347            )?,
4348        ])
4349    }
4350}
4351
4352impl VLlama4Loader {
4353    /// This incorporates the max batch size!
4354    /// Returns (pixels max batch size, num text image tokens)
4355    #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4356    fn run_dummy_processing(
4357        &self,
4358        cfg: &Llama4Config,
4359        height: usize,
4360        width: usize,
4361        max_num_images: usize,
4362        max_batch_size: usize,
4363    ) -> Result<(usize, usize)> {
4364        let cfg = &cfg.vision_config;
4365
4366        let img_processor =
4367            Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4368        let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4369        let res = img_processor.preprocess(
4370            vec![image; max_num_images],
4371            vec![],
4372            &PreProcessorConfig::default(),
4373            &Device::Cpu,
4374            (max_batch_size, max_num_images),
4375        )?;
4376
4377        let pixels_batch_size = res.pixel_values.dim(0)?;
4378        let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4379
4380        let (image_h, image_w) = (
4381            res.pixel_values.dim(D::Minus2).unwrap(),
4382            res.pixel_values.dim(D::Minus1).unwrap(),
4383        );
4384        let num_patches_per_chunk = (image_h / img_processor.patch_size)
4385            * (image_w / img_processor.patch_size)
4386            / img_processor.downsample_ratio;
4387
4388        Ok((
4389            pixels_max_batch_size,
4390            num_patches_per_chunk * pixels_max_batch_size,
4391        ))
4392    }
4393}
4394
4395impl DeviceMappedModelLoader for VLlama4Loader {
4396    fn mapped_max_act_size_elems(
4397        &self,
4398        config: &str,
4399        params: &AutoDeviceMapParams,
4400        _prompt_chunksize: usize,
4401    ) -> Result<usize> {
4402        let AutoDeviceMapParams::Vision {
4403            max_seq_len,
4404            max_batch_size,
4405            max_image_shape: (height, width),
4406            max_num_images,
4407        } = params
4408        else {
4409            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4410        };
4411
4412        let cfg: Llama4Config = serde_json::from_str(config)?;
4413
4414        let (_pixels_batch_size, num_text_image_toks) =
4415            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4416
4417        let max_seq_len = max_seq_len + num_text_image_toks;
4418
4419        Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4420    }
4421    fn non_mapped_max_act_size_elems(
4422        &self,
4423        config: &str,
4424        params: &AutoDeviceMapParams,
4425    ) -> Result<usize> {
4426        let AutoDeviceMapParams::Vision {
4427            max_seq_len: _,
4428            max_batch_size,
4429            max_image_shape: (height, width),
4430            max_num_images,
4431        } = params
4432        else {
4433            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4434        };
4435
4436        let cfg: Llama4Config = serde_json::from_str(config)?;
4437
4438        let (pixels_batch_size, _num_text_image_toks) =
4439            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4440        let max_seq_len = cfg.vision_config.num_patches();
4441
4442        Ok((max_batch_size * pixels_batch_size)
4443            * cfg.vision_config.num_attention_heads
4444            * max_seq_len
4445            * max_seq_len)
4446    }
4447
4448    fn non_mapped_size_in_bytes(
4449        &self,
4450        config: &str,
4451        dtype: DType,
4452        weight_pack_factor: usize,
4453        _matformer_config: Option<&MatformerSliceConfig>,
4454    ) -> Result<usize> {
4455        let cfg: Llama4Config = serde_json::from_str(config)?;
4456        let tcfg = &cfg.text_config;
4457
4458        let text_elems = {
4459            let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4460            let lm_head = if !tcfg.tie_word_embeddings {
4461                tcfg.hidden_size * tcfg.vocab_size
4462            } else {
4463                0
4464            };
4465            let norm = tcfg.hidden_size;
4466            embed_tokens + lm_head + norm
4467        };
4468
4469        let vision_elems = {
4470            let cfg = &cfg.vision_config;
4471
4472            let num_patches = cfg.num_patches();
4473
4474            let unfold_elems =
4475                (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4476            let class_embeddng_elems = cfg.hidden_size;
4477            let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4478            let layernorm_pre_elems = cfg.hidden_size;
4479            let layernorm_post_elems = cfg.hidden_size;
4480
4481            let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4482                / weight_pack_factor
4483                + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4484
4485            let encoder_layer = {
4486                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4487                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4488
4489                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4490                let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4491                    / weight_pack_factor
4492                    + cfg.num_attention_heads * head_dim;
4493                let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4494                    / weight_pack_factor
4495                    + cfg.num_attention_heads * head_dim;
4496                let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4497                    / weight_pack_factor
4498                    + cfg.num_attention_heads * head_dim;
4499                let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4500                    / weight_pack_factor
4501                    + cfg.num_attention_heads * head_dim;
4502
4503                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4504                    + cfg.intermediate_size;
4505                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4506                    + cfg.hidden_size;
4507
4508                input_layernorm
4509                    + post_attention_layernorm
4510                    + q_proj
4511                    + k_proj
4512                    + v_proj
4513                    + o_proj
4514                    + fc1
4515                    + fc2
4516            };
4517
4518            unfold_elems
4519                + class_embeddng_elems
4520                + positional_embedding_vlm_elems
4521                + layernorm_post_elems
4522                + layernorm_pre_elems
4523                + pixel_shuffle_elems
4524                + encoder_layer * cfg.num_hidden_layers
4525        };
4526
4527        let elems = text_elems + vision_elems;
4528
4529        Ok(elems * dtype.size_in_bytes())
4530    }
4531
4532    fn layer_sizes_in_bytes(
4533        &self,
4534        config: &str,
4535        dtype: DType,
4536        weight_pack_factor: usize,
4537        _matformer_config: Option<&MatformerSliceConfig>,
4538    ) -> Result<Vec<usize>> {
4539        let cfg: Llama4Config = serde_json::from_str(config)?;
4540        let tcfg = &cfg.text_config;
4541
4542        let mut per_layer_elems = Vec::new();
4543
4544        for layer_idx in 0..tcfg.num_hidden_layers {
4545            let input_layernorm = tcfg.hidden_size;
4546            let post_attention_layernorm = tcfg.hidden_size;
4547
4548            let size_in = tcfg.hidden_size;
4549            let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4550            let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4551            let q_proj = size_in * size_q / weight_pack_factor;
4552            let k_proj = size_in * size_kv / weight_pack_factor;
4553            let v_proj = size_in * size_kv / weight_pack_factor;
4554            let o_proj = size_q * size_in / weight_pack_factor;
4555
4556            let use_moe = tcfg.moe_layers().contains(&layer_idx);
4557            let moe_block = if use_moe {
4558                let h_size = tcfg.hidden_size;
4559                let i_size = tcfg.intermediate_size;
4560                let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4561                let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4562                let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4563
4564                gate_proj + up_proj + down_proj
4565            } else {
4566                let h_size = tcfg.hidden_size;
4567                let i_size = tcfg.intermediate_size_mlp;
4568                let gate_proj = h_size * i_size / weight_pack_factor;
4569                let up_proj = h_size * i_size / weight_pack_factor;
4570                let down_proj = i_size * h_size / weight_pack_factor;
4571
4572                gate_proj + up_proj + down_proj
4573            };
4574
4575            per_layer_elems.push(
4576                input_layernorm
4577                    + post_attention_layernorm
4578                    + q_proj
4579                    + k_proj
4580                    + v_proj
4581                    + o_proj
4582                    + moe_block,
4583            );
4584        }
4585
4586        Ok(per_layer_elems
4587            .into_iter()
4588            .map(|x| x * dtype.size_in_bytes())
4589            .collect())
4590    }
4591
4592    fn num_layers(&self, config: &str) -> Result<usize> {
4593        let cfg: Llama4Config = serde_json::from_str(config)?;
4594        Ok(cfg.text_config.num_hidden_layers)
4595    }
4596
4597    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4598        let cfg: Llama4Config = serde_json::from_str(config)?;
4599        let cfg = &cfg.text_config;
4600
4601        let cfg = ModelConfigMetadata {
4602            max_seq_len: cfg.max_position_embeddings,
4603            num_layers: cfg.num_hidden_layers,
4604            hidden_size: cfg.hidden_size,
4605            num_kv_heads: cfg.num_attention_heads,
4606            num_attn_heads: cfg.num_attention_heads,
4607            sliding_window: None,
4608            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4609            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4610        };
4611
4612        Ok(Box::new(cfg))
4613    }
4614
4615    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4616        Some(vec![NonMappedSubModel::Vision])
4617    }
4618}
4619
4620// ======================== Gemma 3n Loader
4621
4622/// [`VisionLoader`] for an Gemma 3n model.
4623///
4624/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
4625pub struct Gemma3nLoader;
4626
4627pub struct Gemma3nPrefixer;
4628
4629impl MultimodalPromptPrefixer for Gemma3nPrefixer {
4630    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4631        prompt.to_string()
4632    }
4633}
4634
4635impl VisionModelLoader for Gemma3nLoader {
4636    fn load(
4637        &self,
4638        config: &str,
4639        vb: ShardedVarBuilder,
4640        normal_loading_metadata: NormalLoadingMetadata,
4641        attention_mechanism: AttentionImplementation,
4642    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4643        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4644        Ok(Box::new(Gemma3nModel::new(
4645            &cfg,
4646            vb,
4647            self.is_gptx(config),
4648            normal_loading_metadata,
4649            attention_mechanism,
4650        )?))
4651    }
4652    fn is_gptx(&self, _config: &str) -> bool {
4653        true
4654    }
4655    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4656        let config: Gemma3nConfig = serde_json::from_str(config)?;
4657        Ok(Box::new(config))
4658    }
4659    fn get_processor(
4660        &self,
4661        _config: &str,
4662        processor_config: Option<ProcessorConfig>,
4663        _preprocessor_config: PreProcessorConfig,
4664        _max_edge: Option<u32>,
4665    ) -> Arc<dyn Processor + Send + Sync> {
4666        // Handle the Gemma 3 1b case here
4667        Arc::new(Gemma3nProcessor::new(
4668            processor_config.unwrap_or_default(),
4669            true,
4670        ))
4671    }
4672    fn supports_paged_attention(&self, _config: &str) -> bool {
4673        false
4674    }
4675    fn supports_prefix_cacher(&self, _config: &str) -> bool {
4676        true
4677    }
4678    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4679        Arc::new(Gemma3Prefixer)
4680    }
4681    fn modalities(&self, _config: &str) -> Result<Modalities> {
4682        Ok(Modalities {
4683            input: vec![
4684                SupportedModality::Text,
4685                SupportedModality::Vision,
4686                SupportedModality::Audio,
4687            ],
4688            output: vec![SupportedModality::Text],
4689        })
4690    }
4691}
4692
4693impl IsqModelLoader for Gemma3nLoader {
4694    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4695        Ok(vec![
4696            Regex::new(r"lm_head\.(weight|bias)$")?,
4697            // Language model attention
4698            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4699            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4700            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4701            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4702            // Language model MLP
4703            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4704            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4705            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4706            // Audio conformer attention layers
4707            Regex::new(r"conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$")?,
4708            Regex::new(r"conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$")?,
4709            Regex::new(r"conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$")?,
4710            Regex::new(
4711                r"conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4712            )?,
4713            Regex::new(r"conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4714            // Audio conformer FFW layers
4715            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$")?,
4716            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$")?,
4717            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$")?,
4718            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$")?,
4719            // Audio conformer conv1d layers
4720            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$")?,
4721            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$")?,
4722            // Audio subsample projection
4723            Regex::new(r"subsample_conv_projection\.input_proj_linear\.(weight|bias)$")?,
4724            // Multimodal embedders
4725            Regex::new(r"embed_vision\.embedding_projection\.(weight|bias)$")?,
4726            Regex::new(r"embed_audio\.embedding_projection\.(weight|bias)$")?,
4727        ])
4728    }
4729    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4730        Ok(vec![
4731            Regex::new(r"lm_head\.(weight|bias)$")?,
4732            // Language model attention
4733            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4734            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4735            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4736            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4737            // Language model MLP
4738            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4739            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4740            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4741            // Projections
4742            Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
4743            Regex::new(r"model\.language_model\.altup_projections\.(\d+)\.(weight|bias)$")?,
4744            Regex::new(r"model\.language_model\.altup_unembed_projections\.(\d+)\.(weight|bias)$")?,
4745            // Audio conformer attention layers
4746            Regex::new(
4747                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$",
4748            )?,
4749            Regex::new(
4750                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$",
4751            )?,
4752            Regex::new(
4753                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$",
4754            )?,
4755            Regex::new(
4756                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4757            )?,
4758            Regex::new(r"model\.audio_tower\.conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4759            // Audio conformer FFW layers
4760            Regex::new(
4761                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$",
4762            )?,
4763            Regex::new(
4764                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$",
4765            )?,
4766            Regex::new(
4767                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$",
4768            )?,
4769            Regex::new(
4770                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$",
4771            )?,
4772            // Audio conformer conv1d layers
4773            Regex::new(
4774                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$",
4775            )?,
4776            Regex::new(
4777                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$",
4778            )?,
4779            // Audio subsample projection
4780            Regex::new(
4781                r"model\.audio_tower\.subsample_conv_projection\.input_proj_linear\.(weight|bias)$",
4782            )?,
4783            // Multimodal embedders
4784            Regex::new(r"model\.embed_vision\.embedding_projection\.(weight|bias)$")?,
4785            Regex::new(r"model\.embed_audio\.embedding_projection\.(weight|bias)$")?,
4786        ])
4787    }
4788}
4789
4790impl DeviceMappedModelLoader for Gemma3nLoader {
4791    fn mapped_max_act_size_elems(
4792        &self,
4793        config: &str,
4794        params: &AutoDeviceMapParams,
4795        _prompt_chunksize: usize,
4796    ) -> Result<usize> {
4797        let AutoDeviceMapParams::Vision {
4798            max_seq_len,
4799            max_batch_size,
4800            max_image_shape: _,
4801            max_num_images,
4802        } = params
4803        else {
4804            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4805        };
4806
4807        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4808        let text_cfg = &cfg.text_config;
4809
4810        // Gemma3n is an "inject into the prompt" model, similar to Gemma3
4811        // We need to account for vision and audio tokens in the sequence length
4812
4813        let mut total_seq_len = *max_seq_len;
4814
4815        // Add vision tokens
4816        {
4817            // Vision tokens are injected into the prompt
4818            // MSFA outputs fixed 16x16 features regardless of input size
4819            let msfa_spatial_size = 16; // Fixed from vision.rs line 1115
4820            let vision_tokens_per_image = msfa_spatial_size * msfa_spatial_size; // 256 tokens
4821            total_seq_len += vision_tokens_per_image * max_num_images;
4822        }
4823
4824        // Add audio tokens
4825        {
4826            // Audio tokens are injected into the prompt
4827            // From config field audio_soft_tokens_per_image (typically 188)
4828            let audio_tokens = cfg.audio_soft_tokens_per_image;
4829            total_seq_len += audio_tokens;
4830        }
4831
4832        // Calculate max attention size for text model with all injected tokens
4833        let max_text_attn =
4834            max_batch_size * text_cfg.num_attention_heads * total_seq_len * total_seq_len;
4835
4836        Ok(max_text_attn)
4837    }
4838
4839    fn non_mapped_max_act_size_elems(
4840        &self,
4841        config: &str,
4842        params: &AutoDeviceMapParams,
4843    ) -> Result<usize> {
4844        let AutoDeviceMapParams::Vision {
4845            max_seq_len: _,
4846            max_batch_size,
4847            max_image_shape: _,
4848            max_num_images,
4849        } = params
4850        else {
4851            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4852        };
4853
4854        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4855
4856        // Calculate max activation sizes for each modality
4857        let mut max_activation = 0;
4858
4859        // Vision activation size
4860        {
4861            // Vision is Gemma3n's MobileNetV5 architecture with Multi-Query Attention
4862            // The peak activation is in the Multi-Query Attention layers
4863
4864            // From the architecture: stages 3 and 4 have MMQA blocks
4865            // Input images are 768x768 (from inputs_processor.rs)
4866            // Stage 3: 640 channels at 48x48 (768/16 downsampling), MMQA with num_heads=12, kv_dim=64
4867            // Stage 4: 1280 channels at 24x24 (768/32 downsampling), MMQA with num_heads=16, kv_dim=96
4868            // MSFA output: 2048 channels at fixed 16x16
4869
4870            let vision_tower_act = {
4871                // Peak is during MMQA attention computation in stage 4
4872                // Stage 4 has higher memory usage than Stage 3 due to more heads (16 vs 12)
4873                // From vision.rs: Stage 4 has num_heads=16, kv_dim=96, kv_stride=1
4874                let num_heads = 16; // Stage 4 configuration
4875                let spatial_size = 24; // 768 / 32 = 24 (input 768x768, stage 4 has 32x downsampling)
4876                let seq_len = spatial_size * spatial_size;
4877
4878                // Attention scores: [B * num_images, num_heads, seq_len, seq_len]
4879                max_batch_size * max_num_images * num_heads * seq_len * seq_len
4880            };
4881
4882            // Vision embedder activations
4883            let vision_embed_act = {
4884                // MSFA output: 2048 channels at fixed 16x16 spatial (from vision.rs line 1115)
4885                let msfa_channels = 2048; // MSFA_OUT_CHANNELS from vision.rs
4886                let spatial_size = 16; // Fixed output resolution from MSFA
4887                let vision_features =
4888                    max_batch_size * max_num_images * msfa_channels * spatial_size * spatial_size;
4889
4890                // After embedding projection to text hidden size
4891                let projected = max_batch_size
4892                    * max_num_images
4893                    * spatial_size
4894                    * spatial_size
4895                    * cfg.text_config.hidden_size;
4896
4897                vision_features.max(projected)
4898            };
4899
4900            max_activation = max_activation.max(vision_tower_act).max(vision_embed_act);
4901        }
4902
4903        // Audio activation size
4904        {
4905            let audio_cfg = &cfg.audio_config;
4906
4907            // Calculate max audio sequence length based on config
4908            // Audio uses conformer with subsampling and reduction
4909
4910            // A rough estimate of max_audio_frames
4911            let max_audio_frames = 1280;
4912
4913            let subsample_factor: usize = audio_cfg
4914                .sscp_conv_stride_size
4915                .iter()
4916                .map(|stride| stride[0]) // Time dimension stride
4917                .product();
4918            let audio_seq_after_subsample = max_audio_frames / subsample_factor;
4919
4920            // Audio encoder activations
4921            let audio_encoder_act = {
4922                // Conformer FFW layers have expansion factor from config
4923                let intermediate_size = audio_cfg.hidden_size * 4; // FFW expansion factor
4924
4925                // Peak is in the FFW layers before reduction
4926                max_batch_size * audio_seq_after_subsample * intermediate_size
4927            };
4928
4929            // Audio attention activations
4930            let audio_attn_act = {
4931                // Attention uses chunked processing with specific context sizes
4932                let chunk_size = audio_cfg.conf_attention_chunk_size;
4933                let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
4934                    + audio_cfg.conf_attention_context_right;
4935
4936                // Peak is attention scores: [B, num_heads, num_chunks, chunk_size, context_size]
4937                let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
4938
4939                max_batch_size
4940                    * audio_cfg.conf_num_attention_heads
4941                    * num_chunks
4942                    * chunk_size
4943                    * context_size
4944            };
4945
4946            max_activation = max_activation.max(audio_encoder_act).max(audio_attn_act);
4947        }
4948
4949        Ok(max_activation)
4950    }
4951
4952    fn non_mapped_size_in_bytes(
4953        &self,
4954        config: &str,
4955        dtype: DType,
4956        weight_pack_factor: usize,
4957        matformer_config: Option<&MatformerSliceConfig>,
4958    ) -> Result<usize> {
4959        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4960
4961        // Apply matformer slicing if configured
4962        let text_cfg = if let Some(matformer_cfg) = matformer_config {
4963            use crate::device_map::DummyDeviceMapper;
4964            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
4965
4966            let dummy_mapper = DummyDeviceMapper {
4967                nm_device: Device::Cpu,
4968            };
4969            let (adjusted_cfg, _, _, _, _) = handle_matformer_slicing(
4970                &cfg.text_config,
4971                &Some(matformer_cfg.clone()),
4972                &dummy_mapper,
4973            )?;
4974            adjusted_cfg
4975        } else {
4976            cfg.text_config.clone()
4977        };
4978
4979        let text_cfg = &text_cfg;
4980
4981        // Text components that are not device-mapped
4982        let text_elems = {
4983            // Embeddings
4984            let embed_tokens = text_cfg.hidden_size * text_cfg.vocab_size;
4985            let embed_tokens_per_layer = text_cfg.num_hidden_layers
4986                * text_cfg.hidden_size_per_layer_input
4987                * text_cfg.vocab_size_per_layer_input;
4988
4989            // LM head (if not tied)
4990            let lm_head = if !text_cfg.tie_word_embeddings || weight_pack_factor != 1 {
4991                text_cfg.hidden_size * text_cfg.vocab_size / weight_pack_factor
4992            } else {
4993                0
4994            };
4995
4996            // Final layer norm
4997            let norm = text_cfg.hidden_size;
4998
4999            // AltUp projections (not device-mapped)
5000            let altup_projections =
5001                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5002                    / weight_pack_factor;
5003            let altup_unembed_projections =
5004                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5005                    / weight_pack_factor;
5006
5007            // Per-layer model projection
5008            let per_layer_model_projection = text_cfg.num_hidden_layers
5009                * text_cfg.hidden_size
5010                * text_cfg.hidden_size_per_layer_input
5011                / weight_pack_factor;
5012            let per_layer_projection_norm = text_cfg.hidden_size;
5013
5014            embed_tokens
5015                + embed_tokens_per_layer
5016                + lm_head
5017                + norm
5018                + altup_projections
5019                + altup_unembed_projections
5020                + per_layer_model_projection
5021                + per_layer_projection_norm
5022        };
5023
5024        // Vision components
5025        let vision_elems = {
5026            let vision_cfg = &cfg.vision_config;
5027            // Vision tower - calculated from actual Gemma3n architecture
5028            // NOTE: Vision tower uses only Conv2d layers, NOT Arc<dyn QuantMethod>,
5029            // so NONE of these should be divided by weight_pack_factor
5030            let vision_tower_elems = {
5031                use crate::vision_models::gemma3n::vision::{
5032                    gemma3n_mobilenet_def, make_divisible, BlockType, INPUT_CHANNELS,
5033                    MSFA_EXPANSION_RATIO, MSFA_IN_CHANNELS, MSFA_OUT_CHANNELS, STEM_KERNEL_SIZE,
5034                    STEM_OUT_CHANNELS,
5035                };
5036
5037                // Stem: ConvNormAct (Conv2d + RMSNorm)
5038                let stem_conv =
5039                    INPUT_CHANNELS * STEM_OUT_CHANNELS * STEM_KERNEL_SIZE * STEM_KERNEL_SIZE;
5040                let stem_norm = STEM_OUT_CHANNELS; // RMSNorm weight
5041
5042                // Track input channels through the network
5043                let mut in_chs = STEM_OUT_CHANNELS;
5044                let mut total_elems = stem_conv + stem_norm;
5045
5046                // Process all stages from gemma3n_mobilenet_def
5047                let block_defs = gemma3n_mobilenet_def();
5048
5049                for stage_blocks in block_defs.iter() {
5050                    for block_type in stage_blocks.iter() {
5051                        match block_type {
5052                            BlockType::EdgeResidual {
5053                                out_channels,
5054                                kernel_size,
5055                                stride: _,
5056                                expand_ratio,
5057                                ..
5058                            } => {
5059                                #[allow(clippy::cast_precision_loss)]
5060                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5061                                // EdgeResidual: all Conv2d layers, not quantizable
5062                                total_elems += in_chs * mid_chs * kernel_size * kernel_size; // conv_exp (Conv2d)
5063                                total_elems += mid_chs; // bn1 weight
5064                                total_elems += mid_chs * out_channels; // conv_pwl (Conv2d)
5065                                total_elems += out_channels; // bn2 weight
5066                                in_chs = *out_channels;
5067                            }
5068                            BlockType::UniversalInvertedResidual {
5069                                out_channels,
5070                                start_kernel_size,
5071                                mid_kernel_size,
5072                                stride: _,
5073                                expand_ratio,
5074                                ..
5075                            } => {
5076                                #[allow(clippy::cast_precision_loss)]
5077                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5078                                // UniversalInvertedResidual: all Conv2d layers, not quantizable
5079                                if *expand_ratio != 1.0 {
5080                                    total_elems += in_chs * mid_chs; // expand conv (Conv2d)
5081                                    total_elems += mid_chs; // expand norm
5082                                }
5083                                if *start_kernel_size > 0 {
5084                                    total_elems += mid_chs * start_kernel_size * start_kernel_size; // depthwise start (Conv2d)
5085                                    total_elems += mid_chs; // norm
5086                                }
5087                                if *mid_kernel_size > 0 {
5088                                    total_elems += mid_chs * mid_kernel_size * mid_kernel_size; // depthwise mid (Conv2d)
5089                                    total_elems += mid_chs; // norm
5090                                }
5091                                total_elems += mid_chs * out_channels; // project conv (Conv2d)
5092                                total_elems += out_channels; // project norm
5093                                total_elems += out_channels; // layer scale gamma
5094                                in_chs = *out_channels;
5095                            }
5096                            BlockType::MultiQueryAttention {
5097                                num_heads,
5098                                kv_dim,
5099                                kv_stride: _,
5100                                ..
5101                            } => {
5102                                // MMQA: all Conv2d layers, not quantizable
5103                                let dw_kernel_size = 3; // Default dw_kernel_size for MMQA
5104                                total_elems += in_chs; // norm weight
5105                                total_elems += in_chs * num_heads * kv_dim; // query_proj (Conv2d)
5106                                total_elems += in_chs * kv_dim; // key_proj (Conv2d)
5107                                total_elems += in_chs * dw_kernel_size * dw_kernel_size; // key_dw_conv (Conv2d)
5108                                total_elems += *kv_dim; // value_down_conv (Conv2d)
5109                                total_elems += 1; // value_norm weight
5110                                total_elems += *kv_dim; // value_proj (Conv2d)
5111                                total_elems += num_heads * kv_dim * in_chs; // output_proj (Conv2d)
5112                                total_elems += in_chs; // layer scale
5113                            }
5114                        }
5115                    }
5116                }
5117
5118                // Multi-scale fusion adapter (msfa) - also uses Conv2d layers
5119                let msfa_in = MSFA_IN_CHANNELS.iter().sum::<usize>();
5120                let msfa_out = MSFA_OUT_CHANNELS;
5121                #[allow(clippy::cast_precision_loss)]
5122                let msfa_mid = make_divisible(msfa_in as f64 * MSFA_EXPANSION_RATIO, 8);
5123
5124                // MSFA FFN (UIR with expansion_ratio) - Conv2d layers, not quantizable
5125                total_elems += msfa_in * msfa_mid; // expand (Conv2d)
5126                total_elems += msfa_mid; // expand norm
5127                total_elems += msfa_mid * msfa_out; // project (Conv2d)
5128                total_elems += msfa_out; // project norm
5129                total_elems += msfa_out; // final norm
5130
5131                total_elems
5132            };
5133
5134            // Vision multimodal embedder components
5135            let embed_vision_elems = {
5136                // Embedding layer (not quantizable)
5137                let embedding = vision_cfg.vocab_size * vision_cfg.hidden_size;
5138
5139                // Normalization layers (not quantizable)
5140                let hard_norm = vision_cfg.hidden_size;
5141                let soft_norm = vision_cfg.hidden_size;
5142
5143                // Projection from vision to text hidden size (IS Arc<dyn QuantMethod>, so quantizable)
5144                let projection = vision_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5145
5146                // Post-projection norm (not quantizable)
5147                let post_norm = text_cfg.hidden_size;
5148
5149                embedding + hard_norm + soft_norm + projection + post_norm
5150            };
5151
5152            vision_tower_elems + embed_vision_elems
5153        };
5154
5155        // Audio components - based on actual audio.rs structure
5156        let audio_elems = {
5157            let audio_cfg = &cfg.audio_config;
5158
5159            // SubSampleConvProjection components
5160            let subsample_conv_projection_elems = {
5161                // Conv blocks (Conv2d layers - NOT quantizable)
5162                let mut conv_elems = 0;
5163
5164                // conv_0: Conv2d from 1 channel to first channel size
5165                let in_ch_0 = 1;
5166                let out_ch_0 = audio_cfg.sscp_conv_channel_size[0];
5167                let kernel_0 = &audio_cfg.sscp_conv_kernel_size[0];
5168                conv_elems += in_ch_0 * out_ch_0 * kernel_0[0] * kernel_0[1];
5169
5170                // conv_1: Conv2d from first to second channel size
5171                let in_ch_1 = out_ch_0;
5172                let out_ch_1 = audio_cfg.sscp_conv_channel_size[1];
5173                let kernel_1 = &audio_cfg.sscp_conv_kernel_size[1];
5174                conv_elems += in_ch_1 * out_ch_1 * kernel_1[0] * kernel_1[1];
5175
5176                // CumulativeGroupNorm for each conv block (weight only, no bias by default)
5177                let norm_0 = out_ch_0; // norm weight for conv_0
5178                let norm_1 = out_ch_1; // norm weight for conv_1
5179
5180                // input_proj_linear (Arc<dyn QuantMethod> - IS quantizable)
5181                let mut f_out = audio_cfg.input_feat_size;
5182                for i in 0..2 {
5183                    let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
5184                    let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
5185                    let pad_left = 1;
5186                    let pad_right = 1;
5187                    f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
5188                }
5189                let input_proj_in_features = out_ch_1 * f_out;
5190                let input_proj_linear =
5191                    input_proj_in_features * audio_cfg.hidden_size / weight_pack_factor;
5192
5193                conv_elems + norm_0 + norm_1 + input_proj_linear
5194            };
5195
5196            // Conformer blocks
5197            let conformer_elems = {
5198                let mut total = 0;
5199
5200                for _ in 0..audio_cfg.conf_num_hidden_layers {
5201                    // ConformerAttention
5202                    let attention_elems = {
5203                        // Norms (NOT quantizable)
5204                        let pre_attn_norm = audio_cfg.hidden_size;
5205                        let post_norm = audio_cfg.hidden_size;
5206
5207                        // Attention projections (Arc<dyn QuantMethod> - IS quantizable)
5208                        let q_proj =
5209                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5210                        let k_proj =
5211                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5212                        let v_proj =
5213                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5214                        let post =
5215                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5216
5217                        // RelativePositionEmbedding
5218                        let pos_proj =
5219                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5220                        let per_dim_scale =
5221                            audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads; // head_dim
5222                        let inv_timescales = audio_cfg.hidden_size / 2; // num_timescales
5223                        let pos_indices = audio_cfg.conf_attention_context_left
5224                            + audio_cfg.conf_attention_context_right
5225                            + 1;
5226
5227                        // Local causal masks (precomputed tensors)
5228                        let chunk_size = audio_cfg.conf_attention_chunk_size;
5229                        let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5230                            + audio_cfg.conf_attention_context_right;
5231                        let local_causal_valid_mask = chunk_size * context_size; // U8 tensor
5232                        let invalid_logits_tensor = 1; // single f32 value
5233
5234                        pre_attn_norm
5235                            + post_norm
5236                            + q_proj
5237                            + k_proj
5238                            + v_proj
5239                            + post
5240                            + pos_proj
5241                            + per_dim_scale
5242                            + inv_timescales
5243                            + pos_indices
5244                            + local_causal_valid_mask
5245                            + invalid_logits_tensor
5246                    };
5247
5248                    // ConformerFeedForward (start and end)
5249                    let ffw_elems = {
5250                        // Each FFW has:
5251                        // - pre_layer_norm (NOT quantizable)
5252                        // - ffw_layer_1 (Arc<dyn QuantMethod> - IS quantizable)
5253                        // - ffw_layer_2 (Arc<dyn QuantMethod> - IS quantizable)
5254                        // - post_layer_norm (NOT quantizable)
5255                        let intermediate_size = audio_cfg.hidden_size * 4;
5256
5257                        let ffw_start = {
5258                            let pre_norm = audio_cfg.hidden_size;
5259                            let layer_1 =
5260                                audio_cfg.hidden_size * intermediate_size / weight_pack_factor;
5261                            let layer_2 =
5262                                intermediate_size * audio_cfg.hidden_size / weight_pack_factor;
5263                            let post_norm = audio_cfg.hidden_size;
5264                            pre_norm + layer_1 + layer_2 + post_norm
5265                        };
5266
5267                        let ffw_end = ffw_start; // Same structure
5268
5269                        ffw_start + ffw_end
5270                    };
5271
5272                    // ConformerLightConv1d
5273                    let lconv1d_elems = {
5274                        // Norms (NOT quantizable)
5275                        let pre_layer_norm = audio_cfg.hidden_size;
5276                        let conv_norm = audio_cfg.hidden_size;
5277
5278                        // Linear layers (Arc<dyn QuantMethod> - IS quantizable)
5279                        let linear_start = audio_cfg.hidden_size * (audio_cfg.hidden_size * 2)
5280                            / weight_pack_factor;
5281                        let linear_end =
5282                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5283
5284                        // depthwise_conv1d (Conv1d - NOT quantizable)
5285                        let depthwise = audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
5286
5287                        pre_layer_norm + conv_norm + linear_start + linear_end + depthwise
5288                    };
5289
5290                    // Final norm for conformer block (NOT quantizable)
5291                    let block_norm = audio_cfg.hidden_size;
5292
5293                    total += attention_elems + ffw_elems + lconv1d_elems + block_norm;
5294                }
5295
5296                total
5297            };
5298
5299            // Audio multimodal embedder (embed_audio)
5300            let embed_audio_elems = {
5301                // Embedding layer (ScaledEmbedding - NOT quantizable)
5302                let embedding = audio_cfg.vocab_size * audio_cfg.hidden_size;
5303
5304                // RMS norms (NOT quantizable)
5305                let hard_embedding_norm = audio_cfg.hidden_size; // with scale
5306                let soft_embedding_norm = audio_cfg.hidden_size; // with scale
5307                let embedding_post_projection_norm = text_cfg.hidden_size; // without scale
5308
5309                // Projection (Arc<dyn QuantMethod> - IS quantizable)
5310                let embedding_projection =
5311                    audio_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5312
5313                embedding
5314                    + hard_embedding_norm
5315                    + soft_embedding_norm
5316                    + embedding_post_projection_norm
5317                    + embedding_projection
5318            };
5319
5320            subsample_conv_projection_elems + conformer_elems + embed_audio_elems
5321        };
5322
5323        let vision_dtype = if dtype == DType::F16 {
5324            // f16 -> f32 for vision model in particular.
5325            DType::F32
5326        } else {
5327            dtype
5328        };
5329
5330        let total_elems = text_elems * dtype.size_in_bytes()
5331            + vision_elems * vision_dtype.size_in_bytes()
5332            + audio_elems * dtype.size_in_bytes();
5333
5334        Ok(total_elems)
5335    }
5336
5337    fn layer_sizes_in_bytes(
5338        &self,
5339        config: &str,
5340        dtype: DType,
5341        weight_pack_factor: usize,
5342        matformer_config: Option<&MatformerSliceConfig>,
5343    ) -> Result<Vec<usize>> {
5344        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5345
5346        // Apply matformer slicing if configured
5347        let (text_cfg, _layer_rename_map, _layers_skipped) = if let Some(matformer_cfg) =
5348            matformer_config
5349        {
5350            use crate::device_map::DummyDeviceMapper;
5351            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5352
5353            let dummy_mapper = DummyDeviceMapper {
5354                nm_device: Device::Cpu,
5355            };
5356            let (adjusted_cfg, _, _, layer_rename_map, layers_skipped) = handle_matformer_slicing(
5357                &cfg.text_config,
5358                &Some(matformer_cfg.clone()),
5359                &dummy_mapper,
5360            )?;
5361            (adjusted_cfg, layer_rename_map, layers_skipped)
5362        } else {
5363            (cfg.text_config.clone(), None, None)
5364        };
5365
5366        let text_cfg = &text_cfg;
5367
5368        // When matformer slicing is applied, we only include the layers that are kept
5369        let mut layer_sizes = Vec::new();
5370
5371        // Note: We don't need orig_intermediate_sizes anymore since the adjusted config
5372        // already has the correct intermediate sizes after matformer slicing
5373
5374        for layer_idx in 0..text_cfg.num_hidden_layers {
5375            let per_layer_elems = {
5376                // Layer norms
5377                let input_layernorm = text_cfg.hidden_size;
5378                let post_attention_layernorm = text_cfg.hidden_size;
5379                let pre_feedforward_layernorm = text_cfg.hidden_size;
5380                let post_feedforward_layernorm = text_cfg.hidden_size;
5381                let post_per_layer_input_norm = text_cfg.hidden_size;
5382
5383                // Attention components
5384                let size_in = text_cfg.hidden_size;
5385                let size_q = text_cfg.num_attention_heads * text_cfg.head_dim;
5386                let size_kv = text_cfg.num_key_value_heads * text_cfg.head_dim;
5387
5388                let q_proj = size_in * size_q / weight_pack_factor;
5389                let k_proj = size_in * size_kv / weight_pack_factor;
5390                let v_proj = size_in * size_kv / weight_pack_factor;
5391                let o_proj = size_q * size_in / weight_pack_factor;
5392
5393                // Q, K, V norms
5394                let q_norm = text_cfg.head_dim;
5395                let k_norm = text_cfg.head_dim;
5396                let v_norm = text_cfg.head_dim; // No bias for v_norm
5397
5398                // MLP components - use the adjusted intermediate sizes from matformer
5399                let intermediate_size = text_cfg
5400                    .intermediate_size
5401                    .0
5402                    .clone()
5403                    .map_left(|left| left[layer_idx])
5404                    .left_or_else(|(sizes, _)| sizes[layer_idx]);
5405                let gate_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5406                let up_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5407                let down_proj = intermediate_size * text_cfg.hidden_size / weight_pack_factor;
5408
5409                // AltUp components (per layer)
5410                let altup_elems = {
5411                    let correct_output_scale = text_cfg.hidden_size;
5412                    let correction_coefs = text_cfg.altup_num_inputs * text_cfg.altup_num_inputs;
5413                    let prediction_coefs =
5414                        text_cfg.altup_num_inputs * text_cfg.altup_num_inputs.pow(2);
5415                    let modality_router = text_cfg.hidden_size * text_cfg.altup_num_inputs;
5416                    let router_norm = text_cfg.hidden_size;
5417
5418                    correct_output_scale
5419                        + correction_coefs
5420                        + prediction_coefs
5421                        + modality_router
5422                        + router_norm
5423                };
5424
5425                // Laurel block components
5426                let laurel_elems = {
5427                    let left = text_cfg.hidden_size * text_cfg.laurel_rank;
5428                    let right = text_cfg.laurel_rank * text_cfg.hidden_size;
5429                    let post_norm = text_cfg.hidden_size;
5430
5431                    left + right + post_norm
5432                };
5433
5434                // Per-layer input components
5435                let per_layer_input_gate =
5436                    text_cfg.hidden_size * text_cfg.hidden_size_per_layer_input;
5437                let per_layer_projection =
5438                    text_cfg.hidden_size_per_layer_input * text_cfg.hidden_size;
5439
5440                input_layernorm
5441                    + post_attention_layernorm
5442                    + pre_feedforward_layernorm
5443                    + post_feedforward_layernorm
5444                    + post_per_layer_input_norm
5445                    + q_proj
5446                    + k_proj
5447                    + v_proj
5448                    + o_proj
5449                    + q_norm
5450                    + k_norm
5451                    + v_norm
5452                    + gate_proj
5453                    + up_proj
5454                    + down_proj
5455                    + altup_elems
5456                    + laurel_elems
5457                    + per_layer_input_gate
5458                    + per_layer_projection
5459            };
5460
5461            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5462        }
5463
5464        Ok(layer_sizes)
5465    }
5466
5467    fn num_layers(&self, config: &str) -> Result<usize> {
5468        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5469        Ok(cfg.text_config.num_hidden_layers)
5470    }
5471
5472    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5473        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5474        let cfg = cfg.text_config;
5475
5476        let cfg = ModelConfigMetadata {
5477            max_seq_len: cfg.max_position_embeddings,
5478            num_layers: cfg.num_hidden_layers,
5479            hidden_size: cfg.hidden_size,
5480            num_kv_heads: cfg.num_key_value_heads,
5481            num_attn_heads: cfg.num_attention_heads,
5482            sliding_window: None, // None to be more forgiving, some do not
5483            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5484            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5485        };
5486
5487        Ok(Box::new(cfg))
5488    }
5489
5490    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5491        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
5492    }
5493}