1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::attention::ATTENTION_CHUNK_SIZE;
24use crate::device_map::DeviceMapper;
25use crate::layers::Conv3dConfig;
26use crate::matformer::MatformerSliceConfig;
27use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
28use crate::pipeline::isq::IsqModelLoader;
29use crate::pipeline::loaders::AutoDeviceMapParams;
30use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
31use crate::pipeline::{
32    EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
33    SupportedModality,
34};
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::vision_models::clip::ClipConfig;
37use crate::vision_models::gemma3::config::Gemma3Config;
38use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
39use crate::vision_models::gemma3n::config::{Gemma3nConfig, IntermediateSize};
40use crate::vision_models::gemma3n::{Gemma3nModel, Gemma3nProcessor};
41use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
42use crate::vision_models::idefics2_input_processor::Idefics2Processor;
43use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
44use crate::vision_models::image_processor::ImagePreProcessor;
45use crate::vision_models::inputs_processor::Phi4MMProcessor;
46use crate::vision_models::llama4::{
47    self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
48};
49use crate::vision_models::llava::config::Config as LLaVAConfig;
50use crate::vision_models::llava15::Model as LLaVA;
51use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
52use crate::vision_models::llava_next::Model as LLaVANext;
53use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
54use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
55use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
56use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
57use crate::vision_models::phi3_inputs_processor::Phi3Processor;
58use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
59use crate::vision_models::preprocessor_config::PreProcessorConfig;
60use crate::vision_models::processor_config::ProcessorConfig;
61use crate::vision_models::qwen2_5_vl::{
62    Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
63};
64use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
65use crate::vision_models::qwen3_vl::{Config as Qwen3VLConfig, Qwen3VLModel, Qwen3VLProcessor};
66use crate::vision_models::{minicpmo, phi4};
67
68pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
69    #[allow(clippy::too_many_arguments)]
71    fn forward(
72        &self,
73        input_ids: &Tensor,
74        pixel_values: Option<Tensor>,
75        seqlen_offsets: &[usize],
76        context_lens: Vec<(usize, usize)>,
77        position_ids: Vec<usize>,
78        model_specific_args: Box<dyn Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
80        flash_params: &FlashParams,
81    ) -> candle_core::Result<Tensor>;
82    fn device(&self) -> &Device;
83    fn cache(&self) -> &EitherCache;
84    fn cache_mut(&mut self) -> &mut EitherCache;
85    fn max_seq_len(&self) -> usize;
86    fn config(&self) -> &ModelConfigMetadata;
87    fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
89}
90
91pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
92    fn load(
93        &self,
94        config: &str,
95        vb: ShardedVarBuilder,
96        normal_loading_metadata: NormalLoadingMetadata,
97        attention_mechanism: AttentionImplementation,
98    ) -> Result<Box<dyn VisionModel + Send + Sync>>;
99    fn is_gptx(&self, config: &str) -> bool;
100    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
101    fn get_processor(
102        &self,
103        model_config: &str,
104        processor_config: Option<ProcessorConfig>,
105        preprocessor_config: PreProcessorConfig,
106        max_edge: Option<u32>,
107    ) -> Arc<dyn Processor + Send + Sync>;
108    fn supports_paged_attention(&self, config: &str) -> bool;
109    fn supports_prefix_cacher(&self, _config: &str) -> bool {
110        false
112    }
113    fn modalities(&self, config: &str) -> Result<Modalities>;
114    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
115    fn get_device_for_tensor(
116        &self,
117        config: &str,
118        _mapper: &dyn DeviceMapper,
119        loading_isq: bool,
120    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
121        if loading_isq {
122            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
123        } else {
124            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
125            let num_layers = self.model_config(config)?.num_layers();
126            let closure = move |name: String| {
127                if let Some(captures) = re.captures(&name) {
128                    captures
129                        .get(1)
130                        .and_then(|m| m.as_str().parse::<usize>().ok())
131                        .map(|l| l.min(num_layers))
132                        .map(DeviceForLoadTensor::Idx)
133                        .unwrap_or(DeviceForLoadTensor::Base)
134                } else {
135                    DeviceForLoadTensor::Base
136                }
137            };
138
139            Ok(Arc::new(closure))
140        }
141    }
142}
143
144#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
145#[derive(Clone, Debug, Deserialize, PartialEq)]
146pub enum VisionLoaderType {
148    #[serde(rename = "phi3v")]
149    Phi3V,
150    #[serde(rename = "idefics2")]
151    Idefics2,
152    #[serde(rename = "llava_next")]
153    LLaVANext,
154    #[serde(rename = "llava")]
155    LLaVA,
156    #[serde(rename = "vllama")]
157    VLlama,
158    #[serde(rename = "qwen2vl")]
159    Qwen2VL,
160    #[serde(rename = "idefics3")]
161    Idefics3,
162    #[serde(rename = "minicpmo")]
163    MiniCpmO,
164    #[serde(rename = "phi4mm")]
165    Phi4MM,
166    #[serde(rename = "qwen2_5vl")]
167    Qwen2_5VL,
168    #[serde(rename = "gemma3")]
169    Gemma3,
170    #[serde(rename = "mistral3")]
171    Mistral3,
172    #[serde(rename = "llama4")]
173    Llama4,
174    #[serde(rename = "gemma3n")]
175    Gemma3n,
176    #[serde(rename = "qwen3vl")]
177    Qwen3VL,
178}
179
180impl VisionLoaderType {
182    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
183        match name {
184            "Phi3VForCausalLM" => Ok(Self::Phi3V),
185            "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
186            "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
187            "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
188            "MllamaForConditionalGeneration" => Ok(Self::VLlama),
189            "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
190            "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
191            "MiniCPMO" => Ok(Self::MiniCpmO),
192            "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
193            "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
194            "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
195            "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
196            "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
197            "Gemma3nForConditionalGeneration" => Ok(Self::Gemma3n),
198            "Qwen3VLForConditionalGeneration" => Ok(Self::Qwen3VL),
199            other => anyhow::bail!(
200                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
201            ),
202        }
203    }
204}
205
206impl FromStr for VisionLoaderType {
207    type Err = String;
208    fn from_str(s: &str) -> Result<Self, Self::Err> {
209        match s {
210            "phi3v" => Ok(Self::Phi3V),
211            "idefics2" => Ok(Self::Idefics2),
212            "llava_next" => Ok(Self::LLaVANext),
213            "llava" => Ok(Self::LLaVA),
214            "vllama" => Ok(Self::VLlama),
215            "qwen2vl" => Ok(Self::Qwen2VL),
216            "idefics3" => Ok(Self::Idefics3),
217            "minicpmo" => Ok(Self::MiniCpmO),
218            "phi4mm" => Ok(Self::Phi4MM),
219            "qwen2_5vl" => Ok(Self::Qwen2_5VL),
220            "gemma3" => Ok(Self::Gemma3),
221            "mistral3" => Ok(Self::Mistral3),
222            "llama4" => Ok(Self::Llama4),
223            "gemma3n" => Ok(Self::Gemma3n),
224            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`, `gemma3n`.")),
225        }
226    }
227}
228
229impl std::fmt::Display for VisionLoaderType {
230    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        let name = match self {
232            VisionLoaderType::Phi3V => "phi3v",
233            VisionLoaderType::Idefics2 => "idefics2",
234            VisionLoaderType::LLaVANext => "llava_next",
235            VisionLoaderType::LLaVA => "llava",
236            VisionLoaderType::VLlama => "vllama",
237            VisionLoaderType::Qwen2VL => "qwen2vl",
238            VisionLoaderType::Idefics3 => "idefics3",
239            VisionLoaderType::MiniCpmO => "minicpmo",
240            VisionLoaderType::Phi4MM => "phi4mm",
241            VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
242            VisionLoaderType::Gemma3 => "gemma3",
243            VisionLoaderType::Mistral3 => "mistral3",
244            VisionLoaderType::Llama4 => "llama4",
245            VisionLoaderType::Gemma3n => "gemma3n",
246            VisionLoaderType::Qwen3VL => "qwen3vl",
247        };
248        write!(f, "{name}")
249    }
250}
251
252#[derive(Deserialize)]
253struct AutoVisionLoaderConfig {
254    architectures: Vec<String>,
255}
256
257pub struct AutoVisionLoader;
259
260impl AutoVisionLoader {
261    fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
262        let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
263        if auto_cfg.architectures.len() != 1 {
264            anyhow::bail!("Expected exactly one architecture in config");
265        }
266
267        let name = &auto_cfg.architectures[0];
268        let tp = VisionLoaderType::from_causal_lm_name(name)?;
269
270        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
271
272        Ok(match tp {
274            VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
275            VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
276            VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
277            VisionLoaderType::LLaVA => Box::new(LLaVALoader),
278            VisionLoaderType::VLlama => Box::new(VLlamaLoader),
279            VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
280            VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
281            VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
282            VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
283            VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
284            VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
285            VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
286            VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
287            VisionLoaderType::Gemma3n => Box::new(Gemma3nLoader),
288            VisionLoaderType::Qwen3VL => Box::new(Qwen3VLLoader),
289        })
290    }
291}
292
293impl VisionModelLoader for AutoVisionLoader {
294    fn load(
295        &self,
296        config: &str,
297        vb: ShardedVarBuilder,
298        normal_loading_metadata: NormalLoadingMetadata,
299        attention_mechanism: AttentionImplementation,
300    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
301        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
302    }
303
304    fn is_gptx(&self, config: &str) -> bool {
305        Self::get_loader(config)
306            .expect("AutoVisionLoader get_loader")
307            .is_gptx(config)
308    }
309
310    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
311        Self::get_loader(config)?.get_config_repr(config)
312    }
313
314    fn get_processor(
315        &self,
316        model_config: &str,
317        proc_cfg: Option<ProcessorConfig>,
318        preproc_cfg: PreProcessorConfig,
319        max_edge: Option<u32>,
320    ) -> Arc<dyn Processor + Send + Sync> {
321        Self::get_loader(model_config)
322            .expect("AutoVisionLoader get_loader")
323            .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
324    }
325
326    fn supports_paged_attention(&self, config: &str) -> bool {
327        Self::get_loader(config)
328            .expect("AutoVisionLoader")
329            .supports_paged_attention(config)
330    }
331
332    fn modalities(&self, config: &str) -> Result<Modalities> {
333        Self::get_loader(config)?.modalities(config)
334    }
335
336    fn supports_prefix_cacher(&self, config: &str) -> bool {
337        Self::get_loader(config)
338            .expect("AutoVisionLoader")
339            .supports_prefix_cacher(config)
340    }
341
342    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
343        Self::get_loader(config)
344            .expect("AutoVisionLoader")
345            .prefixer(config)
346    }
347
348    fn get_device_for_tensor(
349        &self,
350        config: &str,
351        mapper: &dyn DeviceMapper,
352        loading_isq: bool,
353    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
354        Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
355    }
356}
357
358impl IsqModelLoader for AutoVisionLoader {
359    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
360        Self::get_loader(config)?.isq_layer_regexes(config)
361    }
362    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
363        Self::get_loader(config)?.immediate_isq_predicates(config)
364    }
365}
366
367impl DeviceMappedModelLoader for AutoVisionLoader {
368    fn mapped_max_act_size_elems(
369        &self,
370        config: &str,
371        params: &AutoDeviceMapParams,
372    ) -> Result<usize> {
373        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
374    }
375    fn non_mapped_max_act_size_elems(
376        &self,
377        config: &str,
378        params: &AutoDeviceMapParams,
379    ) -> Result<usize> {
380        Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
381    }
382    fn non_mapped_size_in_bytes(
383        &self,
384        config: &str,
385        dtype: DType,
386        weight_pack_factor: usize,
387        _matformer_config: Option<&MatformerSliceConfig>,
388    ) -> Result<usize> {
389        Self::get_loader(config)?.non_mapped_size_in_bytes(
390            config,
391            dtype,
392            weight_pack_factor,
393            _matformer_config,
394        )
395    }
396    fn layer_sizes_in_bytes(
397        &self,
398        config: &str,
399        dtype: DType,
400        weight_pack_factor: usize,
401        _matformer_config: Option<&MatformerSliceConfig>,
402    ) -> Result<Vec<usize>> {
403        Self::get_loader(config)?.layer_sizes_in_bytes(
404            config,
405            dtype,
406            weight_pack_factor,
407            _matformer_config,
408        )
409    }
410    fn num_layers(&self, config: &str) -> Result<usize> {
411        Self::get_loader(config)?.num_layers(config)
412    }
413    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
414        Self::get_loader(config)?.model_config(config)
415    }
416}
417
418macro_rules! bias_if {
419    ($cond:expr, $size:expr) => {
420        if $cond {
421            $size
422        } else {
423            0
424        }
425    };
426}
427
428fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
429    let pre_layer_norm = cfg.hidden_size;
430    let final_layer_norm = cfg.hidden_size;
431
432    let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
433    let num_positions = num_patches + 1;
434
435    let class_embedding = cfg.hidden_size;
436
437    let position_ids = num_positions;
438    let position_embedding = num_positions * cfg.hidden_size;
439
440    let conv2dconfig = Conv2dConfig {
441        stride: cfg.patch_size,
442        ..Default::default()
443    };
444    let patch_embedding =
445        cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
446
447    let encoder_layer_elems = {
448        let layer_norm1 = cfg.hidden_size;
449        let layer_norm2 = cfg.hidden_size;
450
451        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
452        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
453        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
454        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
455
456        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
457        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
458
459        layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
460    };
461
462    pre_layer_norm
463        + final_layer_norm
464        + class_embedding
465        + position_ids
466        + position_embedding
467        + patch_embedding
468        + cfg.num_hidden_layers * encoder_layer_elems
469}
470
471pub struct Phi3VLoader;
477
478pub struct Phi3VPrefixer;
479
480impl MultimodalPromptPrefixer for Phi3VPrefixer {
481    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
482        format!(
484            "{}{prompt}",
485            image_indexes
486                .into_iter()
487                .map(|image_index| format!("<|image_{}|>", image_index + 1))
488                .join("")
489        )
490    }
491}
492
493impl VisionModelLoader for Phi3VLoader {
494    fn load(
495        &self,
496        config: &str,
497        vb: ShardedVarBuilder,
498        normal_loading_metadata: NormalLoadingMetadata,
499        attention_mechanism: AttentionImplementation,
500    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
501        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
502        Ok(Box::new(Phi3::new(
503            &cfg,
504            vb,
505            self.is_gptx(config),
506            normal_loading_metadata,
507            attention_mechanism,
508        )?))
509    }
510    fn is_gptx(&self, _config: &str) -> bool {
511        true
512    }
513    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
514        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
515        Ok(Box::new(cfg))
516    }
517    fn get_processor(
518        &self,
519        _model_config: &str,
520        processor_config: Option<ProcessorConfig>,
521        preprocessor_config: PreProcessorConfig,
522        _max_edge: Option<u32>,
523    ) -> Arc<dyn Processor + Send + Sync> {
524        Phi3Processor::new_processor(processor_config, preprocessor_config)
525    }
526    fn supports_paged_attention(&self, _config: &str) -> bool {
527        true
528    }
529    fn supports_prefix_cacher(&self, _config: &str) -> bool {
530        true
531    }
532    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
533        Arc::new(Phi3VPrefixer)
534    }
535    fn modalities(&self, _config: &str) -> Result<Modalities> {
536        Ok(Modalities {
537            input: vec![SupportedModality::Text, SupportedModality::Vision],
538            output: vec![SupportedModality::Text],
539        })
540    }
541}
542
543impl IsqModelLoader for Phi3VLoader {
544    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
545        Ok(vec![
546            Regex::new(r"lm_head\.(weight|bias)$")?,
547            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
549            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
550            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
552            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
553        ])
554    }
555    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
556        self.isq_layer_regexes(config)
557    }
558}
559
560impl DeviceMappedModelLoader for Phi3VLoader {
561    fn mapped_max_act_size_elems(
562        &self,
563        config: &str,
564        params: &AutoDeviceMapParams,
565    ) -> Result<usize> {
566        let AutoDeviceMapParams::Vision {
568            max_seq_len,
569            max_batch_size,
570            max_image_shape: _,
571            max_num_images,
572        } = params
573        else {
574            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
575        };
576
577        let cfg: Phi3Config = serde_json::from_str(config)?;
578
579        let vcfg = &PHI3V_CLIP_CONFIG;
580
581        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
582        let img_seq_len = (num_patches + 1) * max_num_images;
583
584        let max_text_attn = {
585            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
587            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
588        };
589
590        Ok(max_text_attn)
591    }
592
593    fn non_mapped_max_act_size_elems(
594        &self,
595        config: &str,
596        params: &AutoDeviceMapParams,
597    ) -> Result<usize> {
598        let AutoDeviceMapParams::Vision {
600            max_seq_len: _,
601            max_batch_size,
602            max_image_shape: _,
603            max_num_images,
604        } = params
605        else {
606            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
607        };
608
609        let cfg: Phi3Config = serde_json::from_str(config)?;
610
611        let vcfg = &PHI3V_CLIP_CONFIG;
612
613        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
614        let img_seq_len = num_patches + 1;
615
616        let max_vision_attn = {
617            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
618        };
619
620        Ok(max_vision_attn)
621    }
622
623    fn non_mapped_size_in_bytes(
624        &self,
625        config: &str,
626        dtype: DType,
627        weight_pack_factor: usize,
628        _matformer_config: Option<&MatformerSliceConfig>,
629    ) -> Result<usize> {
630        let cfg: Phi3Config = serde_json::from_str(config)?;
631        let elems = {
632            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
633            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
635                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
636            } else {
637                0
638            };
639            let norm = cfg.hidden_size;
640
641            let image_embed = {
642                let projection_cls = cfg
643                    .embd_layer
644                    .projection_cls
645                    .clone()
646                    .unwrap_or("linear".to_string());
647                let with_learnable_separator =
648                    cfg.embd_layer.with_learnable_separator.unwrap_or(false);
649                let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
650                let image_dim_out = cfg.img_processor.image_dim_out;
651
652                let proj = match (projection_cls.as_str(), use_hd_transform) {
653                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
654                    ("mlp", true) => {
655                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
656                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
657                        a + b
658                    }
659                    ("mlp", false) => {
660                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
661                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
662                        a + b
663                    }
664                    _ => {
665                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
666                    }
667                };
668
669                let (glb_gn, sub_gn) = if with_learnable_separator {
670                    let glb_gn = image_dim_out * 4;
671                    let sub_gn = image_dim_out * 4;
672                    (glb_gn, sub_gn)
673                } else {
674                    (0, 0)
675                };
676
677                let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
678
679                proj + glb_gn + sub_gn + clip_vit
680            };
681
682            embed_tokens + lm_head + norm + image_embed
683        };
684
685        Ok(elems * dtype.size_in_bytes())
686    }
687
688    fn layer_sizes_in_bytes(
689        &self,
690        config: &str,
691        dtype: DType,
692        weight_pack_factor: usize,
693        _matformer_config: Option<&MatformerSliceConfig>,
694    ) -> Result<Vec<usize>> {
695        let cfg: Phi3Config = serde_json::from_str(config)?;
696        let per_layer_elems = {
697            let input_layernorm = cfg.hidden_size;
698            let post_attention_layernorm = cfg.hidden_size;
699
700            let size_in = cfg.hidden_size;
701            let head_dim = cfg.head_dim();
702            let op_size =
703                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
704            let qkv_proj = size_in * op_size / weight_pack_factor;
705            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
706
707            let h_size = cfg.hidden_size;
708            let i_size = cfg.intermediate_size;
709            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
710            let down_proj = h_size * i_size / weight_pack_factor;
711
712            input_layernorm
713                + post_attention_layernorm
714                + qkv_proj
715                + o_proj
716                + gate_up_proj
717                + down_proj
718        };
719        Ok(vec![
720            per_layer_elems * dtype.size_in_bytes();
721            cfg.num_hidden_layers
722        ])
723    }
724
725    fn num_layers(&self, config: &str) -> Result<usize> {
726        let cfg: Phi3Config = serde_json::from_str(config)?;
727        Ok(cfg.num_hidden_layers)
728    }
729
730    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
731        let cfg: Phi3Config = serde_json::from_str(config)?;
732
733        let cfg = ModelConfigMetadata {
734            max_seq_len: cfg.max_position_embeddings,
735            num_layers: cfg.num_hidden_layers,
736            hidden_size: cfg.hidden_size,
737            num_kv_heads: cfg.num_key_value_heads,
738            num_attn_heads: cfg.num_attention_heads,
739            sliding_window: cfg.sliding_window,
740            k_head_dim: cfg.head_dim(),
741            v_head_dim: cfg.head_dim(),
742        };
743
744        Ok(Box::new(cfg))
745    }
746
747    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
748        Some(vec![NonMappedSubModel::Vision])
749    }
750}
751
752pub struct Idefics2Loader;
758
759pub struct Idefics2Prefixer;
760
761impl MultimodalPromptPrefixer for Idefics2Prefixer {
762    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
763        prompt.to_string()
765    }
766}
767
768impl VisionModelLoader for Idefics2Loader {
769    fn load(
770        &self,
771        config: &str,
772        vb: ShardedVarBuilder,
773        normal_loading_metadata: NormalLoadingMetadata,
774        attention_mechanism: AttentionImplementation,
775    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
776        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
777        Ok(Box::new(Idefics2::new(
778            &cfg,
779            vb,
780            self.is_gptx(config),
781            normal_loading_metadata,
782            attention_mechanism,
783        )?))
784    }
785    fn is_gptx(&self, _config: &str) -> bool {
786        true
787    }
788    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
789        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
790        Ok(Box::new(cfg))
791    }
792    fn get_processor(
793        &self,
794        _model_config: &str,
795        processor_config: Option<ProcessorConfig>,
796        preprocessor_config: PreProcessorConfig,
797        max_edge: Option<u32>,
798    ) -> Arc<dyn Processor + Send + Sync> {
799        Arc::new(Idefics2Processor::new(
800            processor_config.unwrap(),
801            preprocessor_config,
802            max_edge,
803        ))
804    }
805    fn supports_paged_attention(&self, _config: &str) -> bool {
806        true
807    }
808    fn supports_prefix_cacher(&self, _config: &str) -> bool {
809        true
810    }
811    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
812        Arc::new(Idefics2Prefixer)
813    }
814    fn modalities(&self, _config: &str) -> Result<Modalities> {
815        Ok(Modalities {
816            input: vec![SupportedModality::Text, SupportedModality::Vision],
817            output: vec![SupportedModality::Text],
818        })
819    }
820}
821
822impl IsqModelLoader for Idefics2Loader {
823    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
824        Ok(vec![
825            Regex::new(r"lm_head\.(weight|bias)$")?,
826            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
828            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
829            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
830            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
831            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
833            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
834            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
835        ])
836    }
837    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
838        Ok(vec![
839            Regex::new(r"lm_head\.(weight|bias)$")?,
840            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
842            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
843            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
844            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
845            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
847            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
848            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
849        ])
850    }
851}
852
853impl DeviceMappedModelLoader for Idefics2Loader {
854    fn mapped_max_act_size_elems(
855        &self,
856        config: &str,
857        params: &AutoDeviceMapParams,
858    ) -> Result<usize> {
859        let AutoDeviceMapParams::Vision {
860            max_seq_len,
861            max_batch_size,
862            max_image_shape: _,
863            max_num_images,
864        } = params
865        else {
866            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
867        };
868
869        let cfg: Idefics2Config = serde_json::from_str(config)?;
870
871        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
872        let img_seq_len = (num_patches + 1) * max_num_images;
873
874        let max_text_attn = {
875            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
877            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
878        };
879
880        Ok(max_text_attn)
881    }
882
883    fn non_mapped_max_act_size_elems(
884        &self,
885        config: &str,
886        params: &AutoDeviceMapParams,
887    ) -> Result<usize> {
888        let AutoDeviceMapParams::Vision {
889            max_seq_len: _,
890            max_batch_size,
891            max_image_shape: _,
892            max_num_images,
893        } = params
894        else {
895            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
896        };
897
898        let cfg: Idefics2Config = serde_json::from_str(config)?;
899
900        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
901        let img_seq_len = num_patches + 1;
902
903        let max_vision_attn = {
904            let images_factor = 5;
906
907            (max_batch_size * images_factor * max_num_images)
908                * cfg.vision_config.num_attention_heads
909                * img_seq_len
910                * img_seq_len
911        };
912
913        Ok(max_vision_attn)
914    }
915
916    fn non_mapped_size_in_bytes(
917        &self,
918        config: &str,
919        dtype: DType,
920        weight_pack_factor: usize,
921        _matformer_config: Option<&MatformerSliceConfig>,
922    ) -> Result<usize> {
923        let cfg: Idefics2Config = serde_json::from_str(config)?;
924        let text_elems = {
925            let tie_word_embeddings = cfg.tie_word_embeddings;
926            let cfg = &cfg.text_config;
927
928            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
929            let lm_head = if !tie_word_embeddings {
930                cfg.hidden_size * cfg.vocab_size
931            } else {
932                0
933            };
934            let norm = cfg.hidden_size;
935            embed_tokens + lm_head + norm
936        };
937
938        let connector_elems = {
939            let tcfg = &cfg.text_config;
940            let vcfg = &cfg.vision_config;
941            let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
942            let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
943            let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
944
945            let perceiver_elems = {
946                let tcfg = &cfg.text_config;
947                let pcfg = &cfg.perceiver_config;
948
949                let n_latents = pcfg.resampler_n_latents;
950                let hidden_size = tcfg.hidden_size;
951                let depth = pcfg.resampler_depth;
952
953                let norm = tcfg.hidden_size;
954                let latents = n_latents * hidden_size;
955
956                let layer_elems = {
957                    let input_latents_norm = hidden_size;
958                    let input_context_norm = hidden_size;
959                    let post_attn_norm = hidden_size;
960
961                    let num_heads = pcfg.resampler_n_heads;
962                    let head_dim = pcfg.resampler_head_dim;
963                    let num_key_value_heads = pcfg.num_key_value_heads;
964
965                    let q_proj = hidden_size * num_heads * head_dim;
966                    let k_proj = hidden_size * num_key_value_heads * head_dim;
967                    let v_proj = hidden_size * num_key_value_heads * head_dim;
968                    let o_proj = num_heads * head_dim * hidden_size;
969
970                    let gate_proj = hidden_size * hidden_size * 4;
971                    let up_proj = hidden_size * hidden_size * 4;
972                    let down_proj = hidden_size * 4 * hidden_size;
973
974                    input_latents_norm
975                        + input_context_norm
976                        + post_attn_norm
977                        + q_proj
978                        + k_proj
979                        + v_proj
980                        + o_proj
981                        + gate_proj
982                        + up_proj
983                        + down_proj
984                };
985
986                norm + latents + layer_elems * depth
987            };
988
989            gate_proj + up_proj + down_proj + perceiver_elems
990        };
991
992        let vision_transformer = {
993            let cfg = &cfg.vision_config;
994
995            let post_layernorm = cfg.hidden_size;
996
997            let conv_config = Conv2dConfig {
998                stride: cfg.patch_size,
999                ..Default::default()
1000            };
1001            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
1002                * cfg.patch_size
1003                * cfg.patch_size;
1004
1005            let num_patches_per_side = cfg.image_size / cfg.patch_size;
1006            let num_patches = num_patches_per_side.pow(2);
1007            let position_embedding = num_patches * cfg.hidden_size;
1008
1009            let layer_elems = {
1010                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1011                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1012
1013                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
1014                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
1015
1016                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1017                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1018                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1019                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1020
1021                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
1022            };
1023
1024            post_layernorm + patch_embedding + position_embedding + layer_elems
1025        };
1026
1027        let elems = text_elems + connector_elems + vision_transformer;
1028
1029        Ok(elems * dtype.size_in_bytes())
1030    }
1031
1032    fn layer_sizes_in_bytes(
1033        &self,
1034        config: &str,
1035        dtype: DType,
1036        weight_pack_factor: usize,
1037        _matformer_config: Option<&MatformerSliceConfig>,
1038    ) -> Result<Vec<usize>> {
1039        let cfg: Idefics2Config = serde_json::from_str(config)?;
1040        let cfg = cfg.text_config;
1041        let per_layer_elems = {
1042            let input_layernorm = cfg.hidden_size;
1043            let post_attention_layernorm = cfg.hidden_size;
1044
1045            let size_in = cfg.hidden_size;
1046            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1047            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1048            let q_proj = size_in * size_q / weight_pack_factor;
1049            let k_proj = size_in * size_kv / weight_pack_factor;
1050            let v_proj = size_in * size_kv / weight_pack_factor;
1051            let o_proj = size_q * size_in / weight_pack_factor;
1052
1053            let h_size = cfg.hidden_size;
1054            let i_size = cfg.intermediate_size;
1055            let gate_proj = h_size * i_size / weight_pack_factor;
1056            let up_proj = h_size * i_size / weight_pack_factor;
1057            let down_proj = i_size * h_size / weight_pack_factor;
1058
1059            input_layernorm
1060                + post_attention_layernorm
1061                + q_proj
1062                + k_proj
1063                + v_proj
1064                + o_proj
1065                + gate_proj
1066                + up_proj
1067                + down_proj
1068        };
1069        Ok(vec![
1070            per_layer_elems * dtype.size_in_bytes();
1071            cfg.num_hidden_layers
1072        ])
1073    }
1074
1075    fn num_layers(&self, config: &str) -> Result<usize> {
1076        let cfg: Idefics2Config = serde_json::from_str(config)?;
1077        Ok(cfg.text_config.num_hidden_layers)
1078    }
1079    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1080        let cfg: Idefics2Config = serde_json::from_str(config)?;
1081        let cfg = &cfg.text_config;
1082
1083        let cfg = ModelConfigMetadata {
1084            max_seq_len: cfg.max_position_embeddings,
1085            num_layers: cfg.num_hidden_layers,
1086            hidden_size: cfg.hidden_size,
1087            num_kv_heads: cfg.num_key_value_heads,
1088            num_attn_heads: cfg.num_attention_heads,
1089            sliding_window: cfg.sliding_window,
1090            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1091            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1092        };
1093
1094        Ok(Box::new(cfg))
1095    }
1096
1097    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1098        Some(vec![NonMappedSubModel::Vision])
1099    }
1100}
1101
1102pub struct LLaVANextLoader;
1108
1109pub struct LLaVANextPrefixer;
1110
1111impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1112    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1113        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1114    }
1115}
1116
1117impl VisionModelLoader for LLaVANextLoader {
1118    fn load(
1119        &self,
1120        config: &str,
1121        vb: ShardedVarBuilder,
1122        normal_loading_metadata: NormalLoadingMetadata,
1123        attention_mechanism: AttentionImplementation,
1124    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1125        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1126        Ok(Box::new(LLaVANext::new(
1127            &cfg,
1128            vb,
1129            self.is_gptx(config),
1130            normal_loading_metadata,
1131            attention_mechanism,
1132        )?))
1133    }
1134    fn is_gptx(&self, _config: &str) -> bool {
1135        false
1136    }
1137    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1138        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1139        Ok(Box::new(cfg))
1140    }
1141    fn get_processor(
1142        &self,
1143        model_config: &str,
1144        _processor_config: Option<ProcessorConfig>,
1145        _preprocessor_config: PreProcessorConfig,
1146        _max_edge: Option<u32>,
1147    ) -> Arc<dyn Processor + Send + Sync> {
1148        Arc::new(LLaVANextProcessor::new(model_config))
1149    }
1150    fn supports_paged_attention(&self, _config: &str) -> bool {
1151        true
1152    }
1153    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1154        true
1155    }
1156    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1157        Arc::new(LLaVANextPrefixer)
1158    }
1159    fn modalities(&self, _config: &str) -> Result<Modalities> {
1160        Ok(Modalities {
1161            input: vec![SupportedModality::Text, SupportedModality::Vision],
1162            output: vec![SupportedModality::Text],
1163        })
1164    }
1165}
1166
1167impl IsqModelLoader for LLaVANextLoader {
1168    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1169        Ok(vec![
1170            Regex::new(r"lm_head\.(weight|bias)$")?,
1171            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1173            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1174            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1175            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1176            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1178            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1179            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1180        ])
1181    }
1182    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1183        Ok(vec![
1184            Regex::new(r"lm_head\.(weight|bias)$")?,
1185            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1187            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1188            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1189            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1190            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1192            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1193            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1194        ])
1195    }
1196}
1197
1198impl DeviceMappedModelLoader for LLaVANextLoader {
1199    fn mapped_max_act_size_elems(
1200        &self,
1201        config: &str,
1202        params: &AutoDeviceMapParams,
1203    ) -> Result<usize> {
1204        let AutoDeviceMapParams::Vision {
1205            max_seq_len,
1206            max_batch_size,
1207            max_image_shape,
1208            max_num_images,
1209        } = params
1210        else {
1211            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1212        };
1213
1214        let config: LLaVAConfig = serde_json::from_str(config)?;
1215
1216        #[allow(clippy::cast_possible_truncation)]
1217        let img_seq_len =
1218            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1219                &config,
1220                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1221            );
1222        let img_seq_len = img_seq_len * max_num_images;
1223
1224        let max_text_attn = {
1225            let cfg = &config.text_config;
1226            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1228
1229            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1230        };
1231
1232        Ok(max_text_attn)
1233    }
1234
1235    fn non_mapped_max_act_size_elems(
1236        &self,
1237        config: &str,
1238        params: &AutoDeviceMapParams,
1239    ) -> Result<usize> {
1240        let AutoDeviceMapParams::Vision {
1241            max_seq_len: _,
1242            max_batch_size,
1243            max_image_shape,
1244            max_num_images,
1245        } = params
1246        else {
1247            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1248        };
1249
1250        let config: LLaVAConfig = serde_json::from_str(config)?;
1251
1252        #[allow(clippy::cast_possible_truncation)]
1253        let img_seq_len =
1254            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1255                &config,
1256                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1257            );
1258
1259        let max_vision_attn = {
1260            (max_batch_size * max_num_images)
1261                * config.vision_config.num_attention_heads
1262                * img_seq_len
1263                * img_seq_len
1264        };
1265
1266        Ok(max_vision_attn)
1267    }
1268
1269    fn non_mapped_size_in_bytes(
1270        &self,
1271        config: &str,
1272        dtype: DType,
1273        weight_pack_factor: usize,
1274        _matformer_config: Option<&MatformerSliceConfig>,
1275    ) -> Result<usize> {
1276        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1277        let text_elems = {
1278            let cfg = &cfg.text_config;
1279            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1280            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1281            let norm = cfg.hidden_size;
1282            embed_tokens + lm_head + norm
1283        };
1284
1285        let image_newline = cfg.text_config.hidden_size;
1286        let mmproj = {
1287            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1288                + cfg.text_config.hidden_size;
1289            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1290                + cfg.text_config.hidden_size;
1291
1292            linear_1 + linear_2
1293        };
1294        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1295
1296        let elems = text_elems + image_newline + mmproj + vision_tower;
1297        Ok(elems * dtype.size_in_bytes())
1298    }
1299
1300    fn layer_sizes_in_bytes(
1301        &self,
1302        config: &str,
1303        dtype: DType,
1304        weight_pack_factor: usize,
1305        _matformer_config: Option<&MatformerSliceConfig>,
1306    ) -> Result<Vec<usize>> {
1307        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1308        let per_layer_elems = {
1309            let cfg = &cfg.text_config;
1310            let input_layernorm = cfg.hidden_size;
1311            let post_attention_layernorm = cfg.hidden_size;
1312
1313            let size_in = cfg.hidden_size;
1314            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1315            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1316            let q_proj = size_in * size_q / weight_pack_factor;
1317            let k_proj = size_in * size_kv / weight_pack_factor;
1318            let v_proj = size_in * size_kv / weight_pack_factor;
1319            let o_proj = size_q * size_in / weight_pack_factor;
1320
1321            let h_size = cfg.hidden_size;
1322            let i_size = cfg.intermediate_size;
1323            let gate_proj = h_size * i_size / weight_pack_factor;
1324            let up_proj = h_size * i_size / weight_pack_factor;
1325            let down_proj = i_size * h_size / weight_pack_factor;
1326
1327            input_layernorm
1328                + post_attention_layernorm
1329                + q_proj
1330                + k_proj
1331                + v_proj
1332                + o_proj
1333                + gate_proj
1334                + up_proj
1335                + down_proj
1336        };
1337        Ok(vec![
1338            per_layer_elems * dtype.size_in_bytes();
1339            cfg.text_config.num_hidden_layers
1340        ])
1341    }
1342
1343    fn num_layers(&self, config: &str) -> Result<usize> {
1344        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1345        Ok(cfg.text_config.num_hidden_layers)
1346    }
1347
1348    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1349        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1350        let cfg = &cfg.text_config;
1351
1352        let cfg = ModelConfigMetadata {
1353            max_seq_len: cfg.max_position_embeddings,
1354            num_layers: cfg.num_hidden_layers,
1355            hidden_size: cfg.hidden_size,
1356            num_kv_heads: cfg.num_key_value_heads,
1357            num_attn_heads: cfg.num_attention_heads,
1358            sliding_window: cfg.sliding_window,
1359            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1360            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1361        };
1362
1363        Ok(Box::new(cfg))
1364    }
1365
1366    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1367        Some(vec![NonMappedSubModel::Vision])
1368    }
1369}
1370
1371pub struct LLaVALoader;
1377
1378pub struct LLaVAPrefixer;
1379
1380impl MultimodalPromptPrefixer for LLaVAPrefixer {
1381    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1382        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1383    }
1384}
1385
1386impl VisionModelLoader for LLaVALoader {
1387    fn load(
1388        &self,
1389        config: &str,
1390        vb: ShardedVarBuilder,
1391        normal_loading_metadata: NormalLoadingMetadata,
1392        attention_mechanism: AttentionImplementation,
1393    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1394        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1395        Ok(Box::new(LLaVA::new(
1396            &cfg,
1397            vb,
1398            self.is_gptx(config),
1399            normal_loading_metadata,
1400            attention_mechanism,
1401        )?))
1402    }
1403    fn is_gptx(&self, _config: &str) -> bool {
1404        false
1405    }
1406    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1407        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1408        Ok(Box::new(cfg))
1409    }
1410    fn get_processor(
1411        &self,
1412        model_config: &str,
1413        _processor_config: Option<ProcessorConfig>,
1414        _preprocessor_config: PreProcessorConfig,
1415        _max_edge: Option<u32>,
1416    ) -> Arc<dyn Processor + Send + Sync> {
1417        Arc::new(LLaVAProcessor::new(model_config))
1418    }
1419    fn supports_paged_attention(&self, _config: &str) -> bool {
1420        true
1421    }
1422    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1423        true
1424    }
1425    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1426        Arc::new(LLaVAPrefixer)
1427    }
1428    fn modalities(&self, _config: &str) -> Result<Modalities> {
1429        Ok(Modalities {
1430            input: vec![SupportedModality::Text, SupportedModality::Vision],
1431            output: vec![SupportedModality::Text],
1432        })
1433    }
1434}
1435
1436impl IsqModelLoader for LLaVALoader {
1437    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1438        Ok(vec![
1439            Regex::new(r"lm_head\.(weight|bias)$")?,
1440            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1442            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1443            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1444            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1445            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1447            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1448            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1449        ])
1450    }
1451    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1452        Ok(vec![
1453            Regex::new(r"lm_head\.(weight|bias)$")?,
1454            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1456            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1457            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1458            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1459            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1461            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1462            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1463        ])
1464    }
1465}
1466
1467impl DeviceMappedModelLoader for LLaVALoader {
1468    fn mapped_max_act_size_elems(
1469        &self,
1470        config: &str,
1471        params: &AutoDeviceMapParams,
1472    ) -> Result<usize> {
1473        let AutoDeviceMapParams::Vision {
1474            max_seq_len,
1475            max_batch_size,
1476            max_image_shape: _,
1477            max_num_images,
1478        } = params
1479        else {
1480            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1481        };
1482
1483        let config: LLaVAConfig = serde_json::from_str(config)?;
1484
1485        let img_seq_len =
1486            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1487        let img_seq_len = img_seq_len * max_num_images;
1488
1489        let max_text_attn = {
1490            let cfg = &config.text_config;
1491            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1493
1494            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1495        };
1496
1497        Ok(max_text_attn)
1498    }
1499
1500    fn non_mapped_max_act_size_elems(
1501        &self,
1502        config: &str,
1503        params: &AutoDeviceMapParams,
1504    ) -> Result<usize> {
1505        let AutoDeviceMapParams::Vision {
1506            max_seq_len: _,
1507            max_batch_size,
1508            max_image_shape: _,
1509            max_num_images,
1510        } = params
1511        else {
1512            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1513        };
1514
1515        let config: LLaVAConfig = serde_json::from_str(config)?;
1516
1517        let img_seq_len =
1518            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1519
1520        let max_vision_attn = {
1521            (max_batch_size * max_num_images)
1522                * config.vision_config.num_attention_heads
1523                * img_seq_len
1524                * img_seq_len
1525        };
1526
1527        Ok(max_vision_attn)
1528    }
1529
1530    fn non_mapped_size_in_bytes(
1531        &self,
1532        config: &str,
1533        dtype: DType,
1534        weight_pack_factor: usize,
1535        _matformer_config: Option<&MatformerSliceConfig>,
1536    ) -> Result<usize> {
1537        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1538        let text_elems = {
1539            let cfg = &cfg.text_config;
1540            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1541            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1542            let norm = cfg.hidden_size;
1543            embed_tokens + lm_head + norm
1544        };
1545
1546        let image_newline = cfg.text_config.hidden_size;
1547        let mmproj = {
1548            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1549                + cfg.text_config.hidden_size;
1550            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1551                + cfg.text_config.hidden_size;
1552
1553            linear_1 + linear_2
1554        };
1555        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1556
1557        let elems = text_elems + image_newline + mmproj + vision_tower;
1558        Ok(elems * dtype.size_in_bytes())
1559    }
1560
1561    fn layer_sizes_in_bytes(
1562        &self,
1563        config: &str,
1564        dtype: DType,
1565        weight_pack_factor: usize,
1566        _matformer_config: Option<&MatformerSliceConfig>,
1567    ) -> Result<Vec<usize>> {
1568        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1569        let per_layer_elems = {
1570            let cfg = &cfg.text_config;
1571            let input_layernorm = cfg.hidden_size;
1572            let post_attention_layernorm = cfg.hidden_size;
1573
1574            let size_in = cfg.hidden_size;
1575            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1576            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1577            let q_proj = size_in * size_q / weight_pack_factor;
1578            let k_proj = size_in * size_kv / weight_pack_factor;
1579            let v_proj = size_in * size_kv / weight_pack_factor;
1580            let o_proj = size_q * size_in / weight_pack_factor;
1581
1582            let h_size = cfg.hidden_size;
1583            let i_size = cfg.intermediate_size;
1584            let gate_proj = h_size * i_size / weight_pack_factor;
1585            let up_proj = h_size * i_size / weight_pack_factor;
1586            let down_proj = i_size * h_size / weight_pack_factor;
1587
1588            input_layernorm
1589                + post_attention_layernorm
1590                + q_proj
1591                + k_proj
1592                + v_proj
1593                + o_proj
1594                + gate_proj
1595                + up_proj
1596                + down_proj
1597        };
1598        Ok(vec![
1599            per_layer_elems * dtype.size_in_bytes();
1600            cfg.text_config.num_hidden_layers
1601        ])
1602    }
1603
1604    fn num_layers(&self, config: &str) -> Result<usize> {
1605        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1606        Ok(cfg.text_config.num_hidden_layers)
1607    }
1608
1609    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1610        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1611        let cfg = &cfg.text_config;
1612
1613        let cfg = ModelConfigMetadata {
1614            max_seq_len: cfg.max_position_embeddings,
1615            num_layers: cfg.num_hidden_layers,
1616            hidden_size: cfg.hidden_size,
1617            num_kv_heads: cfg.num_key_value_heads,
1618            num_attn_heads: cfg.num_attention_heads,
1619            sliding_window: cfg.sliding_window,
1620            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1621            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1622        };
1623
1624        Ok(Box::new(cfg))
1625    }
1626
1627    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1628        Some(vec![NonMappedSubModel::Vision])
1629    }
1630}
1631
1632pub struct VLlamaLoader;
1638
1639pub struct VLlamaPrefixer;
1640
1641impl MultimodalPromptPrefixer for VLlamaPrefixer {
1642    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1643        format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1644    }
1645}
1646
1647impl VisionModelLoader for VLlamaLoader {
1648    fn load(
1649        &self,
1650        config: &str,
1651        vb: ShardedVarBuilder,
1652        normal_loading_metadata: NormalLoadingMetadata,
1653        attention_mechanism: AttentionImplementation,
1654    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1655        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1656        Ok(Box::new(MLlamaModel::new(
1657            &cfg,
1658            vb,
1659            self.is_gptx(config),
1660            normal_loading_metadata,
1661            attention_mechanism,
1662        )?))
1663    }
1664    fn is_gptx(&self, _config: &str) -> bool {
1665        true
1666    }
1667    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1668        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1669        Ok(Box::new(cfg))
1670    }
1671    fn get_processor(
1672        &self,
1673        _model_config: &str,
1674        _processor_config: Option<ProcessorConfig>,
1675        _preprocessor_config: PreProcessorConfig,
1676        _max_edge: Option<u32>,
1677    ) -> Arc<dyn Processor + Send + Sync> {
1678        Arc::new(MLlamaProcessor::new())
1679    }
1680    fn supports_paged_attention(&self, _config: &str) -> bool {
1681        false
1682    }
1683    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1684        true
1685    }
1686    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1687        Arc::new(VLlamaPrefixer)
1688    }
1689    fn modalities(&self, _config: &str) -> Result<Modalities> {
1690        Ok(Modalities {
1691            input: vec![SupportedModality::Text, SupportedModality::Vision],
1692            output: vec![SupportedModality::Text],
1693        })
1694    }
1695}
1696
1697impl IsqModelLoader for VLlamaLoader {
1698    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1699        let config: MLlamaConfig = serde_json::from_str(config)?;
1700        let cross_attn_layers = &config.text_config.cross_attention_layers;
1701        let transformer_layers =
1702            (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1703        let mut text_regexes = Vec::new();
1704        for layer in transformer_layers {
1705            text_regexes.extend(vec![
1706                Regex::new(&format!(
1708                    r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1709                ))?,
1710                Regex::new(&format!(
1711                    r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1712                ))?,
1713                Regex::new(&format!(
1714                    r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1715                ))?,
1716                Regex::new(&format!(
1717                    r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1718                ))?,
1719                Regex::new(&format!(
1721                    r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1722                ))?,
1723                Regex::new(&format!(
1724                    r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1725                ))?,
1726                Regex::new(&format!(
1727                    r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1728                ))?,
1729            ]);
1730        }
1731        let vision_regexes = vec![
1732            Regex::new(
1734                r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1735            )?,
1736            Regex::new(
1737                r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1738            )?,
1739            Regex::new(
1740                r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1741            )?,
1742            Regex::new(
1743                r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1744            )?,
1745            Regex::new(
1747                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1748            )?,
1749            Regex::new(
1750                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1751            )?,
1752            Regex::new(
1753                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1754            )?,
1755            Regex::new(
1756                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1757            )?,
1758            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1760            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1761        ];
1762
1763        Ok([text_regexes, vision_regexes].concat())
1764    }
1765    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1766        self.isq_layer_regexes(config)
1767    }
1768}
1769
1770impl DeviceMappedModelLoader for VLlamaLoader {
1771    fn mapped_max_act_size_elems(
1772        &self,
1773        config: &str,
1774        params: &AutoDeviceMapParams,
1775    ) -> Result<usize> {
1776        let AutoDeviceMapParams::Vision {
1777            max_seq_len,
1778            max_batch_size,
1779            max_image_shape: _,
1780            max_num_images,
1781        } = params
1782        else {
1783            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1784        };
1785
1786        let config: MLlamaConfig = serde_json::from_str(config)?;
1787
1788        let img_seq_len = {
1789            let cfg = &config.vision_config;
1790            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1791            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1792            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1793        };
1794        let img_seq_len = img_seq_len * max_num_images;
1795
1796        let max_cross_text_attn = {
1797            let cfg = &config.text_config;
1798            max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1799        };
1800
1801        let max_self_text_attn = {
1802            let cfg = &config.text_config;
1803            max_batch_size * cfg.num_attention_heads * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)
1804        };
1805
1806        Ok(max_self_text_attn.max(max_cross_text_attn))
1807    }
1808
1809    fn non_mapped_max_act_size_elems(
1810        &self,
1811        config: &str,
1812        params: &AutoDeviceMapParams,
1813    ) -> Result<usize> {
1814        let AutoDeviceMapParams::Vision {
1815            max_seq_len: _,
1816            max_batch_size,
1817            max_image_shape: _,
1818            max_num_images,
1819        } = params
1820        else {
1821            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1822        };
1823
1824        let config: MLlamaConfig = serde_json::from_str(config)?;
1825
1826        let img_seq_len = {
1827            let cfg = &config.vision_config;
1828            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1829            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1830            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1831        };
1832        let max_vision_attn = {
1833            let cfg = &config.vision_config;
1834            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1835        };
1836
1837        Ok(max_vision_attn)
1838    }
1839
1840    fn non_mapped_size_in_bytes(
1841        &self,
1842        config: &str,
1843        dtype: DType,
1844        weight_pack_factor: usize,
1845        _matformer_config: Option<&MatformerSliceConfig>,
1846    ) -> Result<usize> {
1847        let config: MLlamaConfig = serde_json::from_str(config)?;
1848        let text_elems = {
1849            let cfg = &config.text_config;
1850            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1851            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1853                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1854            } else {
1855                0
1856            };
1857            let norm = cfg.hidden_size;
1858            embed_tokens + lm_head + norm
1859        };
1860
1861        let vision_elems = {
1862            let cfg = &config.vision_config;
1863
1864            let conv_cfg = Conv2dConfig {
1865                stride: cfg.patch_size,
1866                ..Default::default()
1867            };
1868            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1869                * cfg.patch_size
1870                * cfg.patch_size;
1871
1872            let class_embedding = cfg.hidden_size;
1873
1874            let gated_positional_embedding = {
1875                let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1876                let embedding = num_patches * cfg.hidden_size;
1877                let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1878                    * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1879
1880                embedding + tile_embedding
1881            };
1882
1883            let pre_tile_positional_embedding =
1884                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1885            let post_tile_positional_embedding =
1886                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1887
1888            let layernorm_pre = cfg.hidden_size;
1889            let layernorm_post = cfg.hidden_size;
1890
1891            let encoder_layer = {
1892                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1893                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1894
1895                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1896                let q_proj =
1897                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1898                let k_proj =
1899                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1900                let v_proj =
1901                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1902                let o_proj =
1903                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1904
1905                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1906                    + cfg.intermediate_size;
1907                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1908                    + cfg.hidden_size;
1909
1910                input_layernorm
1911                    + post_attention_layernorm
1912                    + q_proj
1913                    + k_proj
1914                    + v_proj
1915                    + o_proj
1916                    + fc1
1917                    + fc2
1918            };
1919
1920            patch_embedding
1921                + class_embedding
1922                + gated_positional_embedding
1923                + pre_tile_positional_embedding
1924                + post_tile_positional_embedding
1925                + layernorm_pre
1926                + layernorm_post
1927                + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1928        };
1929
1930        let elems = text_elems + vision_elems;
1931        Ok(elems * dtype.size_in_bytes())
1932    }
1933
1934    fn layer_sizes_in_bytes(
1935        &self,
1936        config: &str,
1937        dtype: DType,
1938        weight_pack_factor: usize,
1939        _matformer_config: Option<&MatformerSliceConfig>,
1940    ) -> Result<Vec<usize>> {
1941        let config: MLlamaConfig = serde_json::from_str(config)?;
1942        let cfg = &config.text_config;
1943
1944        let mut layer_sizes = Vec::new();
1945
1946        for i in 0..cfg.num_hidden_layers {
1947            let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1948                1
1950            } else {
1951                weight_pack_factor
1952            };
1953
1954            let per_layer_elems = {
1955                let input_layernorm = cfg.hidden_size;
1956                let post_attention_layernorm = cfg.hidden_size;
1957
1958                let size_in = cfg.hidden_size;
1959                let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1960                let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1961                let q_proj = size_in * size_q / weight_pack_factor;
1962                let k_proj = size_in * size_kv / weight_pack_factor;
1963                let v_proj = size_in * size_kv / weight_pack_factor;
1964                let o_proj = size_q * size_in / weight_pack_factor;
1965
1966                let h_size = cfg.hidden_size;
1967                let i_size = cfg.intermediate_size;
1968                let gate_proj = h_size * i_size / weight_pack_factor;
1969                let up_proj = h_size * i_size / weight_pack_factor;
1970                let down_proj = i_size * h_size / weight_pack_factor;
1971
1972                input_layernorm
1973                    + post_attention_layernorm
1974                    + q_proj
1975                    + k_proj
1976                    + v_proj
1977                    + o_proj
1978                    + gate_proj
1979                    + up_proj
1980                    + down_proj
1981            };
1982
1983            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1984        }
1985
1986        Ok(layer_sizes)
1987    }
1988
1989    fn num_layers(&self, config: &str) -> Result<usize> {
1990        let config: MLlamaConfig = serde_json::from_str(config)?;
1991        Ok(config.text_config.num_hidden_layers)
1992    }
1993
1994    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1995        let cfg: MLlamaConfig = serde_json::from_str(config)?;
1996        let cfg = &cfg.text_config;
1997
1998        let cfg = ModelConfigMetadata {
1999            max_seq_len: cfg.max_position_embeddings,
2000            num_layers: cfg.num_hidden_layers,
2001            hidden_size: cfg.hidden_size,
2002            num_kv_heads: cfg.num_key_value_heads,
2003            num_attn_heads: cfg.num_attention_heads,
2004            sliding_window: None,
2005            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2006            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2007        };
2008
2009        Ok(Box::new(cfg))
2010    }
2011
2012    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2013        Some(vec![NonMappedSubModel::Vision])
2014    }
2015}
2016
2017pub struct Qwen2VLLoader;
2023
2024pub struct Qwen2VLPrefixer;
2025
2026impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
2027    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2028        format!(
2029            "{}{prompt}",
2030            format!(
2031                "{}{}{}",
2032                Qwen2VLProcessor::VISION_START,
2033                Qwen2VLProcessor::IMAGE_PAD,
2034                Qwen2VLProcessor::VISION_END
2035            )
2036            .repeat(image_indexes.len())
2037        )
2038    }
2039}
2040
2041impl VisionModelLoader for Qwen2VLLoader {
2042    fn load(
2043        &self,
2044        config: &str,
2045        vb: ShardedVarBuilder,
2046        normal_loading_metadata: NormalLoadingMetadata,
2047        attention_mechanism: AttentionImplementation,
2048    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2049        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2050        Ok(Box::new(Qwen2VLModel::new(
2051            &cfg,
2052            vb,
2053            self.is_gptx(config),
2054            normal_loading_metadata,
2055            attention_mechanism,
2056        )?))
2057    }
2058    fn is_gptx(&self, _config: &str) -> bool {
2059        true
2060    }
2061    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2062        let config: Qwen2VLConfig = serde_json::from_str(config)?;
2063        Ok(Box::new(config))
2064    }
2065    fn get_processor(
2066        &self,
2067        _model_config: &str,
2068        _processor_config: Option<ProcessorConfig>,
2069        _preprocessor_config: PreProcessorConfig,
2070        max_edge: Option<u32>,
2071    ) -> Arc<dyn Processor + Send + Sync> {
2072        Arc::new(Qwen2VLProcessor::new(max_edge))
2073    }
2074    fn supports_paged_attention(&self, _config: &str) -> bool {
2075        false
2076    }
2077    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2078        Arc::new(Qwen2VLPrefixer)
2079    }
2080    fn modalities(&self, _config: &str) -> Result<Modalities> {
2081        Ok(Modalities {
2082            input: vec![SupportedModality::Text, SupportedModality::Vision],
2083            output: vec![SupportedModality::Text],
2084        })
2085    }
2086}
2087
2088impl IsqModelLoader for Qwen2VLLoader {
2089    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2090        Ok(vec![
2091            Regex::new(r"lm_head\.(weight|bias)$")?,
2092            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2094            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2095            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2096            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2097            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2099            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2100            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2101        ])
2102    }
2103    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2104        self.isq_layer_regexes(config)
2105    }
2106}
2107
2108impl DeviceMappedModelLoader for Qwen2VLLoader {
2109    fn mapped_max_act_size_elems(
2110        &self,
2111        config: &str,
2112        params: &AutoDeviceMapParams,
2113    ) -> Result<usize> {
2114        let AutoDeviceMapParams::Vision {
2115            max_seq_len,
2116            max_batch_size,
2117            max_image_shape,
2118            max_num_images,
2119        } = params
2120        else {
2121            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2122        };
2123
2124        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2125
2126        let img_seq_len = {
2127            let cfg = &cfg.vision_config;
2128            let grid_t = max_num_images / cfg.temporal_patch_size;
2129            let grid_h = max_image_shape.0 / cfg.patch_size;
2130            let grid_w = max_image_shape.1 / cfg.patch_size;
2131            grid_t * grid_h * grid_w
2132        };
2133        let img_seq_len = img_seq_len * max_num_images;
2134
2135        let max_text_attn = {
2136            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2138            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2139        };
2140
2141        Ok(max_text_attn)
2142    }
2143
2144    fn non_mapped_max_act_size_elems(
2145        &self,
2146        config: &str,
2147        params: &AutoDeviceMapParams,
2148    ) -> Result<usize> {
2149        let AutoDeviceMapParams::Vision {
2150            max_seq_len: _,
2151            max_batch_size,
2152            max_image_shape,
2153            max_num_images,
2154        } = params
2155        else {
2156            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2157        };
2158
2159        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2160
2161        let img_seq_len = {
2162            let cfg = &cfg.vision_config;
2163            let grid_t = max_num_images / cfg.temporal_patch_size;
2164            let grid_h = max_image_shape.0 / cfg.patch_size;
2165            let grid_w = max_image_shape.1 / cfg.patch_size;
2166            grid_t * grid_h * grid_w
2167        };
2168
2169        let max_vision_attn = {
2170            let cfg = &cfg.vision_config;
2171            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2172        };
2173
2174        Ok(max_vision_attn)
2175    }
2176
2177    fn non_mapped_size_in_bytes(
2178        &self,
2179        config: &str,
2180        dtype: DType,
2181        weight_pack_factor: usize,
2182        _matformer_config: Option<&MatformerSliceConfig>,
2183    ) -> Result<usize> {
2184        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2185        let text_elems = {
2186            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2187            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2189                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2190            } else {
2191                0
2192            };
2193            let norm = cfg.hidden_size;
2194            embed_tokens + lm_head + norm
2195        };
2196
2197        let patch_merger = {
2198            let cfg = &cfg.vision_config;
2199            let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2200
2201            let mlp0 = hidden_size * hidden_size + hidden_size;
2202            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2203
2204            let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2205
2206            mlp0 + mlp2 + ln_q
2207        };
2208
2209        let patch_embed = {
2210            let cfg = &cfg.vision_config;
2211            let conv_cfg = Conv3dConfig {
2212                stride: cfg.patch_size,
2213                ..Default::default()
2214            };
2215            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2216            cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2217                * kernel_sizes[0]
2218                * kernel_sizes[1]
2219                * kernel_sizes[2]
2220        };
2221
2222        let encoder_layer = {
2223            let cfg = &cfg.vision_config;
2224            let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2225            let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2226
2227            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2228            let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2229            let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2230            let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2231
2232            let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2233            let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2234
2235            norm1 + norm2 + fc1 + fc2 + qkv + out
2236        };
2237
2238        let elems =
2239            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2240
2241        Ok(elems * dtype.size_in_bytes())
2242    }
2243
2244    fn layer_sizes_in_bytes(
2245        &self,
2246        config: &str,
2247        dtype: DType,
2248        weight_pack_factor: usize,
2249        _matformer_config: Option<&MatformerSliceConfig>,
2250    ) -> Result<Vec<usize>> {
2251        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2252        let per_layer_elems = {
2253            let input_layernorm = cfg.hidden_size;
2254            let post_attention_layernorm = cfg.hidden_size;
2255
2256            let size_in = cfg.hidden_size;
2257            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2258            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2259            let q_proj = size_in * size_q / weight_pack_factor + size_q;
2260            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2261            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2262            let o_proj = size_q * size_in / weight_pack_factor;
2263
2264            let h_size = cfg.hidden_size;
2265            let i_size = cfg.intermediate_size;
2266            let gate_proj = h_size * i_size / weight_pack_factor;
2267            let up_proj = h_size * i_size / weight_pack_factor;
2268            let down_proj = i_size * h_size / weight_pack_factor;
2269
2270            input_layernorm
2271                + post_attention_layernorm
2272                + q_proj
2273                + k_proj
2274                + v_proj
2275                + o_proj
2276                + gate_proj
2277                + up_proj
2278                + down_proj
2279        };
2280        Ok(vec![
2281            per_layer_elems * dtype.size_in_bytes();
2282            cfg.num_hidden_layers
2283        ])
2284    }
2285
2286    fn num_layers(&self, config: &str) -> Result<usize> {
2287        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2288        Ok(cfg.num_hidden_layers)
2289    }
2290
2291    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2292        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2293
2294        let cfg = ModelConfigMetadata {
2295            max_seq_len: cfg.max_position_embeddings,
2296            num_layers: cfg.num_hidden_layers,
2297            hidden_size: cfg.hidden_size,
2298            num_kv_heads: cfg.num_key_value_heads,
2299            num_attn_heads: cfg.num_attention_heads,
2300            sliding_window: cfg.sliding_window,
2301            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2302            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2303        };
2304
2305        Ok(Box::new(cfg))
2306    }
2307
2308    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2309        Some(vec![NonMappedSubModel::Vision])
2310    }
2311}
2312
2313pub struct Idefics3Loader;
2319
2320pub struct Idefics3Prefixer;
2321
2322impl MultimodalPromptPrefixer for Idefics3Prefixer {
2323    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2324        prompt.to_string()
2326    }
2327}
2328
2329impl VisionModelLoader for Idefics3Loader {
2330    fn load(
2331        &self,
2332        config: &str,
2333        vb: ShardedVarBuilder,
2334        normal_loading_metadata: NormalLoadingMetadata,
2335        attention_mechanism: AttentionImplementation,
2336    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2337        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2338        Ok(Box::new(Idefics3Model::new(
2339            &cfg,
2340            vb,
2341            self.is_gptx(config),
2342            normal_loading_metadata,
2343            attention_mechanism,
2344        )?))
2345    }
2346    fn is_gptx(&self, _config: &str) -> bool {
2347        true
2348    }
2349    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2350        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2351        Ok(Box::new(cfg))
2352    }
2353    fn get_processor(
2354        &self,
2355        _model_config: &str,
2356        processor_config: Option<ProcessorConfig>,
2357        preprocessor_config: PreProcessorConfig,
2358        max_edge: Option<u32>,
2359    ) -> Arc<dyn Processor + Send + Sync> {
2360        Arc::new(Idefics3Processor::new(
2361            processor_config.unwrap_or_default(),
2362            preprocessor_config,
2363            max_edge,
2364        ))
2365    }
2366    fn supports_paged_attention(&self, _config: &str) -> bool {
2367        true
2368    }
2369    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2370        true
2371    }
2372    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2373        Arc::new(Idefics3Prefixer)
2374    }
2375    fn modalities(&self, _config: &str) -> Result<Modalities> {
2376        Ok(Modalities {
2377            input: vec![SupportedModality::Text, SupportedModality::Vision],
2378            output: vec![SupportedModality::Text],
2379        })
2380    }
2381}
2382
2383impl IsqModelLoader for Idefics3Loader {
2384    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2385        Ok(vec![
2386            Regex::new(r"lm_head\.(weight|bias)$")?,
2387            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2389            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2390            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2391            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2392            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2394            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2395            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2396        ])
2397    }
2398    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2399        Ok(vec![
2400            Regex::new(r"lm_head\.(weight|bias)$")?,
2401            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2403            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2404            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2405            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2406            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2408            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2409            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2410            ])
2427    }
2428}
2429
2430impl DeviceMappedModelLoader for Idefics3Loader {
2431    fn mapped_max_act_size_elems(
2432        &self,
2433        config: &str,
2434        params: &AutoDeviceMapParams,
2435    ) -> Result<usize> {
2436        let AutoDeviceMapParams::Vision {
2437            max_seq_len,
2438            max_batch_size,
2439            max_image_shape: _,
2440            max_num_images,
2441        } = params
2442        else {
2443            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2444        };
2445
2446        let cfg: Idefics3Config = serde_json::from_str(config)?;
2447
2448        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2449        let img_seq_len = (num_patches + 1) * max_num_images;
2450
2451        let max_text_attn = {
2452            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2454            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2455        };
2456
2457        Ok(max_text_attn)
2458    }
2459
2460    fn non_mapped_max_act_size_elems(
2461        &self,
2462        config: &str,
2463        params: &AutoDeviceMapParams,
2464    ) -> Result<usize> {
2465        let AutoDeviceMapParams::Vision {
2466            max_seq_len: _,
2467            max_batch_size,
2468            max_image_shape: _,
2469            max_num_images,
2470        } = params
2471        else {
2472            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2473        };
2474
2475        let cfg: Idefics3Config = serde_json::from_str(config)?;
2476
2477        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2478        let img_seq_len = num_patches + 1;
2479
2480        let max_vision_attn = {
2481            let images_factor = 5;
2483
2484            (max_batch_size * images_factor * max_num_images)
2485                * cfg.vision_config.num_attention_heads
2486                * img_seq_len
2487                * img_seq_len
2488        };
2489
2490        Ok(max_vision_attn)
2491    }
2492
2493    fn non_mapped_size_in_bytes(
2494        &self,
2495        config: &str,
2496        dtype: DType,
2497        weight_pack_factor: usize,
2498        _matformer_config: Option<&MatformerSliceConfig>,
2499    ) -> Result<usize> {
2500        let cfg: Idefics3Config = serde_json::from_str(config)?;
2501        let text_elems = {
2502            let cfg = &cfg.text_config;
2503
2504            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2505            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2506            let norm = cfg.hidden_size;
2507            embed_tokens + lm_head + norm
2508        };
2509
2510        let connector_elems = {
2511            let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2512            let out_dim = cfg.text_config.hidden_size;
2513
2514            in_dim * out_dim
2515        };
2516
2517        let vision_transformer = {
2518            let cfg = &cfg.vision_config;
2519
2520            let post_layernorm = cfg.hidden_size;
2521
2522            let conv_config = Conv2dConfig {
2523                stride: cfg.patch_size,
2524                ..Default::default()
2525            };
2526            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2527                * cfg.patch_size
2528                * cfg.patch_size;
2529
2530            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2531            let num_patches = num_patches_per_side.pow(2);
2532            let position_embedding = num_patches * cfg.hidden_size;
2533
2534            let layer_elems = {
2535                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2536                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2537
2538                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2539                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2540
2541                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2542                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2543                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2544                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2545
2546                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2547            };
2548
2549            post_layernorm
2550                + patch_embedding
2551                + position_embedding
2552                + layer_elems * cfg.num_hidden_layers
2553        };
2554
2555        let elems = text_elems + connector_elems + vision_transformer;
2556
2557        Ok(elems * dtype.size_in_bytes())
2558    }
2559
2560    fn layer_sizes_in_bytes(
2561        &self,
2562        config: &str,
2563        dtype: DType,
2564        weight_pack_factor: usize,
2565        _matformer_config: Option<&MatformerSliceConfig>,
2566    ) -> Result<Vec<usize>> {
2567        let cfg: Idefics3Config = serde_json::from_str(config)?;
2568        let cfg = cfg.text_config;
2569        let per_layer_elems = {
2570            let input_layernorm = cfg.hidden_size;
2571            let post_attention_layernorm = cfg.hidden_size;
2572
2573            let size_in = cfg.hidden_size;
2574            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2575            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2576            let q_proj = size_in * size_q / weight_pack_factor;
2577            let k_proj = size_in * size_kv / weight_pack_factor;
2578            let v_proj = size_in * size_kv / weight_pack_factor;
2579            let o_proj = size_q * size_in / weight_pack_factor;
2580
2581            let h_size = cfg.hidden_size;
2582            let i_size = cfg.intermediate_size;
2583            let gate_proj = h_size * i_size / weight_pack_factor;
2584            let up_proj = h_size * i_size / weight_pack_factor;
2585            let down_proj = i_size * h_size / weight_pack_factor;
2586
2587            input_layernorm
2588                + post_attention_layernorm
2589                + q_proj
2590                + k_proj
2591                + v_proj
2592                + o_proj
2593                + gate_proj
2594                + up_proj
2595                + down_proj
2596        };
2597        Ok(vec![
2598            per_layer_elems * dtype.size_in_bytes();
2599            cfg.num_hidden_layers
2600        ])
2601    }
2602
2603    fn num_layers(&self, config: &str) -> Result<usize> {
2604        let cfg: Idefics3Config = serde_json::from_str(config)?;
2605        Ok(cfg.text_config.num_hidden_layers)
2606    }
2607    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2608        let cfg: Idefics3Config = serde_json::from_str(config)?;
2609        let cfg = &cfg.text_config;
2610
2611        let cfg = ModelConfigMetadata {
2612            max_seq_len: cfg.max_position_embeddings,
2613            num_layers: cfg.num_hidden_layers,
2614            hidden_size: cfg.hidden_size,
2615            num_kv_heads: cfg.num_key_value_heads,
2616            num_attn_heads: cfg.num_attention_heads,
2617            sliding_window: None,
2618            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2619            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2620        };
2621
2622        Ok(Box::new(cfg))
2623    }
2624
2625    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2626        Some(vec![NonMappedSubModel::Vision])
2627    }
2628}
2629
2630pub struct MiniCpmOLoader;
2636
2637pub struct MiniCpmOPrefixer;
2638
2639impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2640    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2641        format!(
2642            "{}{prompt}",
2643            "(<image>./</image>)".repeat(image_indexes.len())
2644        )
2645    }
2646}
2647
2648impl VisionModelLoader for MiniCpmOLoader {
2649    fn load(
2650        &self,
2651        config: &str,
2652        vb: ShardedVarBuilder,
2653        normal_loading_metadata: NormalLoadingMetadata,
2654        attention_mechanism: AttentionImplementation,
2655    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2656        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2657        Ok(Box::new(MiniCpmOModel::new(
2658            &cfg,
2659            vb,
2660            self.is_gptx(config),
2661            normal_loading_metadata,
2662            attention_mechanism,
2663        )?))
2664    }
2665    fn is_gptx(&self, _config: &str) -> bool {
2666        true
2667    }
2668    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2669        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2670        Ok(Box::new(cfg))
2671    }
2672    fn get_processor(
2673        &self,
2674        _model_config: &str,
2675        processor_config: Option<ProcessorConfig>,
2676        preprocessor_config: PreProcessorConfig,
2677        max_edge: Option<u32>,
2678    ) -> Arc<dyn Processor + Send + Sync> {
2679        Arc::new(MiniCpmOProcessor::new(
2680            processor_config.unwrap_or_default(),
2681            preprocessor_config,
2682            max_edge,
2683        ))
2684    }
2685    fn supports_paged_attention(&self, _config: &str) -> bool {
2686        true
2687    }
2688    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2689        Arc::new(MiniCpmOPrefixer)
2690    }
2691    fn modalities(&self, _config: &str) -> Result<Modalities> {
2692        Ok(Modalities {
2693            input: vec![SupportedModality::Text, SupportedModality::Vision],
2694            output: vec![SupportedModality::Text],
2695        })
2696    }
2697}
2698
2699impl IsqModelLoader for MiniCpmOLoader {
2700    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2701        Ok(vec![
2702            Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2703            Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2705            Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2706            Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2707            Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2708            Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2710            Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2711            Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2712        ])
2713    }
2714    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2715        self.isq_layer_regexes(config)
2716    }
2717}
2718
2719impl DeviceMappedModelLoader for MiniCpmOLoader {
2720    fn mapped_max_act_size_elems(
2721        &self,
2722        config: &str,
2723        params: &AutoDeviceMapParams,
2724    ) -> Result<usize> {
2725        let AutoDeviceMapParams::Vision {
2726            max_seq_len,
2727            max_batch_size,
2728            max_image_shape: _,
2729            max_num_images,
2730        } = params
2731        else {
2732            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2733        };
2734
2735        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2736
2737        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2738        let img_seq_len = (num_patches + 1) * max_num_images;
2739
2740        let max_text_attn = {
2741            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2743            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2744        };
2745
2746        Ok(max_text_attn)
2747    }
2748
2749    fn non_mapped_max_act_size_elems(
2750        &self,
2751        config: &str,
2752        params: &AutoDeviceMapParams,
2753    ) -> Result<usize> {
2754        let AutoDeviceMapParams::Vision {
2755            max_seq_len: _,
2756            max_batch_size,
2757            max_image_shape: _,
2758            max_num_images,
2759        } = params
2760        else {
2761            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2762        };
2763
2764        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2765
2766        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2767        let img_seq_len = num_patches + 1;
2768
2769        let max_vision_attn = {
2770            let images_factor = 5;
2772
2773            (max_batch_size * images_factor * max_num_images)
2774                * cfg.vision_config.num_attention_heads
2775                * img_seq_len
2776                * img_seq_len
2777        };
2778
2779        Ok(max_vision_attn)
2780    }
2781
2782    fn non_mapped_size_in_bytes(
2783        &self,
2784        config: &str,
2785        dtype: DType,
2786        weight_pack_factor: usize,
2787        _matformer_config: Option<&MatformerSliceConfig>,
2788    ) -> Result<usize> {
2789        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2790        let text_elems = {
2791            let cfg = &cfg.text_config;
2792
2793            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2794            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2795            let norm = cfg.hidden_size;
2796            embed_tokens + lm_head + norm
2797        };
2798
2799        let vision_transformer = {
2800            let cfg = &cfg.vision_config;
2801
2802            let post_layernorm = cfg.hidden_size;
2803
2804            let conv_config = Conv2dConfig {
2805                stride: cfg.patch_size,
2806                ..Default::default()
2807            };
2808            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2809                * cfg.patch_size
2810                * cfg.patch_size;
2811
2812            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2813            let num_patches = num_patches_per_side.pow(2);
2814            let position_embedding = num_patches * cfg.hidden_size;
2815
2816            let layer_elems = {
2817                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2818                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2819
2820                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2821                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2822
2823                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2824                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2825                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2826                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2827
2828                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2829            };
2830
2831            post_layernorm
2832                + patch_embedding
2833                + position_embedding
2834                + layer_elems * cfg.num_hidden_layers
2835        };
2836
2837        let elems = text_elems + vision_transformer;
2838
2839        Ok(elems * dtype.size_in_bytes())
2840    }
2841
2842    fn layer_sizes_in_bytes(
2843        &self,
2844        config: &str,
2845        dtype: DType,
2846        weight_pack_factor: usize,
2847        _matformer_config: Option<&MatformerSliceConfig>,
2848    ) -> Result<Vec<usize>> {
2849        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2850        let cfg = cfg.text_config;
2851        let per_layer_elems = {
2852            let input_layernorm = cfg.hidden_size;
2853            let post_attention_layernorm = cfg.hidden_size;
2854
2855            let size_in = cfg.hidden_size;
2856            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2857            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2858            let q_proj = size_in * size_q / weight_pack_factor;
2859            let k_proj = size_in * size_kv / weight_pack_factor;
2860            let v_proj = size_in * size_kv / weight_pack_factor;
2861            let o_proj = size_q * size_in / weight_pack_factor;
2862
2863            let h_size = cfg.hidden_size;
2864            let i_size = cfg.intermediate_size;
2865            let gate_proj = h_size * i_size / weight_pack_factor;
2866            let up_proj = h_size * i_size / weight_pack_factor;
2867            let down_proj = i_size * h_size / weight_pack_factor;
2868
2869            input_layernorm
2870                + post_attention_layernorm
2871                + q_proj
2872                + k_proj
2873                + v_proj
2874                + o_proj
2875                + gate_proj
2876                + up_proj
2877                + down_proj
2878        };
2879        Ok(vec![
2880            per_layer_elems * dtype.size_in_bytes();
2881            cfg.num_hidden_layers
2882        ])
2883    }
2884
2885    fn num_layers(&self, config: &str) -> Result<usize> {
2886        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2887        Ok(cfg.text_config.num_hidden_layers)
2888    }
2889    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2890        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2891        let cfg = &cfg.text_config;
2892
2893        let cfg = ModelConfigMetadata {
2894            max_seq_len: cfg.max_position_embeddings,
2895            num_layers: cfg.num_hidden_layers,
2896            hidden_size: cfg.hidden_size,
2897            num_kv_heads: cfg.num_key_value_heads,
2898            num_attn_heads: cfg.num_attention_heads,
2899            sliding_window: None,
2900            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2901            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2902        };
2903
2904        Ok(Box::new(cfg))
2905    }
2906}
2907
2908pub struct Phi4MMLoader;
2914
2915pub struct Phi4MMPrefixer;
2916
2917impl MultimodalPromptPrefixer for Phi4MMPrefixer {
2918    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2919        format!(
2922            "{}{prompt}",
2923            image_indexes
2924                .into_iter()
2925                .map(|image_index| format!("<|image_{}|>", image_index + 1))
2926                .join("")
2927        )
2928    }
2929    fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
2930        format!(
2933            "{}{prompt}",
2934            audio_indexes
2935                .into_iter()
2936                .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
2937                .join("")
2938        )
2939    }
2940}
2941
2942impl VisionModelLoader for Phi4MMLoader {
2943    fn load(
2944        &self,
2945        config: &str,
2946        vb: ShardedVarBuilder,
2947        normal_loading_metadata: NormalLoadingMetadata,
2948        attention_mechanism: AttentionImplementation,
2949    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2950        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2951        Ok(Box::new(Phi4MMModel::new(
2952            &cfg,
2953            vb,
2954            self.is_gptx(config),
2955            normal_loading_metadata,
2956            attention_mechanism,
2957        )?))
2958    }
2959    fn is_gptx(&self, _config: &str) -> bool {
2960        true
2961    }
2962    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2963        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2964        Ok(Box::new(cfg))
2965    }
2966    fn get_processor(
2967        &self,
2968        _model_config: &str,
2969        processor_config: Option<ProcessorConfig>,
2970        preprocessor_config: PreProcessorConfig,
2971        _max_edge: Option<u32>,
2972    ) -> Arc<dyn Processor + Send + Sync> {
2973        Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2974    }
2975    fn supports_paged_attention(&self, _config: &str) -> bool {
2976        true
2977    }
2978    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2979        true
2980    }
2981    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2982        Arc::new(Phi4MMPrefixer)
2983    }
2984    fn modalities(&self, _config: &str) -> Result<Modalities> {
2985        Ok(Modalities {
2986            input: vec![
2987                SupportedModality::Text,
2988                SupportedModality::Vision,
2989                SupportedModality::Audio,
2990            ],
2991            output: vec![SupportedModality::Text],
2992        })
2993    }
2994}
2995
2996impl IsqModelLoader for Phi4MMLoader {
2997    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2998        Ok(vec![
2999            Regex::new(r"lm_head\.(weight|bias)$")?,
3000            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
3002            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3003            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
3005            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3006        ])
3007    }
3008    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3009        self.isq_layer_regexes(config)
3010    }
3011}
3012
3013impl DeviceMappedModelLoader for Phi4MMLoader {
3014    fn mapped_max_act_size_elems(
3015        &self,
3016        config: &str,
3017        params: &AutoDeviceMapParams,
3018    ) -> Result<usize> {
3019        let AutoDeviceMapParams::Vision {
3021            max_seq_len,
3022            max_batch_size,
3023            max_image_shape: _,
3024            max_num_images,
3025        } = params
3026        else {
3027            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3028        };
3029
3030        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3031
3032        let vcfg = &PHI4_MM_VISION_CFG;
3033
3034        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3035        let img_seq_len = (num_patches + 1) * max_num_images;
3036
3037        let max_text_attn = {
3038            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3040            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3041        };
3042
3043        Ok(max_text_attn)
3044    }
3045
3046    fn non_mapped_max_act_size_elems(
3047        &self,
3048        _config: &str,
3049        params: &AutoDeviceMapParams,
3050    ) -> Result<usize> {
3051        let AutoDeviceMapParams::Vision {
3052            max_seq_len: _,
3053            max_batch_size,
3054            max_image_shape,
3055            max_num_images,
3056        } = params
3057        else {
3058            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3059        };
3060
3061        let vcfg = &PHI4_MM_VISION_CFG;
3062
3063        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3064        let img_seq_len = num_patches + 1;
3065
3066        let max_batch_size = max_batch_size
3067            * (max_image_shape
3068                .0
3069                .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3070                * max_image_shape
3071                    .1
3072                    .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3073                + 1);
3074
3075        let max_vision_attn = (max_batch_size * max_num_images)
3076            * vcfg.num_attention_heads
3077            * img_seq_len
3078            * img_seq_len;
3079        let max_qkv = 3
3080            * (max_batch_size
3081                * vcfg.num_attention_heads
3082                * img_seq_len
3083                * (vcfg.hidden_size / vcfg.num_attention_heads));
3084
3085        Ok(max_vision_attn + max_qkv)
3086    }
3087
3088    fn non_mapped_size_in_bytes(
3089        &self,
3090        config: &str,
3091        dtype: DType,
3092        weight_pack_factor: usize,
3093        _matformer_config: Option<&MatformerSliceConfig>,
3094    ) -> Result<usize> {
3095        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3096        let elems = {
3097            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3098            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3100                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3101            } else {
3102                0
3103            };
3104            let norm = cfg.hidden_size;
3105
3106            let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3107                let projection_cls = img_embed
3108                    .projection_cls
3109                    .clone()
3110                    .unwrap_or("linear".to_string());
3111                let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3112                let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3113                let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3114
3115                let proj = match (projection_cls.as_str(), use_hd_transform) {
3116                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3117                    ("mlp", true) => {
3118                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3119                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3120                        a + b
3121                    }
3122                    ("mlp", false) => {
3123                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3124                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3125                        a + b
3126                    }
3127                    _ => {
3128                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3129                    }
3130                };
3131
3132                let (glb_gn, sub_gn) = if with_learnable_separator {
3133                    let glb_gn = image_dim_out * 4;
3134                    let sub_gn = image_dim_out * 4;
3135                    (glb_gn, sub_gn)
3136                } else {
3137                    (0, 0)
3138                };
3139
3140                let vision_transformer = {
3141                    let cfg = &PHI4_MM_VISION_CFG;
3142
3143                    let post_layernorm = cfg.hidden_size;
3144
3145                    let conv_config = Conv2dConfig {
3146                        stride: cfg.patch_size,
3147                        ..Default::default()
3148                    };
3149                    let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3150                        * cfg.patch_size
3151                        * cfg.patch_size;
3152
3153                    let num_patches_per_side = cfg.image_size / cfg.patch_size;
3154                    let num_patches = num_patches_per_side.pow(2);
3155                    let position_embedding = num_patches * cfg.hidden_size;
3156
3157                    let layer_elems = {
3158                        let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3159                        let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3160
3161                        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3162                        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3163
3164                        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3165                        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3166                        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3167                        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3168
3169                        layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3170                    };
3171
3172                    post_layernorm
3173                        + patch_embedding
3174                        + position_embedding
3175                        + layer_elems * cfg.num_hidden_layers
3176                };
3177
3178                proj + glb_gn + sub_gn + vision_transformer
3179            } else {
3180                0
3181            };
3182
3183            embed_tokens + lm_head + norm + image_embed
3184        };
3185
3186        Ok(elems * dtype.size_in_bytes())
3187    }
3188
3189    fn layer_sizes_in_bytes(
3190        &self,
3191        config: &str,
3192        dtype: DType,
3193        weight_pack_factor: usize,
3194        _matformer_config: Option<&MatformerSliceConfig>,
3195    ) -> Result<Vec<usize>> {
3196        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3197        let per_layer_elems = {
3198            let input_layernorm = cfg.hidden_size;
3199            let post_attention_layernorm = cfg.hidden_size;
3200
3201            let size_in = cfg.hidden_size;
3202            let head_dim = cfg.head_dim();
3203            let op_size =
3204                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3205            let qkv_proj = size_in * op_size / weight_pack_factor;
3206            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3207
3208            let h_size = cfg.hidden_size;
3209            let i_size = cfg.intermediate_size;
3210            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3211            let down_proj = h_size * i_size / weight_pack_factor;
3212
3213            input_layernorm
3214                + post_attention_layernorm
3215                + qkv_proj
3216                + o_proj
3217                + gate_up_proj
3218                + down_proj
3219        };
3220        Ok(vec![
3221            per_layer_elems * dtype.size_in_bytes();
3222            cfg.num_hidden_layers
3223        ])
3224    }
3225
3226    fn num_layers(&self, config: &str) -> Result<usize> {
3227        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3228        Ok(cfg.num_hidden_layers)
3229    }
3230
3231    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3232        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3233
3234        let cfg = ModelConfigMetadata {
3235            max_seq_len: cfg.max_position_embeddings,
3236            num_layers: cfg.num_hidden_layers,
3237            hidden_size: cfg.hidden_size,
3238            num_kv_heads: cfg.num_key_value_heads(),
3239            num_attn_heads: cfg.num_attention_heads,
3240            sliding_window: cfg.sliding_window,
3241            k_head_dim: cfg.head_dim(),
3242            v_head_dim: cfg.head_dim(),
3243        };
3244
3245        Ok(Box::new(cfg))
3246    }
3247
3248    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3249        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
3250    }
3251}
3252
3253pub struct Qwen2_5VLLoader;
3259
3260pub struct Qwen2_5VLPrefixer;
3261
3262impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3263    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3264        format!(
3265            "{}{prompt}",
3266            format!(
3267                "{}{}{}",
3268                Qwen2_5VLProcessor::VISION_START,
3269                Qwen2_5VLProcessor::IMAGE_PAD,
3270                Qwen2_5VLProcessor::VISION_END
3271            )
3272            .repeat(image_indexes.len())
3273        )
3274    }
3275}
3276
3277impl VisionModelLoader for Qwen2_5VLLoader {
3278    fn load(
3279        &self,
3280        config: &str,
3281        vb: ShardedVarBuilder,
3282        normal_loading_metadata: NormalLoadingMetadata,
3283        attention_mechanism: AttentionImplementation,
3284    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3285        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3286        Ok(Box::new(Qwen2_5VLModel::new(
3287            &cfg,
3288            vb,
3289            self.is_gptx(config),
3290            normal_loading_metadata,
3291            attention_mechanism,
3292        )?))
3293    }
3294    fn is_gptx(&self, _config: &str) -> bool {
3295        true
3296    }
3297    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3298        let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3299        Ok(Box::new(config))
3300    }
3301    fn get_processor(
3302        &self,
3303        _model_config: &str,
3304        _processor_config: Option<ProcessorConfig>,
3305        _preprocessor_config: PreProcessorConfig,
3306        max_edge: Option<u32>,
3307    ) -> Arc<dyn Processor + Send + Sync> {
3308        Arc::new(Qwen2_5VLProcessor::new(max_edge))
3309    }
3310    fn supports_paged_attention(&self, _config: &str) -> bool {
3311        false
3312    }
3313    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3314        Arc::new(Qwen2_5VLPrefixer)
3315    }
3316    fn modalities(&self, _config: &str) -> Result<Modalities> {
3317        Ok(Modalities {
3318            input: vec![SupportedModality::Text, SupportedModality::Vision],
3319            output: vec![SupportedModality::Text],
3320        })
3321    }
3322}
3323
3324impl IsqModelLoader for Qwen2_5VLLoader {
3325    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3326        Ok(vec![
3327            Regex::new(r"lm_head\.(weight|bias)$")?,
3328            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3330            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3331            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3332            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3333            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3335            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3336            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3337        ])
3338    }
3339    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3340        self.isq_layer_regexes(config)
3341    }
3342}
3343
3344impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3345    fn mapped_max_act_size_elems(
3346        &self,
3347        config: &str,
3348        params: &AutoDeviceMapParams,
3349    ) -> Result<usize> {
3350        let AutoDeviceMapParams::Vision {
3351            max_seq_len,
3352            max_batch_size,
3353            max_image_shape,
3354            max_num_images,
3355        } = params
3356        else {
3357            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3358        };
3359
3360        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3361
3362        let img_seq_len = {
3363            let cfg = &cfg.vision_config;
3364            let grid_t = max_num_images / cfg.temporal_patch_size;
3365            let grid_h = max_image_shape.0 / cfg.patch_size;
3366            let grid_w = max_image_shape.1 / cfg.patch_size;
3367            grid_t * grid_h * grid_w
3368        };
3369        let img_seq_len = img_seq_len * max_num_images;
3370
3371        let max_text_attn = {
3372            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3374            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3375        };
3376
3377        Ok(max_text_attn)
3378    }
3379
3380    fn non_mapped_max_act_size_elems(
3381        &self,
3382        config: &str,
3383        params: &AutoDeviceMapParams,
3384    ) -> Result<usize> {
3385        let AutoDeviceMapParams::Vision {
3386            max_seq_len: _,
3387            max_batch_size,
3388            max_image_shape,
3389            max_num_images,
3390        } = params
3391        else {
3392            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3393        };
3394
3395        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3396
3397        let img_seq_len = {
3398            let cfg = &cfg.vision_config;
3399            let grid_t = max_num_images / cfg.temporal_patch_size;
3400            let grid_h = max_image_shape.0 / cfg.patch_size;
3401            let grid_w = max_image_shape.1 / cfg.patch_size;
3402            grid_t * grid_h * grid_w
3403        };
3404
3405        let max_vision_attn = {
3406            let cfg = &cfg.vision_config;
3407            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3408        };
3409
3410        Ok(max_vision_attn)
3411    }
3412
3413    fn non_mapped_size_in_bytes(
3414        &self,
3415        config: &str,
3416        dtype: DType,
3417        weight_pack_factor: usize,
3418        _matformer_config: Option<&MatformerSliceConfig>,
3419    ) -> Result<usize> {
3420        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3421        let text_elems = {
3422            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3423            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3425                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3426            } else {
3427                0
3428            };
3429            let norm = cfg.hidden_size;
3430            embed_tokens + lm_head + norm
3431        };
3432
3433        let patch_merger = {
3434            let cfg = &cfg.vision_config;
3435            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3436
3437            let mlp0 = hidden_size * hidden_size + hidden_size;
3438            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3439
3440            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3441
3442            mlp0 + mlp2 + ln_q
3443        };
3444
3445        let patch_embed = {
3446            let cfg = &cfg.vision_config;
3447            let conv_cfg = Conv3dConfig {
3448                stride: cfg.patch_size,
3449                ..Default::default()
3450            };
3451            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3452            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3453                * kernel_sizes[0]
3454                * kernel_sizes[1]
3455                * kernel_sizes[2]
3456        };
3457
3458        let encoder_layer = {
3459            let cfg = &cfg.vision_config;
3460            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3461            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3462
3463            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3464            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3465            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3466
3467            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3468            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3469
3470            norm1 + norm2 + fc1 + fc2 + qkv + out
3471        };
3472
3473        let elems =
3474            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3475
3476        Ok(elems * dtype.size_in_bytes())
3477    }
3478
3479    fn layer_sizes_in_bytes(
3480        &self,
3481        config: &str,
3482        dtype: DType,
3483        weight_pack_factor: usize,
3484        _matformer_config: Option<&MatformerSliceConfig>,
3485    ) -> Result<Vec<usize>> {
3486        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3487        let per_layer_elems = {
3488            let input_layernorm = cfg.hidden_size;
3489            let post_attention_layernorm = cfg.hidden_size;
3490
3491            let size_in = cfg.hidden_size;
3492            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3493            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3494            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3495            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3496            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3497            let o_proj = size_q * size_in / weight_pack_factor;
3498
3499            let h_size = cfg.hidden_size;
3500            let i_size = cfg.intermediate_size;
3501            let gate_proj = h_size * i_size / weight_pack_factor;
3502            let up_proj = h_size * i_size / weight_pack_factor;
3503            let down_proj = i_size * h_size / weight_pack_factor;
3504
3505            input_layernorm
3506                + post_attention_layernorm
3507                + q_proj
3508                + k_proj
3509                + v_proj
3510                + o_proj
3511                + gate_proj
3512                + up_proj
3513                + down_proj
3514        };
3515        Ok(vec![
3516            per_layer_elems * dtype.size_in_bytes();
3517            cfg.num_hidden_layers
3518        ])
3519    }
3520
3521    fn num_layers(&self, config: &str) -> Result<usize> {
3522        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3523        Ok(cfg.num_hidden_layers)
3524    }
3525
3526    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3527        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3528
3529        let cfg = ModelConfigMetadata {
3530            max_seq_len: cfg.max_position_embeddings,
3531            num_layers: cfg.num_hidden_layers,
3532            hidden_size: cfg.hidden_size,
3533            num_kv_heads: cfg.num_key_value_heads,
3534            num_attn_heads: cfg.num_attention_heads,
3535            sliding_window: cfg.sliding_window,
3536            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3537            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3538        };
3539
3540        Ok(Box::new(cfg))
3541    }
3542
3543    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3544        Some(vec![NonMappedSubModel::Vision])
3545    }
3546}
3547
3548pub struct Gemma3Loader;
3554
3555pub struct Gemma3Prefixer;
3556
3557impl MultimodalPromptPrefixer for Gemma3Prefixer {
3558    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3559        prompt.to_string()
3560    }
3561}
3562
3563impl VisionModelLoader for Gemma3Loader {
3564    fn load(
3565        &self,
3566        config: &str,
3567        vb: ShardedVarBuilder,
3568        normal_loading_metadata: NormalLoadingMetadata,
3569        attention_mechanism: AttentionImplementation,
3570    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3571        let cfg: Gemma3Config = serde_json::from_str(config)?;
3572        Ok(Box::new(Gemma3Model::new(
3573            &cfg,
3574            vb,
3575            self.is_gptx(config),
3576            normal_loading_metadata,
3577            attention_mechanism,
3578        )?))
3579    }
3580    fn is_gptx(&self, _config: &str) -> bool {
3581        true
3582    }
3583    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3584        let config: Gemma3Config = serde_json::from_str(config)?;
3585        Ok(Box::new(config))
3586    }
3587    fn get_processor(
3588        &self,
3589        config: &str,
3590        processor_config: Option<ProcessorConfig>,
3591        _preprocessor_config: PreProcessorConfig,
3592        _max_edge: Option<u32>,
3593    ) -> Arc<dyn Processor + Send + Sync> {
3594        let config: Gemma3Config = serde_json::from_str(config).unwrap();
3595        Arc::new(Gemma3Processor::new(
3597            processor_config.unwrap_or_default(),
3598            matches!(config, Gemma3Config::WithVision { .. }),
3599        ))
3600    }
3601    fn supports_paged_attention(&self, _config: &str) -> bool {
3602        true
3603    }
3604    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3605        true
3606    }
3607    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3608        Arc::new(Gemma3Prefixer)
3609    }
3610    fn modalities(&self, _config: &str) -> Result<Modalities> {
3611        Ok(Modalities {
3612            input: vec![SupportedModality::Text, SupportedModality::Vision],
3613            output: vec![SupportedModality::Text],
3614        })
3615    }
3616}
3617
3618impl IsqModelLoader for Gemma3Loader {
3619    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3620        Ok(vec![
3621            Regex::new(r"lm_head\.(weight|bias)$")?,
3622            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3624            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3625            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3626            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3627            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3629            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3630            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3631        ])
3632    }
3633    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3634        Ok(vec![
3635            Regex::new(r"lm_head\.(weight|bias)$")?,
3636            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3638            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3639            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3640            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3641            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3643            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3644            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3645        ])
3646    }
3647}
3648
3649impl DeviceMappedModelLoader for Gemma3Loader {
3650    fn mapped_max_act_size_elems(
3651        &self,
3652        config: &str,
3653        params: &AutoDeviceMapParams,
3654    ) -> Result<usize> {
3655        let AutoDeviceMapParams::Vision {
3656            max_seq_len,
3657            max_batch_size,
3658            max_image_shape: _,
3659            max_num_images,
3660        } = params
3661        else {
3662            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3663        };
3664
3665        let cfg: Gemma3Config = serde_json::from_str(config)?;
3666
3667        match cfg {
3668            Gemma3Config::Text(text_config) => Ok(max_batch_size
3669                * text_config.num_attention_heads
3670                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)),
3671            Gemma3Config::WithVision {
3672                text_config,
3673                vision_config,
3674                ..
3675            } => {
3676                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3677                let img_seq_len = (num_patches + 1) * max_num_images;
3678
3679                let max_text_attn = {
3680                    let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3682                    max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3683                };
3684                Ok(max_text_attn)
3685            }
3686        }
3687    }
3688
3689    fn non_mapped_max_act_size_elems(
3690        &self,
3691        config: &str,
3692        params: &AutoDeviceMapParams,
3693    ) -> Result<usize> {
3694        let AutoDeviceMapParams::Vision {
3695            max_seq_len: _,
3696            max_batch_size,
3697            max_image_shape: _,
3698            max_num_images,
3699        } = params
3700        else {
3701            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3702        };
3703
3704        let cfg: Gemma3Config = serde_json::from_str(config)?;
3705
3706        match cfg {
3707            Gemma3Config::WithVision { vision_config, .. } => {
3708                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3709                let img_seq_len = num_patches + 1;
3710
3711                let max_vision_attn = {
3712                    (max_batch_size * max_num_images)
3713                        * vision_config.num_attention_heads
3714                        * img_seq_len
3715                        * img_seq_len
3716                };
3717
3718                Ok(max_vision_attn)
3719            }
3720            Gemma3Config::Text(_) => Ok(0),
3721        }
3722    }
3723
3724    fn non_mapped_size_in_bytes(
3725        &self,
3726        config: &str,
3727        dtype: DType,
3728        weight_pack_factor: usize,
3729        _matformer_config: Option<&MatformerSliceConfig>,
3730    ) -> Result<usize> {
3731        let cfg: Gemma3Config = serde_json::from_str(config)?;
3732
3733        let text_elems = {
3734            let cfg = match &cfg {
3735                Gemma3Config::Text(cfg) => cfg,
3736                Gemma3Config::WithVision { text_config, .. } => text_config,
3737            };
3738            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3739            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3741                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3742            } else {
3743                0
3744            };
3745            let norm = cfg.hidden_size;
3746            embed_tokens + lm_head + norm
3747        };
3748
3749        let vision_transformer = if let Gemma3Config::WithVision {
3750            vision_config: cfg, ..
3751        } = &cfg
3752        {
3753            let post_layernorm = cfg.hidden_size;
3754
3755            let conv_config = Conv2dConfig {
3756                stride: cfg.patch_size,
3757                ..Default::default()
3758            };
3759            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3760                * cfg.patch_size
3761                * cfg.patch_size;
3762
3763            let num_patches_per_side = cfg.image_size / cfg.patch_size;
3764            let num_patches = num_patches_per_side.pow(2);
3765            let position_embedding = num_patches * cfg.hidden_size;
3766
3767            let layer_elems = {
3768                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3769                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3770
3771                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3772                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3773
3774                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3775                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3776                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3777                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3778
3779                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3780            };
3781
3782            post_layernorm
3783                + patch_embedding
3784                + position_embedding
3785                + layer_elems * cfg.num_hidden_layers
3786        } else {
3787            0
3788        };
3789
3790        let elems = text_elems + vision_transformer;
3791
3792        Ok(elems * dtype.size_in_bytes())
3793    }
3794
3795    fn layer_sizes_in_bytes(
3796        &self,
3797        config: &str,
3798        dtype: DType,
3799        weight_pack_factor: usize,
3800        _matformer_config: Option<&MatformerSliceConfig>,
3801    ) -> Result<Vec<usize>> {
3802        let cfg: Gemma3Config = serde_json::from_str(config)?;
3803
3804        let txt_cfg = match &cfg {
3805            Gemma3Config::Text(cfg) => cfg,
3806            Gemma3Config::WithVision { text_config, .. } => text_config,
3807        };
3808        let per_layer_elems = {
3809            let cfg = txt_cfg;
3810
3811            let input_layernorm = cfg.hidden_size;
3812            let post_attention_layernorm = cfg.hidden_size;
3813
3814            let size_in = cfg.hidden_size;
3815            let size_q = cfg.head_dim * cfg.num_attention_heads;
3816            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3817            let q_proj =
3818                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3819            let k_proj =
3820                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3821            let v_proj =
3822                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3823            let o_proj =
3824                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3825
3826            let h_size = cfg.hidden_size;
3827            let i_size = cfg.intermediate_size;
3828            let gate_proj = h_size * i_size / weight_pack_factor;
3829            let up_proj = h_size * i_size / weight_pack_factor;
3830            let down_proj = i_size * h_size / weight_pack_factor;
3831
3832            input_layernorm
3833                + post_attention_layernorm
3834                + q_proj
3835                + k_proj
3836                + v_proj
3837                + o_proj
3838                + gate_proj
3839                + up_proj
3840                + down_proj
3841        };
3842        Ok(vec![
3843            per_layer_elems * dtype.size_in_bytes();
3844            txt_cfg.num_hidden_layers
3845        ])
3846    }
3847
3848    fn num_layers(&self, config: &str) -> Result<usize> {
3849        let cfg: Gemma3Config = serde_json::from_str(config)?;
3850
3851        let txt_cfg = match &cfg {
3852            Gemma3Config::Text(cfg) => cfg,
3853            Gemma3Config::WithVision { text_config, .. } => text_config,
3854        };
3855
3856        Ok(txt_cfg.num_hidden_layers)
3857    }
3858
3859    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3860        let cfg: Gemma3Config = serde_json::from_str(config)?;
3861
3862        let cfg = match &cfg {
3863            Gemma3Config::Text(cfg) => cfg,
3864            Gemma3Config::WithVision { text_config, .. } => text_config,
3865        };
3866
3867        let cfg = ModelConfigMetadata {
3868            max_seq_len: cfg.max_position_embeddings,
3869            num_layers: cfg.num_hidden_layers,
3870            hidden_size: cfg.hidden_size,
3871            num_kv_heads: cfg.num_key_value_heads,
3872            num_attn_heads: cfg.num_attention_heads,
3873            sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3875            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3876        };
3877
3878        Ok(Box::new(cfg))
3879    }
3880
3881    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3882        Some(vec![NonMappedSubModel::Vision])
3883    }
3884}
3885
3886pub struct Mistral3Loader;
3892
3893pub struct Mistral3Prefixer;
3894
3895impl MultimodalPromptPrefixer for Mistral3Prefixer {
3896    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3897        prompt.to_string()
3898    }
3899}
3900
3901impl VisionModelLoader for Mistral3Loader {
3902    fn load(
3903        &self,
3904        config: &str,
3905        vb: ShardedVarBuilder,
3906        normal_loading_metadata: NormalLoadingMetadata,
3907        attention_mechanism: AttentionImplementation,
3908    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3909        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3910        Ok(Box::new(Mistral3Model::new(
3911            &cfg,
3912            vb,
3913            self.is_gptx(config),
3914            normal_loading_metadata,
3915            attention_mechanism,
3916        )?))
3917    }
3918    fn is_gptx(&self, _config: &str) -> bool {
3919        true
3920    }
3921    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3922        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3923        Ok(Box::new(cfg))
3924    }
3925    fn get_processor(
3926        &self,
3927        _model_config: &str,
3928        processor_config: Option<ProcessorConfig>,
3929        _preprocessor_config: PreProcessorConfig,
3930        _max_edge: Option<u32>,
3931    ) -> Arc<dyn Processor + Send + Sync> {
3932        Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3933    }
3934    fn supports_paged_attention(&self, _config: &str) -> bool {
3935        true
3936    }
3937    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3938        true
3939    }
3940    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3941        Arc::new(Mistral3Prefixer)
3942    }
3943    fn modalities(&self, _config: &str) -> Result<Modalities> {
3944        Ok(Modalities {
3945            input: vec![SupportedModality::Text, SupportedModality::Vision],
3946            output: vec![SupportedModality::Text],
3947        })
3948    }
3949}
3950
3951impl IsqModelLoader for Mistral3Loader {
3952    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3953        Ok(vec![
3954            Regex::new(r"lm_head\.(weight|bias)$")?,
3955            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3957            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3958            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3959            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3960            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3962            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3963            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3964        ])
3965    }
3966    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3967        Ok(vec![
3968            Regex::new(r"lm_head\.(weight|bias)$")?,
3969            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3971            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3972            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3973            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3974            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3976            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3977            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3978        ])
3979    }
3980}
3981
3982#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3983impl DeviceMappedModelLoader for Mistral3Loader {
3984    fn mapped_max_act_size_elems(
3985        &self,
3986        config: &str,
3987        params: &AutoDeviceMapParams,
3988    ) -> Result<usize> {
3989        let cfg: Mistral3Config = serde_json::from_str(config)?;
3990        let vcfg = &cfg.vision_config;
3991        let tcfg = &cfg.text_config;
3992
3993        let AutoDeviceMapParams::Vision {
3994            max_seq_len,
3995            max_batch_size,
3996            max_image_shape: (mut height, mut width),
3997            max_num_images,
3998        } = params
3999        else {
4000            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4001        };
4002
4003        let img_seq_len = {
4004            let (max_height, max_width) = (1540, 1540);
4008            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4009            if ratio > 1. {
4010                height = (height as f64 / ratio).floor() as usize;
4011                width = (width as f64 / ratio).floor() as usize;
4012            }
4013
4014            let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
4015            let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
4016
4017            height = num_height_tokens * vcfg.patch_size;
4018            width = num_width_tokens * vcfg.patch_size;
4019
4020            let num_height_tokens = height / vcfg.patch_size;
4021            let num_width_tokens = width / vcfg.patch_size;
4022
4023            (num_width_tokens + 1) * num_height_tokens
4024        };
4025
4026        let max_seq_len = img_seq_len * max_num_images + *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4028        Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
4029    }
4030
4031    fn non_mapped_max_act_size_elems(
4032        &self,
4033        config: &str,
4034        params: &AutoDeviceMapParams,
4035    ) -> Result<usize> {
4036        let cfg: Mistral3Config = serde_json::from_str(config)?;
4037        let cfg = &cfg.vision_config;
4038
4039        let AutoDeviceMapParams::Vision {
4040            max_seq_len: _,
4041            max_batch_size,
4042            max_image_shape: (mut height, mut width),
4043            max_num_images,
4044        } = params
4045        else {
4046            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4047        };
4048
4049        let img_seq_len = {
4050            let (max_height, max_width) = (1540, 1540);
4054            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4055            if ratio > 1. {
4056                height = (height as f64 / ratio).floor() as usize;
4057                width = (width as f64 / ratio).floor() as usize;
4058            }
4059
4060            let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4061            let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4062
4063            height = num_height_tokens * cfg.patch_size;
4064            width = num_width_tokens * cfg.patch_size;
4065
4066            let num_height_tokens = height / cfg.patch_size;
4067            let num_width_tokens = width / cfg.patch_size;
4068
4069            (num_width_tokens + 1) * num_height_tokens
4070        };
4071
4072        Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4073    }
4074
4075    fn non_mapped_size_in_bytes(
4076        &self,
4077        config: &str,
4078        dtype: DType,
4079        weight_pack_factor: usize,
4080        _matformer_config: Option<&MatformerSliceConfig>,
4081    ) -> Result<usize> {
4082        let cfg: Mistral3Config = serde_json::from_str(config)?;
4083
4084        let text_elems = {
4085            let cfg = &cfg.text_config;
4086
4087            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4088            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4090                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4091            } else {
4092                0
4093            };
4094            let norm = cfg.hidden_size;
4095            embed_tokens + lm_head + norm
4096        };
4097
4098        let vision_elems = {
4099            let cfg = &cfg.vision_config;
4100
4101            let patch_embed = {
4102                let conv_cfg = Conv2dConfig {
4103                    stride: cfg.patch_size,
4104                    ..Default::default()
4105                };
4106                cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4107                    * cfg.patch_size
4108                    * cfg.patch_size
4109                    * cfg.patch_size
4110            };
4111            let ln_pre = cfg.hidden_size;
4112            let vision_layer = {
4113                let attn_norm = cfg.hidden_size;
4114                let ffn_norm = cfg.hidden_size;
4115
4116                let gate = cfg.hidden_size * cfg.intermediate_size;
4117                let up = cfg.hidden_size * cfg.intermediate_size;
4118                let down = cfg.hidden_size * cfg.intermediate_size;
4119
4120                let q = cfg.hidden_size * cfg.hidden_size;
4121                let k = cfg.hidden_size * cfg.hidden_size;
4122                let v = cfg.hidden_size * cfg.hidden_size;
4123                let o = cfg.hidden_size * cfg.hidden_size;
4124
4125                attn_norm + ffn_norm + gate + up + down + q + k + v + o
4126            };
4127
4128            patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4129        };
4130
4131        let elems = text_elems + vision_elems;
4132
4133        Ok(elems * dtype.size_in_bytes())
4134    }
4135
4136    fn layer_sizes_in_bytes(
4137        &self,
4138        config: &str,
4139        dtype: DType,
4140        weight_pack_factor: usize,
4141        _matformer_config: Option<&MatformerSliceConfig>,
4142    ) -> Result<Vec<usize>> {
4143        let cfg: Mistral3Config = serde_json::from_str(config)?;
4144        let cfg = &cfg.text_config;
4145
4146        let per_layer_elems = {
4147            let input_layernorm = cfg.hidden_size;
4148            let post_attention_layernorm = cfg.hidden_size;
4149
4150            let size_in = cfg.hidden_size;
4151            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4152            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4153            let q_proj = size_in * size_q / weight_pack_factor;
4154            let k_proj = size_in * size_kv / weight_pack_factor;
4155            let v_proj = size_in * size_kv / weight_pack_factor;
4156            let o_proj = size_q * size_in / weight_pack_factor;
4157
4158            let h_size = cfg.hidden_size;
4159            let i_size = cfg.intermediate_size;
4160            let gate_proj = h_size * i_size / weight_pack_factor;
4161            let up_proj = h_size * i_size / weight_pack_factor;
4162            let down_proj = i_size * h_size / weight_pack_factor;
4163
4164            input_layernorm
4165                + post_attention_layernorm
4166                + q_proj
4167                + k_proj
4168                + v_proj
4169                + o_proj
4170                + gate_proj
4171                + up_proj
4172                + down_proj
4173        };
4174        Ok(vec![
4175            per_layer_elems * dtype.size_in_bytes();
4176            cfg.num_hidden_layers
4177        ])
4178    }
4179
4180    fn num_layers(&self, config: &str) -> Result<usize> {
4181        let cfg: Mistral3Config = serde_json::from_str(config)?;
4182        let cfg = &cfg.text_config;
4183        Ok(cfg.num_hidden_layers)
4184    }
4185
4186    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4187        let cfg: Mistral3Config = serde_json::from_str(config)?;
4188        let cfg = &cfg.text_config;
4189
4190        let cfg = ModelConfigMetadata {
4191            max_seq_len: cfg.max_position_embeddings,
4192            num_layers: cfg.num_hidden_layers,
4193            hidden_size: cfg.hidden_size,
4194            num_kv_heads: cfg.num_key_value_heads,
4195            num_attn_heads: cfg.num_attention_heads,
4196            sliding_window: cfg.sliding_window,
4197            k_head_dim: cfg.head_dim(),
4198            v_head_dim: cfg.head_dim(),
4199        };
4200
4201        Ok(Box::new(cfg))
4202    }
4203
4204    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4205        Some(vec![NonMappedSubModel::Vision])
4206    }
4207}
4208
4209pub struct VLlama4Loader;
4215
4216pub struct VLlama4Prefixer;
4217
4218impl MultimodalPromptPrefixer for VLlama4Prefixer {
4219    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4220        format!(
4221            "{}{prompt}",
4222            llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4223        )
4224    }
4225}
4226
4227impl VisionModelLoader for VLlama4Loader {
4228    fn load(
4229        &self,
4230        config: &str,
4231        vb: ShardedVarBuilder,
4232        normal_loading_metadata: NormalLoadingMetadata,
4233        attention_mechanism: AttentionImplementation,
4234    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4235        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4236        Ok(Box::new(Llama4Model::new(
4237            &cfg,
4238            vb,
4239            self.is_gptx(config),
4240            normal_loading_metadata,
4241            attention_mechanism,
4242        )?))
4243    }
4244    fn is_gptx(&self, _config: &str) -> bool {
4245        false
4246    }
4247    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4248        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4249        Ok(Box::new(cfg))
4250    }
4251    fn get_processor(
4252        &self,
4253        _model_config: &str,
4254        processor_config: Option<ProcessorConfig>,
4255        _preprocessor_config: PreProcessorConfig,
4256        _max_edge: Option<u32>,
4257    ) -> Arc<dyn Processor + Send + Sync> {
4258        Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4259    }
4260    fn supports_paged_attention(&self, _config: &str) -> bool {
4261        true
4262    }
4263    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4264        Arc::new(VLlama4Prefixer)
4265    }
4266    fn modalities(&self, _config: &str) -> Result<Modalities> {
4267        Ok(Modalities {
4268            input: vec![SupportedModality::Text, SupportedModality::Vision],
4269            output: vec![SupportedModality::Text],
4270        })
4271    }
4272}
4273
4274impl IsqModelLoader for VLlama4Loader {
4275    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4276        Ok(vec![
4277            Regex::new(r"lm_head\.(weight|bias)$")?,
4278            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4280            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4281            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4282            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4283            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4285            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4286            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4287            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4288            Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4289            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4290            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4291            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4292            Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4294            Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4295            Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4296        ])
4297    }
4298    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4299        Ok(vec![
4300            Regex::new(r"lm_head\.(weight|bias)$")?,
4301            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4303            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4304            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4305            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4306            Regex::new(
4308                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4309            )?,
4310            Regex::new(
4311                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4312            )?,
4313            Regex::new(
4314                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4315            )?,
4316            Regex::new(
4317                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4318            )?,
4319            Regex::new(
4320                r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4321            )?,
4322            Regex::new(
4323                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4324            )?,
4325            Regex::new(
4326                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4327            )?,
4328            Regex::new(
4329                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4330            )?,
4331            Regex::new(
4333                r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4334            )?,
4335            Regex::new(
4336                r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4337            )?,
4338            Regex::new(
4339                r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4340            )?,
4341        ])
4342    }
4343}
4344
4345impl VLlama4Loader {
4346    #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4349    fn run_dummy_processing(
4350        &self,
4351        cfg: &Llama4Config,
4352        height: usize,
4353        width: usize,
4354        max_num_images: usize,
4355        max_batch_size: usize,
4356    ) -> Result<(usize, usize)> {
4357        let cfg = &cfg.vision_config;
4358
4359        let img_processor =
4360            Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4361        let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4362        let res = img_processor.preprocess(
4363            vec![image; max_num_images],
4364            vec![],
4365            &PreProcessorConfig::default(),
4366            &Device::Cpu,
4367            (max_batch_size, max_num_images),
4368        )?;
4369
4370        let pixels_batch_size = res.pixel_values.dim(0)?;
4371        let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4372
4373        let (image_h, image_w) = (
4374            res.pixel_values.dim(D::Minus2).unwrap(),
4375            res.pixel_values.dim(D::Minus1).unwrap(),
4376        );
4377        let num_patches_per_chunk = (image_h / img_processor.patch_size)
4378            * (image_w / img_processor.patch_size)
4379            / img_processor.downsample_ratio;
4380
4381        Ok((
4382            pixels_max_batch_size,
4383            num_patches_per_chunk * pixels_max_batch_size,
4384        ))
4385    }
4386}
4387
4388impl DeviceMappedModelLoader for VLlama4Loader {
4389    fn mapped_max_act_size_elems(
4390        &self,
4391        config: &str,
4392        params: &AutoDeviceMapParams,
4393    ) -> Result<usize> {
4394        let AutoDeviceMapParams::Vision {
4395            max_seq_len,
4396            max_batch_size,
4397            max_image_shape: (height, width),
4398            max_num_images,
4399        } = params
4400        else {
4401            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4402        };
4403
4404        let cfg: Llama4Config = serde_json::from_str(config)?;
4405
4406        let (_pixels_batch_size, num_text_image_toks) =
4407            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4408
4409        let max_seq_len = max_seq_len.min(&ATTENTION_CHUNK_SIZE) + num_text_image_toks;
4410
4411        Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4412    }
4413    fn non_mapped_max_act_size_elems(
4414        &self,
4415        config: &str,
4416        params: &AutoDeviceMapParams,
4417    ) -> Result<usize> {
4418        let AutoDeviceMapParams::Vision {
4419            max_seq_len: _,
4420            max_batch_size,
4421            max_image_shape: (height, width),
4422            max_num_images,
4423        } = params
4424        else {
4425            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4426        };
4427
4428        let cfg: Llama4Config = serde_json::from_str(config)?;
4429
4430        let (pixels_batch_size, _num_text_image_toks) =
4431            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4432        let max_seq_len = cfg.vision_config.num_patches();
4433
4434        Ok((max_batch_size * pixels_batch_size)
4435            * cfg.vision_config.num_attention_heads
4436            * max_seq_len
4437            * max_seq_len)
4438    }
4439
4440    fn non_mapped_size_in_bytes(
4441        &self,
4442        config: &str,
4443        dtype: DType,
4444        weight_pack_factor: usize,
4445        _matformer_config: Option<&MatformerSliceConfig>,
4446    ) -> Result<usize> {
4447        let cfg: Llama4Config = serde_json::from_str(config)?;
4448        let tcfg = &cfg.text_config;
4449
4450        let text_elems = {
4451            let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4452            let lm_head = if !tcfg.tie_word_embeddings {
4453                tcfg.hidden_size * tcfg.vocab_size
4454            } else {
4455                0
4456            };
4457            let norm = tcfg.hidden_size;
4458            embed_tokens + lm_head + norm
4459        };
4460
4461        let vision_elems = {
4462            let cfg = &cfg.vision_config;
4463
4464            let num_patches = cfg.num_patches();
4465
4466            let unfold_elems =
4467                (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4468            let class_embeddng_elems = cfg.hidden_size;
4469            let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4470            let layernorm_pre_elems = cfg.hidden_size;
4471            let layernorm_post_elems = cfg.hidden_size;
4472
4473            let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4474                / weight_pack_factor
4475                + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4476
4477            let encoder_layer = {
4478                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4479                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4480
4481                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4482                let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4483                    / weight_pack_factor
4484                    + cfg.num_attention_heads * head_dim;
4485                let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4486                    / weight_pack_factor
4487                    + cfg.num_attention_heads * head_dim;
4488                let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4489                    / weight_pack_factor
4490                    + cfg.num_attention_heads * head_dim;
4491                let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4492                    / weight_pack_factor
4493                    + cfg.num_attention_heads * head_dim;
4494
4495                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4496                    + cfg.intermediate_size;
4497                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4498                    + cfg.hidden_size;
4499
4500                input_layernorm
4501                    + post_attention_layernorm
4502                    + q_proj
4503                    + k_proj
4504                    + v_proj
4505                    + o_proj
4506                    + fc1
4507                    + fc2
4508            };
4509
4510            unfold_elems
4511                + class_embeddng_elems
4512                + positional_embedding_vlm_elems
4513                + layernorm_post_elems
4514                + layernorm_pre_elems
4515                + pixel_shuffle_elems
4516                + encoder_layer * cfg.num_hidden_layers
4517        };
4518
4519        let elems = text_elems + vision_elems;
4520
4521        Ok(elems * dtype.size_in_bytes())
4522    }
4523
4524    fn layer_sizes_in_bytes(
4525        &self,
4526        config: &str,
4527        dtype: DType,
4528        weight_pack_factor: usize,
4529        _matformer_config: Option<&MatformerSliceConfig>,
4530    ) -> Result<Vec<usize>> {
4531        let cfg: Llama4Config = serde_json::from_str(config)?;
4532        let tcfg = &cfg.text_config;
4533
4534        let mut per_layer_elems = Vec::new();
4535
4536        for layer_idx in 0..tcfg.num_hidden_layers {
4537            let input_layernorm = tcfg.hidden_size;
4538            let post_attention_layernorm = tcfg.hidden_size;
4539
4540            let size_in = tcfg.hidden_size;
4541            let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4542            let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4543            let q_proj = size_in * size_q / weight_pack_factor;
4544            let k_proj = size_in * size_kv / weight_pack_factor;
4545            let v_proj = size_in * size_kv / weight_pack_factor;
4546            let o_proj = size_q * size_in / weight_pack_factor;
4547
4548            let use_moe = tcfg.moe_layers().contains(&layer_idx);
4549            let moe_block = if use_moe {
4550                let h_size = tcfg.hidden_size;
4551                let i_size = tcfg.intermediate_size;
4552                let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4553                let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4554                let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4555
4556                gate_proj + up_proj + down_proj
4557            } else {
4558                let h_size = tcfg.hidden_size;
4559                let i_size = tcfg.intermediate_size_mlp;
4560                let gate_proj = h_size * i_size / weight_pack_factor;
4561                let up_proj = h_size * i_size / weight_pack_factor;
4562                let down_proj = i_size * h_size / weight_pack_factor;
4563
4564                gate_proj + up_proj + down_proj
4565            };
4566
4567            per_layer_elems.push(
4568                input_layernorm
4569                    + post_attention_layernorm
4570                    + q_proj
4571                    + k_proj
4572                    + v_proj
4573                    + o_proj
4574                    + moe_block,
4575            );
4576        }
4577
4578        Ok(per_layer_elems
4579            .into_iter()
4580            .map(|x| x * dtype.size_in_bytes())
4581            .collect())
4582    }
4583
4584    fn num_layers(&self, config: &str) -> Result<usize> {
4585        let cfg: Llama4Config = serde_json::from_str(config)?;
4586        Ok(cfg.text_config.num_hidden_layers)
4587    }
4588
4589    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4590        let cfg: Llama4Config = serde_json::from_str(config)?;
4591        let cfg = &cfg.text_config;
4592
4593        let cfg = ModelConfigMetadata {
4594            max_seq_len: cfg.max_position_embeddings,
4595            num_layers: cfg.num_hidden_layers,
4596            hidden_size: cfg.hidden_size,
4597            num_kv_heads: cfg.num_attention_heads,
4598            num_attn_heads: cfg.num_attention_heads,
4599            sliding_window: None,
4600            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4601            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4602        };
4603
4604        Ok(Box::new(cfg))
4605    }
4606
4607    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4608        Some(vec![NonMappedSubModel::Vision])
4609    }
4610}
4611
4612pub struct Gemma3nLoader;
4618
4619#[allow(dead_code)]
4620pub struct Gemma3nPrefixer;
4621
4622impl MultimodalPromptPrefixer for Gemma3nPrefixer {
4623    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4624        prompt.to_string()
4625    }
4626}
4627
4628impl VisionModelLoader for Gemma3nLoader {
4629    fn load(
4630        &self,
4631        config: &str,
4632        vb: ShardedVarBuilder,
4633        normal_loading_metadata: NormalLoadingMetadata,
4634        attention_mechanism: AttentionImplementation,
4635    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4636        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4637        Ok(Box::new(Gemma3nModel::new(
4638            &cfg,
4639            vb,
4640            self.is_gptx(config),
4641            normal_loading_metadata,
4642            attention_mechanism,
4643        )?))
4644    }
4645    fn is_gptx(&self, _config: &str) -> bool {
4646        true
4647    }
4648    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4649        let config: Gemma3nConfig = serde_json::from_str(config)?;
4650        Ok(Box::new(config))
4651    }
4652    fn get_processor(
4653        &self,
4654        _config: &str,
4655        processor_config: Option<ProcessorConfig>,
4656        _preprocessor_config: PreProcessorConfig,
4657        _max_edge: Option<u32>,
4658    ) -> Arc<dyn Processor + Send + Sync> {
4659        Arc::new(Gemma3nProcessor::new(
4661            processor_config.unwrap_or_default(),
4662            true,
4663        ))
4664    }
4665    fn supports_paged_attention(&self, _config: &str) -> bool {
4666        false
4667    }
4668    fn supports_prefix_cacher(&self, _config: &str) -> bool {
4669        true
4670    }
4671    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4672        Arc::new(Gemma3Prefixer)
4673    }
4674    fn modalities(&self, _config: &str) -> Result<Modalities> {
4675        Ok(Modalities {
4676            input: vec![
4677                SupportedModality::Text,
4678                SupportedModality::Vision,
4679                SupportedModality::Audio,
4680            ],
4681            output: vec![SupportedModality::Text],
4682        })
4683    }
4684}
4685
4686impl IsqModelLoader for Gemma3nLoader {
4687    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4688        Ok(vec![
4689            Regex::new(r"lm_head\.(weight|bias)$")?,
4690            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4692            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4693            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4694            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4695            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4697            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4698            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4699            Regex::new(r"conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$")?,
4701            Regex::new(r"conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$")?,
4702            Regex::new(r"conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$")?,
4703            Regex::new(
4704                r"conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4705            )?,
4706            Regex::new(r"conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4707            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$")?,
4709            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$")?,
4710            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$")?,
4711            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$")?,
4712            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$")?,
4714            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$")?,
4715            Regex::new(r"subsample_conv_projection\.input_proj_linear\.(weight|bias)$")?,
4717            Regex::new(r"embed_vision\.embedding_projection\.(weight|bias)$")?,
4719            Regex::new(r"embed_audio\.embedding_projection\.(weight|bias)$")?,
4720        ])
4721    }
4722    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4723        Ok(vec![
4724            Regex::new(r"lm_head\.(weight|bias)$")?,
4725            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4727            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4728            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4729            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4730            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4732            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4733            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4734            Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
4736            Regex::new(r"model\.language_model\.altup_projections\.(\d+)\.(weight|bias)$")?,
4737            Regex::new(r"model\.language_model\.altup_unembed_projections\.(\d+)\.(weight|bias)$")?,
4738            Regex::new(
4740                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$",
4741            )?,
4742            Regex::new(
4743                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$",
4744            )?,
4745            Regex::new(
4746                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$",
4747            )?,
4748            Regex::new(
4749                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4750            )?,
4751            Regex::new(r"model\.audio_tower\.conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4752            Regex::new(
4754                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$",
4755            )?,
4756            Regex::new(
4757                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$",
4758            )?,
4759            Regex::new(
4760                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$",
4761            )?,
4762            Regex::new(
4763                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$",
4764            )?,
4765            Regex::new(
4767                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$",
4768            )?,
4769            Regex::new(
4770                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$",
4771            )?,
4772            Regex::new(
4774                r"model\.audio_tower\.subsample_conv_projection\.input_proj_linear\.(weight|bias)$",
4775            )?,
4776            Regex::new(r"model\.embed_vision\.embedding_projection\.(weight|bias)$")?,
4778            Regex::new(r"model\.embed_audio\.embedding_projection\.(weight|bias)$")?,
4779        ])
4780    }
4781}
4782
4783impl DeviceMappedModelLoader for Gemma3nLoader {
4784    fn mapped_max_act_size_elems(
4785        &self,
4786        config: &str,
4787        params: &AutoDeviceMapParams,
4788    ) -> Result<usize> {
4789        let AutoDeviceMapParams::Vision {
4790            max_seq_len,
4791            max_batch_size,
4792            max_image_shape: _,
4793            max_num_images,
4794        } = params
4795        else {
4796            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4797        };
4798
4799        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4800        let text_cfg = &cfg.text_config;
4801
4802        let mut total_seq_len = *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4806
4807        {
4809            let msfa_spatial_size = 16; let vision_tokens_per_image = msfa_spatial_size * msfa_spatial_size; total_seq_len += vision_tokens_per_image * max_num_images;
4814        }
4815
4816        {
4818            let audio_tokens = cfg.audio_soft_tokens_per_image;
4821            total_seq_len += audio_tokens;
4822        }
4823
4824        let max_text_attn =
4826            max_batch_size * text_cfg.num_attention_heads * total_seq_len * total_seq_len;
4827
4828        Ok(max_text_attn)
4829    }
4830
4831    fn non_mapped_max_act_size_elems(
4832        &self,
4833        config: &str,
4834        params: &AutoDeviceMapParams,
4835    ) -> Result<usize> {
4836        let AutoDeviceMapParams::Vision {
4837            max_seq_len: _,
4838            max_batch_size,
4839            max_image_shape: _,
4840            max_num_images,
4841        } = params
4842        else {
4843            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4844        };
4845
4846        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4847
4848        let mut max_activation = 0;
4850
4851        {
4853            let vision_tower_act = {
4863                let num_heads = 16; let spatial_size = 24; let seq_len = spatial_size * spatial_size;
4869
4870                max_batch_size * max_num_images * num_heads * seq_len * seq_len
4872            };
4873
4874            let vision_embed_act = {
4876                let msfa_channels = 2048; let spatial_size = 16; let vision_features =
4880                    max_batch_size * max_num_images * msfa_channels * spatial_size * spatial_size;
4881
4882                let projected = max_batch_size
4884                    * max_num_images
4885                    * spatial_size
4886                    * spatial_size
4887                    * cfg.text_config.hidden_size;
4888
4889                vision_features.max(projected)
4890            };
4891
4892            max_activation = max_activation.max(vision_tower_act).max(vision_embed_act);
4893        }
4894
4895        {
4897            let audio_cfg = &cfg.audio_config;
4898
4899            let max_audio_frames = 1280;
4904
4905            let subsample_factor: usize = audio_cfg
4906                .sscp_conv_stride_size
4907                .iter()
4908                .map(|stride| stride[0]) .product();
4910            let audio_seq_after_subsample = max_audio_frames / subsample_factor;
4911
4912            let audio_encoder_act = {
4914                let intermediate_size = audio_cfg.hidden_size * 4; max_batch_size * audio_seq_after_subsample * intermediate_size
4919            };
4920
4921            let audio_attn_act = {
4923                let chunk_size = audio_cfg.conf_attention_chunk_size;
4925                let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
4926                    + audio_cfg.conf_attention_context_right;
4927
4928                let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
4930
4931                max_batch_size
4932                    * audio_cfg.conf_num_attention_heads
4933                    * num_chunks
4934                    * chunk_size
4935                    * context_size
4936            };
4937
4938            max_activation = max_activation.max(audio_encoder_act).max(audio_attn_act);
4939        }
4940
4941        Ok(max_activation)
4942    }
4943
4944    fn non_mapped_size_in_bytes(
4945        &self,
4946        config: &str,
4947        dtype: DType,
4948        weight_pack_factor: usize,
4949        matformer_config: Option<&MatformerSliceConfig>,
4950    ) -> Result<usize> {
4951        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4952
4953        let text_cfg = if let Some(matformer_cfg) = matformer_config {
4955            use crate::device_map::DummyDeviceMapper;
4956            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
4957
4958            let dummy_mapper = DummyDeviceMapper {
4959                nm_device: Device::Cpu,
4960            };
4961            let (adjusted_cfg, _, _, _, _) = handle_matformer_slicing(
4962                &cfg.text_config,
4963                &Some(matformer_cfg.clone()),
4964                &dummy_mapper,
4965            )?;
4966            adjusted_cfg
4967        } else {
4968            cfg.text_config.clone()
4969        };
4970
4971        let text_cfg = &text_cfg;
4972
4973        let text_elems = {
4975            let embed_tokens = text_cfg.hidden_size * text_cfg.vocab_size;
4977            let embed_tokens_per_layer = text_cfg.num_hidden_layers
4978                * text_cfg.hidden_size_per_layer_input
4979                * text_cfg.vocab_size_per_layer_input;
4980
4981            let lm_head = if !text_cfg.tie_word_embeddings || weight_pack_factor != 1 {
4983                text_cfg.hidden_size * text_cfg.vocab_size / weight_pack_factor
4984            } else {
4985                0
4986            };
4987
4988            let norm = text_cfg.hidden_size;
4990
4991            let altup_projections =
4993                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
4994                    / weight_pack_factor;
4995            let altup_unembed_projections =
4996                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
4997                    / weight_pack_factor;
4998
4999            let per_layer_model_projection = text_cfg.num_hidden_layers
5001                * text_cfg.hidden_size
5002                * text_cfg.hidden_size_per_layer_input
5003                / weight_pack_factor;
5004            let per_layer_projection_norm = text_cfg.hidden_size;
5005
5006            embed_tokens
5007                + embed_tokens_per_layer
5008                + lm_head
5009                + norm
5010                + altup_projections
5011                + altup_unembed_projections
5012                + per_layer_model_projection
5013                + per_layer_projection_norm
5014        };
5015
5016        let vision_elems = {
5018            let vision_cfg = &cfg.vision_config;
5019            let vision_tower_elems = {
5023                use crate::vision_models::gemma3n::vision::{
5024                    gemma3n_mobilenet_def, make_divisible, BlockType, INPUT_CHANNELS,
5025                    MSFA_EXPANSION_RATIO, MSFA_IN_CHANNELS, MSFA_OUT_CHANNELS, STEM_KERNEL_SIZE,
5026                    STEM_OUT_CHANNELS,
5027                };
5028
5029                let stem_conv =
5031                    INPUT_CHANNELS * STEM_OUT_CHANNELS * STEM_KERNEL_SIZE * STEM_KERNEL_SIZE;
5032                let stem_norm = STEM_OUT_CHANNELS; let mut in_chs = STEM_OUT_CHANNELS;
5036                let mut total_elems = stem_conv + stem_norm;
5037
5038                let block_defs = gemma3n_mobilenet_def();
5040
5041                for stage_blocks in block_defs.iter() {
5042                    for block_type in stage_blocks.iter() {
5043                        match block_type {
5044                            BlockType::EdgeResidual {
5045                                out_channels,
5046                                kernel_size,
5047                                stride: _,
5048                                expand_ratio,
5049                                ..
5050                            } => {
5051                                #[allow(clippy::cast_precision_loss)]
5052                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5053                                total_elems += in_chs * mid_chs * kernel_size * kernel_size; total_elems += mid_chs; total_elems += mid_chs * out_channels; total_elems += out_channels; in_chs = *out_channels;
5059                            }
5060                            BlockType::UniversalInvertedResidual {
5061                                out_channels,
5062                                start_kernel_size,
5063                                mid_kernel_size,
5064                                stride: _,
5065                                expand_ratio,
5066                                ..
5067                            } => {
5068                                #[allow(clippy::cast_precision_loss)]
5069                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5070                                if *expand_ratio != 1.0 {
5072                                    total_elems += in_chs * mid_chs; total_elems += mid_chs; }
5075                                if *start_kernel_size > 0 {
5076                                    total_elems += mid_chs * start_kernel_size * start_kernel_size; total_elems += mid_chs; }
5079                                if *mid_kernel_size > 0 {
5080                                    total_elems += mid_chs * mid_kernel_size * mid_kernel_size; total_elems += mid_chs; }
5083                                total_elems += mid_chs * out_channels; total_elems += out_channels; total_elems += out_channels; in_chs = *out_channels;
5087                            }
5088                            BlockType::MultiQueryAttention {
5089                                num_heads,
5090                                kv_dim,
5091                                kv_stride: _,
5092                                ..
5093                            } => {
5094                                let dw_kernel_size = 3; total_elems += in_chs; total_elems += in_chs * num_heads * kv_dim; total_elems += in_chs * kv_dim; total_elems += in_chs * dw_kernel_size * dw_kernel_size; total_elems += *kv_dim; total_elems += 1; total_elems += *kv_dim; total_elems += num_heads * kv_dim * in_chs; total_elems += in_chs; }
5106                        }
5107                    }
5108                }
5109
5110                let msfa_in = MSFA_IN_CHANNELS.iter().sum::<usize>();
5112                let msfa_out = MSFA_OUT_CHANNELS;
5113                #[allow(clippy::cast_precision_loss)]
5114                let msfa_mid = make_divisible(msfa_in as f64 * MSFA_EXPANSION_RATIO, 8);
5115
5116                total_elems += msfa_in * msfa_mid; total_elems += msfa_mid; total_elems += msfa_mid * msfa_out; total_elems += msfa_out; total_elems += msfa_out; total_elems
5124            };
5125
5126            let embed_vision_elems = {
5128                let embedding = vision_cfg.vocab_size * vision_cfg.hidden_size;
5130
5131                let hard_norm = vision_cfg.hidden_size;
5133                let soft_norm = vision_cfg.hidden_size;
5134
5135                let projection = vision_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5137
5138                let post_norm = text_cfg.hidden_size;
5140
5141                embedding + hard_norm + soft_norm + projection + post_norm
5142            };
5143
5144            vision_tower_elems + embed_vision_elems
5145        };
5146
5147        let audio_elems = {
5149            let audio_cfg = &cfg.audio_config;
5150
5151            let subsample_conv_projection_elems = {
5153                let mut conv_elems = 0;
5155
5156                let in_ch_0 = 1;
5158                let out_ch_0 = audio_cfg.sscp_conv_channel_size[0];
5159                let kernel_0 = &audio_cfg.sscp_conv_kernel_size[0];
5160                conv_elems += in_ch_0 * out_ch_0 * kernel_0[0] * kernel_0[1];
5161
5162                let in_ch_1 = out_ch_0;
5164                let out_ch_1 = audio_cfg.sscp_conv_channel_size[1];
5165                let kernel_1 = &audio_cfg.sscp_conv_kernel_size[1];
5166                conv_elems += in_ch_1 * out_ch_1 * kernel_1[0] * kernel_1[1];
5167
5168                let norm_0 = out_ch_0; let norm_1 = out_ch_1; let mut f_out = audio_cfg.input_feat_size;
5174                for i in 0..2 {
5175                    let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
5176                    let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
5177                    let pad_left = 1;
5178                    let pad_right = 1;
5179                    f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
5180                }
5181                let input_proj_in_features = out_ch_1 * f_out;
5182                let input_proj_linear =
5183                    input_proj_in_features * audio_cfg.hidden_size / weight_pack_factor;
5184
5185                conv_elems + norm_0 + norm_1 + input_proj_linear
5186            };
5187
5188            let conformer_elems = {
5190                let mut total = 0;
5191
5192                for _ in 0..audio_cfg.conf_num_hidden_layers {
5193                    let attention_elems = {
5195                        let pre_attn_norm = audio_cfg.hidden_size;
5197                        let post_norm = audio_cfg.hidden_size;
5198
5199                        let q_proj =
5201                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5202                        let k_proj =
5203                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5204                        let v_proj =
5205                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5206                        let post =
5207                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5208
5209                        let pos_proj =
5211                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5212                        let per_dim_scale =
5213                            audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads; let inv_timescales = audio_cfg.hidden_size / 2; let pos_indices = audio_cfg.conf_attention_context_left
5216                            + audio_cfg.conf_attention_context_right
5217                            + 1;
5218
5219                        let chunk_size = audio_cfg.conf_attention_chunk_size;
5221                        let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5222                            + audio_cfg.conf_attention_context_right;
5223                        let local_causal_valid_mask = chunk_size * context_size; let invalid_logits_tensor = 1; pre_attn_norm
5227                            + post_norm
5228                            + q_proj
5229                            + k_proj
5230                            + v_proj
5231                            + post
5232                            + pos_proj
5233                            + per_dim_scale
5234                            + inv_timescales
5235                            + pos_indices
5236                            + local_causal_valid_mask
5237                            + invalid_logits_tensor
5238                    };
5239
5240                    let ffw_elems = {
5242                        let intermediate_size = audio_cfg.hidden_size * 4;
5248
5249                        let ffw_start = {
5250                            let pre_norm = audio_cfg.hidden_size;
5251                            let layer_1 =
5252                                audio_cfg.hidden_size * intermediate_size / weight_pack_factor;
5253                            let layer_2 =
5254                                intermediate_size * audio_cfg.hidden_size / weight_pack_factor;
5255                            let post_norm = audio_cfg.hidden_size;
5256                            pre_norm + layer_1 + layer_2 + post_norm
5257                        };
5258
5259                        let ffw_end = ffw_start; ffw_start + ffw_end
5262                    };
5263
5264                    let lconv1d_elems = {
5266                        let pre_layer_norm = audio_cfg.hidden_size;
5268                        let conv_norm = audio_cfg.hidden_size;
5269
5270                        let linear_start = audio_cfg.hidden_size * (audio_cfg.hidden_size * 2)
5272                            / weight_pack_factor;
5273                        let linear_end =
5274                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5275
5276                        let depthwise = audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
5278
5279                        pre_layer_norm + conv_norm + linear_start + linear_end + depthwise
5280                    };
5281
5282                    let block_norm = audio_cfg.hidden_size;
5284
5285                    total += attention_elems + ffw_elems + lconv1d_elems + block_norm;
5286                }
5287
5288                total
5289            };
5290
5291            let embed_audio_elems = {
5293                let embedding = audio_cfg.vocab_size * audio_cfg.hidden_size;
5295
5296                let hard_embedding_norm = audio_cfg.hidden_size; let soft_embedding_norm = audio_cfg.hidden_size; let embedding_post_projection_norm = text_cfg.hidden_size; let embedding_projection =
5303                    audio_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5304
5305                embedding
5306                    + hard_embedding_norm
5307                    + soft_embedding_norm
5308                    + embedding_post_projection_norm
5309                    + embedding_projection
5310            };
5311
5312            subsample_conv_projection_elems + conformer_elems + embed_audio_elems
5313        };
5314
5315        let vision_dtype = if dtype == DType::F16 {
5316            DType::F32
5318        } else {
5319            dtype
5320        };
5321
5322        let total_elems = text_elems * dtype.size_in_bytes()
5323            + vision_elems * vision_dtype.size_in_bytes()
5324            + audio_elems * dtype.size_in_bytes();
5325
5326        Ok(total_elems)
5327    }
5328
5329    fn layer_sizes_in_bytes(
5330        &self,
5331        config: &str,
5332        dtype: DType,
5333        weight_pack_factor: usize,
5334        matformer_config: Option<&MatformerSliceConfig>,
5335    ) -> Result<Vec<usize>> {
5336        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5337
5338        let (text_cfg, _layer_rename_map, _layers_skipped) = if let Some(matformer_cfg) =
5340            matformer_config
5341        {
5342            use crate::device_map::DummyDeviceMapper;
5343            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5344
5345            let dummy_mapper = DummyDeviceMapper {
5346                nm_device: Device::Cpu,
5347            };
5348            let (adjusted_cfg, _, _, layer_rename_map, layers_skipped) = handle_matformer_slicing(
5349                &cfg.text_config,
5350                &Some(matformer_cfg.clone()),
5351                &dummy_mapper,
5352            )?;
5353            (adjusted_cfg, layer_rename_map, layers_skipped)
5354        } else {
5355            (cfg.text_config.clone(), None, None)
5356        };
5357
5358        let text_cfg = &text_cfg;
5359
5360        let mut layer_sizes = Vec::new();
5362
5363        for layer_idx in 0..text_cfg.num_hidden_layers {
5367            let per_layer_elems = {
5368                let input_layernorm = text_cfg.hidden_size;
5370                let post_attention_layernorm = text_cfg.hidden_size;
5371                let pre_feedforward_layernorm = text_cfg.hidden_size;
5372                let post_feedforward_layernorm = text_cfg.hidden_size;
5373                let post_per_layer_input_norm = text_cfg.hidden_size;
5374
5375                let size_in = text_cfg.hidden_size;
5377                let size_q = text_cfg.num_attention_heads * text_cfg.head_dim;
5378                let size_kv = text_cfg.num_key_value_heads * text_cfg.head_dim;
5379
5380                let q_proj = size_in * size_q / weight_pack_factor;
5381                let k_proj = size_in * size_kv / weight_pack_factor;
5382                let v_proj = size_in * size_kv / weight_pack_factor;
5383                let o_proj = size_q * size_in / weight_pack_factor;
5384
5385                let q_norm = text_cfg.head_dim;
5387                let k_norm = text_cfg.head_dim;
5388                let v_norm = text_cfg.head_dim; let intermediate_size = match &text_cfg.intermediate_size {
5392                    IntermediateSize::Single(size) => *size,
5393                    IntermediateSize::PerLayer(sizes) => sizes[layer_idx],
5394                    IntermediateSize::Matformer(sizes, _) => sizes[layer_idx],
5395                };
5396                let gate_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5397                let up_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5398                let down_proj = intermediate_size * text_cfg.hidden_size / weight_pack_factor;
5399
5400                let altup_elems = {
5402                    let correct_output_scale = text_cfg.hidden_size;
5403                    let correction_coefs = text_cfg.altup_num_inputs * text_cfg.altup_num_inputs;
5404                    let prediction_coefs =
5405                        text_cfg.altup_num_inputs * text_cfg.altup_num_inputs.pow(2);
5406                    let modality_router = text_cfg.hidden_size * text_cfg.altup_num_inputs;
5407                    let router_norm = text_cfg.hidden_size;
5408
5409                    correct_output_scale
5410                        + correction_coefs
5411                        + prediction_coefs
5412                        + modality_router
5413                        + router_norm
5414                };
5415
5416                let laurel_elems = {
5418                    let left = text_cfg.hidden_size * text_cfg.laurel_rank;
5419                    let right = text_cfg.laurel_rank * text_cfg.hidden_size;
5420                    let post_norm = text_cfg.hidden_size;
5421
5422                    left + right + post_norm
5423                };
5424
5425                let per_layer_input_gate =
5427                    text_cfg.hidden_size * text_cfg.hidden_size_per_layer_input;
5428                let per_layer_projection =
5429                    text_cfg.hidden_size_per_layer_input * text_cfg.hidden_size;
5430
5431                input_layernorm
5432                    + post_attention_layernorm
5433                    + pre_feedforward_layernorm
5434                    + post_feedforward_layernorm
5435                    + post_per_layer_input_norm
5436                    + q_proj
5437                    + k_proj
5438                    + v_proj
5439                    + o_proj
5440                    + q_norm
5441                    + k_norm
5442                    + v_norm
5443                    + gate_proj
5444                    + up_proj
5445                    + down_proj
5446                    + altup_elems
5447                    + laurel_elems
5448                    + per_layer_input_gate
5449                    + per_layer_projection
5450            };
5451
5452            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5453        }
5454
5455        Ok(layer_sizes)
5456    }
5457
5458    fn num_layers(&self, config: &str) -> Result<usize> {
5459        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5460        Ok(cfg.text_config.num_hidden_layers)
5461    }
5462
5463    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5464        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5465        let cfg = cfg.text_config;
5466
5467        let cfg = ModelConfigMetadata {
5468            max_seq_len: cfg.max_position_embeddings,
5469            num_layers: cfg.num_hidden_layers,
5470            hidden_size: cfg.hidden_size,
5471            num_kv_heads: cfg.num_key_value_heads,
5472            num_attn_heads: cfg.num_attention_heads,
5473            sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5475            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5476        };
5477
5478        Ok(Box::new(cfg))
5479    }
5480
5481    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5482        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
5483    }
5484}
5485
5486pub struct Qwen3VLLoader;
5492
5493pub struct Qwen3VLPrefixer;
5494
5495impl MultimodalPromptPrefixer for Qwen3VLPrefixer {
5496    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
5497        format!(
5498            "{}{prompt}",
5499            format!(
5500                "{}{}{}",
5501                Qwen3VLProcessor::VISION_START,
5502                Qwen3VLProcessor::IMAGE_PAD,
5503                Qwen3VLProcessor::VISION_END
5504            )
5505            .repeat(image_indexes.len())
5506        )
5507    }
5508}
5509
5510impl VisionModelLoader for Qwen3VLLoader {
5511    fn load(
5512        &self,
5513        config: &str,
5514        vb: ShardedVarBuilder,
5515        normal_loading_metadata: NormalLoadingMetadata,
5516        attention_mechanism: AttentionImplementation,
5517    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
5518        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5519        Ok(Box::new(Qwen3VLModel::new(
5520            &cfg,
5521            vb,
5522            self.is_gptx(config),
5523            normal_loading_metadata,
5524            attention_mechanism,
5525        )?))
5526    }
5527    fn is_gptx(&self, _config: &str) -> bool {
5528        true
5529    }
5530    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5531        let config: Qwen3VLConfig = serde_json::from_str(config)?;
5532        Ok(Box::new(config))
5533    }
5534    fn get_processor(
5535        &self,
5536        _model_config: &str,
5537        _processor_config: Option<ProcessorConfig>,
5538        _preprocessor_config: PreProcessorConfig,
5539        max_edge: Option<u32>,
5540    ) -> Arc<dyn Processor + Send + Sync> {
5541        Arc::new(Qwen3VLProcessor::new(max_edge))
5542    }
5543    fn supports_paged_attention(&self, _config: &str) -> bool {
5544        true
5545    }
5546    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5547        Arc::new(Qwen3VLPrefixer)
5548    }
5549    fn modalities(&self, _config: &str) -> Result<Modalities> {
5550        Ok(Modalities {
5551            input: vec![SupportedModality::Text, SupportedModality::Vision],
5552            output: vec![SupportedModality::Text],
5553        })
5554    }
5555}
5556
5557impl IsqModelLoader for Qwen3VLLoader {
5558    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5559        Ok(vec![
5560            Regex::new(r"lm_head\.(weight|bias)$")?,
5561            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5563            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5564            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5565            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5566            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5568            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5569            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5570        ])
5571    }
5572    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5573        self.isq_layer_regexes(config)
5574    }
5575}
5576
5577impl DeviceMappedModelLoader for Qwen3VLLoader {
5578    fn mapped_max_act_size_elems(
5579        &self,
5580        config: &str,
5581        params: &AutoDeviceMapParams,
5582    ) -> Result<usize> {
5583        let AutoDeviceMapParams::Vision {
5584            max_seq_len,
5585            max_batch_size,
5586            max_image_shape,
5587            max_num_images,
5588        } = params
5589        else {
5590            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5591        };
5592
5593        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5594
5595        let img_seq_len = {
5596            let cfg = &cfg.vision_config;
5597            let grid_t = max_num_images / cfg.temporal_patch_size;
5598            let grid_h = max_image_shape.0 / cfg.patch_size;
5599            let grid_w = max_image_shape.1 / cfg.patch_size;
5600            grid_t * grid_h * grid_w
5601        };
5602        let img_seq_len = img_seq_len * max_num_images;
5603
5604        let max_text_attn = {
5605            let cfg = &cfg.text_config;
5606            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5608            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5609        };
5610
5611        Ok(max_text_attn)
5612    }
5613
5614    fn non_mapped_max_act_size_elems(
5615        &self,
5616        config: &str,
5617        params: &AutoDeviceMapParams,
5618    ) -> Result<usize> {
5619        let AutoDeviceMapParams::Vision {
5620            max_seq_len: _,
5621            max_batch_size,
5622            max_image_shape,
5623            max_num_images,
5624        } = params
5625        else {
5626            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5627        };
5628
5629        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5630
5631        let img_seq_len = {
5632            let cfg = &cfg.vision_config;
5633            let grid_t = max_num_images / cfg.temporal_patch_size;
5634            let grid_h = max_image_shape.0 / cfg.patch_size;
5635            let grid_w = max_image_shape.1 / cfg.patch_size;
5636            grid_t * grid_h * grid_w
5637        };
5638
5639        let max_vision_attn = {
5640            let cfg = &cfg.vision_config;
5641            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5642        };
5643
5644        Ok(max_vision_attn)
5645    }
5646
5647    fn non_mapped_size_in_bytes(
5648        &self,
5649        config: &str,
5650        dtype: DType,
5651        weight_pack_factor: usize,
5652        _matformer_config: Option<&MatformerSliceConfig>,
5653    ) -> Result<usize> {
5654        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5655        let tie = cfg.tie_word_embeddings;
5656        let text_elems = {
5657            let cfg = &cfg.text_config;
5658            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5659            let lm_head = if !tie || weight_pack_factor != 1 {
5661                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5662            } else {
5663                0
5664            };
5665            let norm = cfg.hidden_size;
5666            embed_tokens + lm_head + norm
5667        };
5668
5669        let patch_merger = {
5670            let cfg = &cfg.vision_config;
5671            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
5672
5673            let mlp0 = hidden_size * hidden_size + hidden_size;
5674            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
5675
5676            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5677
5678            mlp0 + mlp2 + ln_q
5679        };
5680
5681        let patch_embed = {
5682            let cfg = &cfg.vision_config;
5683            let conv_cfg = Conv3dConfig {
5684                stride: cfg.patch_size,
5685                ..Default::default()
5686            };
5687            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
5688            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
5689                * kernel_sizes[0]
5690                * kernel_sizes[1]
5691                * kernel_sizes[2]
5692        };
5693
5694        let encoder_layer = {
5695            let cfg = &cfg.vision_config;
5696            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5697            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5698
5699            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
5700            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
5701            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
5702
5703            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
5704            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
5705
5706            norm1 + norm2 + fc1 + fc2 + qkv + out
5707        };
5708
5709        let elems =
5710            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
5711
5712        Ok(elems * dtype.size_in_bytes())
5713    }
5714
5715    fn layer_sizes_in_bytes(
5716        &self,
5717        config: &str,
5718        dtype: DType,
5719        weight_pack_factor: usize,
5720        _matformer_config: Option<&MatformerSliceConfig>,
5721    ) -> Result<Vec<usize>> {
5722        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5723        let per_layer_elems = {
5724            let cfg = &cfg.text_config;
5725            let input_layernorm = cfg.hidden_size;
5726            let post_attention_layernorm = cfg.hidden_size;
5727
5728            let size_in = cfg.hidden_size;
5729            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
5730            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
5731            let q_proj = size_in * size_q / weight_pack_factor + size_q;
5732            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
5733            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
5734            let o_proj = size_q * size_in / weight_pack_factor;
5735
5736            let h_size = cfg.hidden_size;
5737            let i_size = cfg.intermediate_size;
5738            let gate_proj = h_size * i_size / weight_pack_factor;
5739            let up_proj = h_size * i_size / weight_pack_factor;
5740            let down_proj = i_size * h_size / weight_pack_factor;
5741
5742            input_layernorm
5743                + post_attention_layernorm
5744                + q_proj
5745                + k_proj
5746                + v_proj
5747                + o_proj
5748                + gate_proj
5749                + up_proj
5750                + down_proj
5751        };
5752        Ok(vec![
5753            per_layer_elems * dtype.size_in_bytes();
5754            cfg.text_config.num_hidden_layers
5755        ])
5756    }
5757
5758    fn num_layers(&self, config: &str) -> Result<usize> {
5759        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5760        let cfg = &cfg.text_config;
5761        Ok(cfg.num_hidden_layers)
5762    }
5763
5764    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5765        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5766        let cfg = &cfg.text_config;
5767
5768        let cfg = ModelConfigMetadata {
5769            max_seq_len: cfg.max_position_embeddings,
5770            num_layers: cfg.num_hidden_layers,
5771            hidden_size: cfg.hidden_size,
5772            num_kv_heads: cfg.num_key_value_heads,
5773            num_attn_heads: cfg.num_attention_heads,
5774            sliding_window: cfg.sliding_window,
5775            k_head_dim: cfg.head_dim,
5776            v_head_dim: cfg.head_dim,
5777        };
5778
5779        Ok(Box::new(cfg))
5780    }
5781
5782    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5783        Some(vec![NonMappedSubModel::Vision])
5784    }
5785}