mistralrs_core/pipeline/loaders/
vision_loaders.rs

1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::device_map::DeviceMapper;
24use crate::layers::Conv3dConfig;
25use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
26use crate::pipeline::isq::IsqModelLoader;
27use crate::pipeline::loaders::AutoDeviceMapParams;
28use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
29use crate::pipeline::{EitherCache, IsqModel, Processor, ProcessorCreator, VisionPromptPrefixer};
30use crate::utils::varbuilder_utils::DeviceForLoadTensor;
31use crate::vision_models::clip::ClipConfig;
32use crate::vision_models::gemma3::config::Gemma3Config;
33use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
34use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
35use crate::vision_models::idefics2_input_processor::Idefics2Processor;
36use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
37use crate::vision_models::image_processor::ImagePreProcessor;
38use crate::vision_models::inputs_processor::Phi4MMProcessor;
39use crate::vision_models::llama4::{
40    self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
41};
42use crate::vision_models::llava::config::Config as LLaVAConfig;
43use crate::vision_models::llava15::Model as LLaVA;
44use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
45use crate::vision_models::llava_next::Model as LLaVANext;
46use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
47use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
48use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
49use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
50use crate::vision_models::phi3_inputs_processor::Phi3Processor;
51use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
52use crate::vision_models::preprocessor_config::PreProcessorConfig;
53use crate::vision_models::processor_config::ProcessorConfig;
54use crate::vision_models::qwen2_5_vl::{
55    Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
56};
57use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
58use crate::vision_models::{minicpmo, phi4};
59
60pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
61    // pixel_values and pixel_attention_mask only specified for prompt seqs
62    #[allow(clippy::too_many_arguments)]
63    fn forward(
64        &self,
65        input_ids: &Tensor,
66        pixel_values: Option<Tensor>,
67        seqlen_offsets: &[usize],
68        context_lens: Vec<(usize, usize)>,
69        position_ids: Vec<usize>,
70        model_specific_args: Box<dyn Any>, // pixel attention mask, or image sizes, or anything else
71        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
72        flash_params: &FlashParams,
73    ) -> candle_core::Result<Tensor>;
74    fn device(&self) -> &Device;
75    fn cache(&self) -> &EitherCache;
76    fn cache_mut(&mut self) -> &mut EitherCache;
77    fn max_seq_len(&self) -> usize;
78    fn config(&self) -> &ModelConfigMetadata;
79    /// For a prompt without images. Requires batch size of 1!
80    fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
81}
82
83pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
84    fn load(
85        &self,
86        config: &str,
87        vb: ShardedVarBuilder,
88        normal_loading_metadata: NormalLoadingMetadata,
89        attention_mechanism: AttentionImplementation,
90    ) -> Result<Box<dyn VisionModel + Send + Sync>>;
91    fn is_gptx(&self, config: &str) -> bool;
92    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
93    fn get_processor(
94        &self,
95        model_config: &str,
96        processor_config: Option<ProcessorConfig>,
97        preprocessor_config: PreProcessorConfig,
98        max_edge: Option<u32>,
99    ) -> Arc<dyn Processor + Send + Sync>;
100    fn supports_paged_attention(&self, config: &str) -> bool;
101    fn supports_prefix_cacher(&self, _config: &str) -> bool {
102        // Default is false, specific model must override.
103        false
104    }
105    fn prefixer(&self, config: &str) -> Arc<dyn VisionPromptPrefixer>;
106    fn get_device_for_tensor(
107        &self,
108        config: &str,
109        _mapper: &dyn DeviceMapper,
110        loading_isq: bool,
111    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
112        if loading_isq {
113            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
114        } else {
115            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
116            let num_layers = self.model_config(config)?.num_layers();
117            let closure = move |name: String| {
118                if let Some(captures) = re.captures(&name) {
119                    captures
120                        .get(1)
121                        .and_then(|m| m.as_str().parse::<usize>().ok())
122                        .map(|l| l.min(num_layers))
123                        .map(DeviceForLoadTensor::Idx)
124                        .unwrap_or(DeviceForLoadTensor::Base)
125                } else {
126                    DeviceForLoadTensor::Base
127                }
128            };
129
130            Ok(Arc::new(closure))
131        }
132    }
133}
134
135#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
136#[derive(Clone, Debug, Deserialize, PartialEq)]
137/// The architecture to load the vision model as.
138pub enum VisionLoaderType {
139    #[serde(rename = "phi3v")]
140    Phi3V,
141    #[serde(rename = "idefics2")]
142    Idefics2,
143    #[serde(rename = "llava_next")]
144    LLaVANext,
145    #[serde(rename = "llava")]
146    LLaVA,
147    #[serde(rename = "vllama")]
148    VLlama,
149    #[serde(rename = "qwen2vl")]
150    Qwen2VL,
151    #[serde(rename = "idefics3")]
152    Idefics3,
153    #[serde(rename = "minicpmo")]
154    MiniCpmO,
155    #[serde(rename = "phi4mm")]
156    Phi4MM,
157    #[serde(rename = "qwen2_5vl")]
158    Qwen2_5VL,
159    #[serde(rename = "gemma3")]
160    Gemma3,
161    #[serde(rename = "mistral3")]
162    Mistral3,
163    #[serde(rename = "llama4")]
164    Llama4,
165}
166
167// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
168impl VisionLoaderType {
169    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
170        match name {
171            "Phi3VForCausalLM" => Ok(Self::Phi3V),
172            "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
173            "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
174            "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
175            "MllamaForConditionalGeneration" => Ok(Self::VLlama),
176            "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
177            "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
178            "MiniCPMO" => Ok(Self::MiniCpmO),
179            "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
180            "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
181            "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
182            "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
183            "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
184            other => anyhow::bail!(
185                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
186            ),
187        }
188    }
189}
190
191impl FromStr for VisionLoaderType {
192    type Err = String;
193    fn from_str(s: &str) -> Result<Self, Self::Err> {
194        match s {
195            "phi3v" => Ok(Self::Phi3V),
196            "idefics2" => Ok(Self::Idefics2),
197            "llava_next" => Ok(Self::LLaVANext),
198            "llava" => Ok(Self::LLaVA),
199            "vllama" => Ok(Self::VLlama),
200            "qwen2vl" => Ok(Self::Qwen2VL),
201            "idefics3" => Ok(Self::Idefics3),
202            "minicpmo" => Ok(Self::MiniCpmO),
203            "phi4mm" => Ok(Self::Phi4MM),
204            "qwen2_5vl" => Ok(Self::Qwen2_5VL),
205            "gemma3" => Ok(Self::Gemma3),
206            "mistral3" => Ok(Self::Mistral3),
207            "llama4" => Ok(Self::Llama4),
208            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`.")),
209        }
210    }
211}
212
213impl std::fmt::Display for VisionLoaderType {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        let name = match self {
216            VisionLoaderType::Phi3V => "phi3v",
217            VisionLoaderType::Idefics2 => "idefics2",
218            VisionLoaderType::LLaVANext => "llava_next",
219            VisionLoaderType::LLaVA => "llava",
220            VisionLoaderType::VLlama => "vllama",
221            VisionLoaderType::Qwen2VL => "qwen2vl",
222            VisionLoaderType::Idefics3 => "idefics3",
223            VisionLoaderType::MiniCpmO => "minicpmo",
224            VisionLoaderType::Phi4MM => "phi4mm",
225            VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
226            VisionLoaderType::Gemma3 => "gemma3",
227            VisionLoaderType::Mistral3 => "mistral3",
228            VisionLoaderType::Llama4 => "llama4",
229        };
230        write!(f, "{name}")
231    }
232}
233
234#[derive(Deserialize)]
235struct AutoVisionLoaderConfig {
236    architectures: Vec<String>,
237}
238
239/// Automatically selects a VisionModelLoader implementation based on the JSON `architectures` field.
240pub struct AutoVisionLoader;
241
242impl AutoVisionLoader {
243    fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
244        let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
245        if auto_cfg.architectures.len() != 1 {
246            anyhow::bail!("Expected exactly one architecture in config");
247        }
248
249        let name = &auto_cfg.architectures[0];
250        let tp = VisionLoaderType::from_causal_lm_name(name)?;
251
252        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
253
254        // Delegate to the concrete loader
255        Ok(match tp {
256            VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
257            VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
258            VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
259            VisionLoaderType::LLaVA => Box::new(LLaVALoader),
260            VisionLoaderType::VLlama => Box::new(VLlamaLoader),
261            VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
262            VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
263            VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
264            VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
265            VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
266            VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
267            VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
268            VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
269        })
270    }
271}
272
273impl VisionModelLoader for AutoVisionLoader {
274    fn load(
275        &self,
276        config: &str,
277        vb: ShardedVarBuilder,
278        normal_loading_metadata: NormalLoadingMetadata,
279        attention_mechanism: AttentionImplementation,
280    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
281        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
282    }
283
284    fn is_gptx(&self, config: &str) -> bool {
285        Self::get_loader(config)
286            .expect("AutoVisionLoader get_loader")
287            .is_gptx(config)
288    }
289
290    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
291        Self::get_loader(config)?.get_config_repr(config)
292    }
293
294    fn get_processor(
295        &self,
296        model_config: &str,
297        proc_cfg: Option<ProcessorConfig>,
298        preproc_cfg: PreProcessorConfig,
299        max_edge: Option<u32>,
300    ) -> Arc<dyn Processor + Send + Sync> {
301        Self::get_loader(model_config)
302            .expect("AutoVisionLoader get_loader")
303            .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
304    }
305
306    fn supports_paged_attention(&self, config: &str) -> bool {
307        Self::get_loader(config)
308            .expect("AutoVisionLoader")
309            .supports_paged_attention(config)
310    }
311
312    fn supports_prefix_cacher(&self, config: &str) -> bool {
313        Self::get_loader(config)
314            .expect("AutoVisionLoader")
315            .supports_prefix_cacher(config)
316    }
317
318    fn prefixer(&self, config: &str) -> Arc<dyn VisionPromptPrefixer> {
319        Self::get_loader(config)
320            .expect("AutoVisionLoader")
321            .prefixer(config)
322    }
323
324    fn get_device_for_tensor(
325        &self,
326        config: &str,
327        mapper: &dyn DeviceMapper,
328        loading_isq: bool,
329    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
330        Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
331    }
332}
333
334impl IsqModelLoader for AutoVisionLoader {
335    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
336        Self::get_loader(config)?.isq_layer_regexes(config)
337    }
338    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
339        Self::get_loader(config)?.immediate_isq_predicates(config)
340    }
341}
342
343impl DeviceMappedModelLoader for AutoVisionLoader {
344    fn mapped_max_act_size_elems(
345        &self,
346        config: &str,
347        params: &AutoDeviceMapParams,
348        prompt_chunksize: usize,
349    ) -> Result<usize> {
350        Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
351    }
352    fn non_mapped_max_act_size_elems(
353        &self,
354        config: &str,
355        params: &AutoDeviceMapParams,
356    ) -> Result<usize> {
357        Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
358    }
359    fn non_mapped_size_in_bytes(
360        &self,
361        config: &str,
362        dtype: DType,
363        weight_pack_factor: usize,
364    ) -> Result<usize> {
365        Self::get_loader(config)?.non_mapped_size_in_bytes(config, dtype, weight_pack_factor)
366    }
367    fn layer_sizes_in_bytes(
368        &self,
369        config: &str,
370        dtype: DType,
371        weight_pack_factor: usize,
372    ) -> Result<Vec<usize>> {
373        Self::get_loader(config)?.layer_sizes_in_bytes(config, dtype, weight_pack_factor)
374    }
375    fn num_layers(&self, config: &str) -> Result<usize> {
376        Self::get_loader(config)?.num_layers(config)
377    }
378    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
379        Self::get_loader(config)?.model_config(config)
380    }
381}
382
383macro_rules! bias_if {
384    ($cond:expr, $size:expr) => {
385        if $cond {
386            $size
387        } else {
388            0
389        }
390    };
391}
392
393fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
394    let pre_layer_norm = cfg.hidden_size;
395    let final_layer_norm = cfg.hidden_size;
396
397    let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
398    let num_positions = num_patches + 1;
399
400    let class_embedding = cfg.hidden_size;
401
402    let position_ids = num_positions;
403    let position_embedding = num_positions * cfg.hidden_size;
404
405    let conv2dconfig = Conv2dConfig {
406        stride: cfg.patch_size,
407        ..Default::default()
408    };
409    let patch_embedding =
410        cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
411
412    let encoder_layer_elems = {
413        let layer_norm1 = cfg.hidden_size;
414        let layer_norm2 = cfg.hidden_size;
415
416        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
417        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
418        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
419        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
420
421        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
422        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
423
424        layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
425    };
426
427    pre_layer_norm
428        + final_layer_norm
429        + class_embedding
430        + position_ids
431        + position_embedding
432        + patch_embedding
433        + cfg.num_hidden_layers * encoder_layer_elems
434}
435
436// ======================== Phi 3 loader
437
438/// [`VisionLoader`] for a Phi 3 Vision model.
439///
440/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
441pub struct Phi3VLoader;
442
443pub struct Phi3VPrefixer;
444
445impl VisionPromptPrefixer for Phi3VPrefixer {
446    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
447        // Image indexing starts at 0.
448        format!(
449            "{}{prompt}",
450            image_indexes
451                .into_iter()
452                .map(|image_index| format!("<|image_{}|>", image_index + 1))
453                .join("")
454        )
455    }
456}
457
458impl VisionModelLoader for Phi3VLoader {
459    fn load(
460        &self,
461        config: &str,
462        vb: ShardedVarBuilder,
463        normal_loading_metadata: NormalLoadingMetadata,
464        attention_mechanism: AttentionImplementation,
465    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
466        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
467        Ok(Box::new(Phi3::new(
468            &cfg,
469            vb,
470            self.is_gptx(config),
471            normal_loading_metadata,
472            attention_mechanism,
473        )?))
474    }
475    fn is_gptx(&self, _config: &str) -> bool {
476        true
477    }
478    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
479        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
480        Ok(Box::new(cfg))
481    }
482    fn get_processor(
483        &self,
484        _model_config: &str,
485        processor_config: Option<ProcessorConfig>,
486        preprocessor_config: PreProcessorConfig,
487        _max_edge: Option<u32>,
488    ) -> Arc<dyn Processor + Send + Sync> {
489        Phi3Processor::new_processor(processor_config, preprocessor_config)
490    }
491    fn supports_paged_attention(&self, _config: &str) -> bool {
492        true
493    }
494    fn supports_prefix_cacher(&self, _config: &str) -> bool {
495        true
496    }
497    fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
498        Arc::new(Phi3VPrefixer)
499    }
500}
501
502impl IsqModelLoader for Phi3VLoader {
503    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
504        Ok(vec![
505            Regex::new(r"lm_head\.(weight|bias)$")?,
506            // Attention
507            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
508            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
509            // MLP
510            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
511            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
512        ])
513    }
514    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
515        self.isq_layer_regexes(config)
516    }
517}
518
519impl DeviceMappedModelLoader for Phi3VLoader {
520    fn mapped_max_act_size_elems(
521        &self,
522        config: &str,
523        params: &AutoDeviceMapParams,
524        _prompt_chunksize: usize,
525    ) -> Result<usize> {
526        // NOTE: we ignore max_num_images although it can only be one...
527        let AutoDeviceMapParams::Vision {
528            max_seq_len,
529            max_batch_size,
530            max_image_shape: _,
531            max_num_images,
532        } = params
533        else {
534            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
535        };
536
537        let cfg: Phi3Config = serde_json::from_str(config)?;
538
539        let vcfg = &PHI3V_CLIP_CONFIG;
540
541        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
542        let img_seq_len = (num_patches + 1) * max_num_images;
543
544        let max_text_attn = {
545            // This model injects the vision information directly into the input embeddings
546            let max_seq_len = img_seq_len + max_seq_len;
547            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
548        };
549
550        Ok(max_text_attn)
551    }
552
553    fn non_mapped_max_act_size_elems(
554        &self,
555        config: &str,
556        params: &AutoDeviceMapParams,
557    ) -> Result<usize> {
558        // NOTE: we ignore max_num_images although it can only be one...
559        let AutoDeviceMapParams::Vision {
560            max_seq_len: _,
561            max_batch_size,
562            max_image_shape: _,
563            max_num_images,
564        } = params
565        else {
566            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
567        };
568
569        let cfg: Phi3Config = serde_json::from_str(config)?;
570
571        let vcfg = &PHI3V_CLIP_CONFIG;
572
573        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
574        let img_seq_len = num_patches + 1;
575
576        let max_vision_attn = {
577            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
578        };
579
580        Ok(max_vision_attn)
581    }
582
583    fn non_mapped_size_in_bytes(
584        &self,
585        config: &str,
586        dtype: DType,
587        weight_pack_factor: usize,
588    ) -> Result<usize> {
589        let cfg: Phi3Config = serde_json::from_str(config)?;
590        let elems = {
591            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
592            let lm_head = if !cfg.tie_word_embeddings {
593                cfg.hidden_size * cfg.vocab_size
594            } else {
595                0
596            };
597            let norm = cfg.hidden_size;
598
599            let image_embed = {
600                let projection_cls = cfg
601                    .embd_layer
602                    .projection_cls
603                    .clone()
604                    .unwrap_or("linear".to_string());
605                let with_learnable_separator =
606                    cfg.embd_layer.with_learnable_separator.unwrap_or(false);
607                let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
608                let image_dim_out = cfg.img_processor.image_dim_out;
609
610                let proj = match (projection_cls.as_str(), use_hd_transform) {
611                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
612                    ("mlp", true) => {
613                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
614                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
615                        a + b
616                    }
617                    ("mlp", false) => {
618                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
619                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
620                        a + b
621                    }
622                    _ => {
623                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
624                    }
625                };
626
627                let (glb_gn, sub_gn) = if with_learnable_separator {
628                    let glb_gn = image_dim_out * 4;
629                    let sub_gn = image_dim_out * 4;
630                    (glb_gn, sub_gn)
631                } else {
632                    (0, 0)
633                };
634
635                let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
636
637                proj + glb_gn + sub_gn + clip_vit
638            };
639
640            embed_tokens + lm_head + norm + image_embed
641        };
642
643        Ok(elems * dtype.size_in_bytes())
644    }
645
646    fn layer_sizes_in_bytes(
647        &self,
648        config: &str,
649        dtype: DType,
650        weight_pack_factor: usize,
651    ) -> Result<Vec<usize>> {
652        let cfg: Phi3Config = serde_json::from_str(config)?;
653        let per_layer_elems = {
654            let input_layernorm = cfg.hidden_size;
655            let post_attention_layernorm = cfg.hidden_size;
656
657            let size_in = cfg.hidden_size;
658            let head_dim = cfg.head_dim();
659            let op_size =
660                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
661            let qkv_proj = size_in * op_size / weight_pack_factor;
662            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
663
664            let h_size = cfg.hidden_size;
665            let i_size = cfg.intermediate_size;
666            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
667            let down_proj = h_size * i_size / weight_pack_factor;
668
669            input_layernorm
670                + post_attention_layernorm
671                + qkv_proj
672                + o_proj
673                + gate_up_proj
674                + down_proj
675        };
676        Ok(vec![
677            per_layer_elems * dtype.size_in_bytes();
678            cfg.num_hidden_layers
679        ])
680    }
681
682    fn num_layers(&self, config: &str) -> Result<usize> {
683        let cfg: Phi3Config = serde_json::from_str(config)?;
684        Ok(cfg.num_hidden_layers)
685    }
686
687    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
688        let cfg: Phi3Config = serde_json::from_str(config)?;
689
690        let cfg = ModelConfigMetadata {
691            max_seq_len: cfg.max_position_embeddings,
692            num_layers: cfg.num_hidden_layers,
693            hidden_size: cfg.hidden_size,
694            num_kv_heads: cfg.num_key_value_heads,
695            num_attn_heads: cfg.num_attention_heads,
696            sliding_window: cfg.sliding_window,
697            k_head_dim: cfg.head_dim(),
698            v_head_dim: cfg.head_dim(),
699        };
700
701        Ok(Box::new(cfg))
702    }
703
704    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
705        Some(vec![NonMappedSubModel::Vision])
706    }
707}
708
709// ======================== Idefics 2 loader
710
711/// [`VisionLoader`] for an Idefics 2 Vision model.
712///
713/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
714pub struct Idefics2Loader;
715
716pub struct Idefics2Prefixer;
717
718impl VisionPromptPrefixer for Idefics2Prefixer {
719    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
720        // Chat template does it
721        prompt.to_string()
722    }
723}
724
725impl VisionModelLoader for Idefics2Loader {
726    fn load(
727        &self,
728        config: &str,
729        vb: ShardedVarBuilder,
730        normal_loading_metadata: NormalLoadingMetadata,
731        attention_mechanism: AttentionImplementation,
732    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
733        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
734        Ok(Box::new(Idefics2::new(
735            &cfg,
736            vb,
737            self.is_gptx(config),
738            normal_loading_metadata,
739            attention_mechanism,
740        )?))
741    }
742    fn is_gptx(&self, _config: &str) -> bool {
743        true
744    }
745    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
746        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
747        Ok(Box::new(cfg))
748    }
749    fn get_processor(
750        &self,
751        _model_config: &str,
752        processor_config: Option<ProcessorConfig>,
753        preprocessor_config: PreProcessorConfig,
754        max_edge: Option<u32>,
755    ) -> Arc<dyn Processor + Send + Sync> {
756        Arc::new(Idefics2Processor::new(
757            processor_config.unwrap(),
758            preprocessor_config,
759            max_edge,
760        ))
761    }
762    fn supports_paged_attention(&self, _config: &str) -> bool {
763        true
764    }
765    fn supports_prefix_cacher(&self, _config: &str) -> bool {
766        true
767    }
768    fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
769        Arc::new(Idefics2Prefixer)
770    }
771}
772
773impl IsqModelLoader for Idefics2Loader {
774    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
775        Ok(vec![
776            Regex::new(r"lm_head\.(weight|bias)$")?,
777            // Attention
778            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
779            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
780            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
781            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
782            // MLP
783            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
784            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
785            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
786        ])
787    }
788    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
789        Ok(vec![
790            Regex::new(r"lm_head\.(weight|bias)$")?,
791            // Attention
792            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
793            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
794            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
795            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
796            // MLP
797            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
798            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
799            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
800        ])
801    }
802}
803
804impl DeviceMappedModelLoader for Idefics2Loader {
805    fn mapped_max_act_size_elems(
806        &self,
807        config: &str,
808        params: &AutoDeviceMapParams,
809        _prompt_chunksize: usize,
810    ) -> Result<usize> {
811        let AutoDeviceMapParams::Vision {
812            max_seq_len,
813            max_batch_size,
814            max_image_shape: _,
815            max_num_images,
816        } = params
817        else {
818            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
819        };
820
821        let cfg: Idefics2Config = serde_json::from_str(config)?;
822
823        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
824        let img_seq_len = (num_patches + 1) * max_num_images;
825
826        let max_text_attn = {
827            // This model injects the vision information directly into the input embeddings
828            let max_seq_len = img_seq_len + max_seq_len;
829            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
830        };
831
832        Ok(max_text_attn)
833    }
834
835    fn non_mapped_max_act_size_elems(
836        &self,
837        config: &str,
838        params: &AutoDeviceMapParams,
839    ) -> Result<usize> {
840        let AutoDeviceMapParams::Vision {
841            max_seq_len: _,
842            max_batch_size,
843            max_image_shape: _,
844            max_num_images,
845        } = params
846        else {
847            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
848        };
849
850        let cfg: Idefics2Config = serde_json::from_str(config)?;
851
852        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
853        let img_seq_len = num_patches + 1;
854
855        let max_vision_attn = {
856            // do_image_splitting = true
857            let images_factor = 5;
858
859            (max_batch_size * images_factor * max_num_images)
860                * cfg.vision_config.num_attention_heads
861                * img_seq_len
862                * img_seq_len
863        };
864
865        Ok(max_vision_attn)
866    }
867
868    fn non_mapped_size_in_bytes(
869        &self,
870        config: &str,
871        dtype: DType,
872        weight_pack_factor: usize,
873    ) -> Result<usize> {
874        let cfg: Idefics2Config = serde_json::from_str(config)?;
875        let text_elems = {
876            let tie_word_embeddings = cfg.tie_word_embeddings;
877            let cfg = &cfg.text_config;
878
879            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
880            let lm_head = if !tie_word_embeddings {
881                cfg.hidden_size * cfg.vocab_size
882            } else {
883                0
884            };
885            let norm = cfg.hidden_size;
886            embed_tokens + lm_head + norm
887        };
888
889        let connector_elems = {
890            let tcfg = &cfg.text_config;
891            let vcfg = &cfg.vision_config;
892            let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
893            let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
894            let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
895
896            let perceiver_elems = {
897                let tcfg = &cfg.text_config;
898                let pcfg = &cfg.perceiver_config;
899
900                let n_latents = pcfg.resampler_n_latents;
901                let hidden_size = tcfg.hidden_size;
902                let depth = pcfg.resampler_depth;
903
904                let norm = tcfg.hidden_size;
905                let latents = n_latents * hidden_size;
906
907                let layer_elems = {
908                    let input_latents_norm = hidden_size;
909                    let input_context_norm = hidden_size;
910                    let post_attn_norm = hidden_size;
911
912                    let num_heads = pcfg.resampler_n_heads;
913                    let head_dim = pcfg.resampler_head_dim;
914                    let num_key_value_heads = pcfg.num_key_value_heads;
915
916                    let q_proj = hidden_size * num_heads * head_dim;
917                    let k_proj = hidden_size * num_key_value_heads * head_dim;
918                    let v_proj = hidden_size * num_key_value_heads * head_dim;
919                    let o_proj = num_heads * head_dim * hidden_size;
920
921                    let gate_proj = hidden_size * hidden_size * 4;
922                    let up_proj = hidden_size * hidden_size * 4;
923                    let down_proj = hidden_size * 4 * hidden_size;
924
925                    input_latents_norm
926                        + input_context_norm
927                        + post_attn_norm
928                        + q_proj
929                        + k_proj
930                        + v_proj
931                        + o_proj
932                        + gate_proj
933                        + up_proj
934                        + down_proj
935                };
936
937                norm + latents + layer_elems * depth
938            };
939
940            gate_proj + up_proj + down_proj + perceiver_elems
941        };
942
943        let vision_transformer = {
944            let cfg = &cfg.vision_config;
945
946            let post_layernorm = cfg.hidden_size;
947
948            let conv_config = Conv2dConfig {
949                stride: cfg.patch_size,
950                ..Default::default()
951            };
952            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
953                * cfg.patch_size
954                * cfg.patch_size;
955
956            let num_patches_per_side = cfg.image_size / cfg.patch_size;
957            let num_patches = num_patches_per_side.pow(2);
958            let position_embedding = num_patches * cfg.hidden_size;
959
960            let layer_elems = {
961                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
962                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
963
964                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
965                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
966
967                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
968                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
969                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
970                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
971
972                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
973            };
974
975            post_layernorm + patch_embedding + position_embedding + layer_elems
976        };
977
978        let elems = text_elems + connector_elems + vision_transformer;
979
980        Ok(elems * dtype.size_in_bytes())
981    }
982
983    fn layer_sizes_in_bytes(
984        &self,
985        config: &str,
986        dtype: DType,
987        weight_pack_factor: usize,
988    ) -> Result<Vec<usize>> {
989        let cfg: Idefics2Config = serde_json::from_str(config)?;
990        let cfg = cfg.text_config;
991        let per_layer_elems = {
992            let input_layernorm = cfg.hidden_size;
993            let post_attention_layernorm = cfg.hidden_size;
994
995            let size_in = cfg.hidden_size;
996            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
997            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
998            let q_proj = size_in * size_q / weight_pack_factor;
999            let k_proj = size_in * size_kv / weight_pack_factor;
1000            let v_proj = size_in * size_kv / weight_pack_factor;
1001            let o_proj = size_q * size_in / weight_pack_factor;
1002
1003            let h_size = cfg.hidden_size;
1004            let i_size = cfg.intermediate_size;
1005            let gate_proj = h_size * i_size / weight_pack_factor;
1006            let up_proj = h_size * i_size / weight_pack_factor;
1007            let down_proj = i_size * h_size / weight_pack_factor;
1008
1009            input_layernorm
1010                + post_attention_layernorm
1011                + q_proj
1012                + k_proj
1013                + v_proj
1014                + o_proj
1015                + gate_proj
1016                + up_proj
1017                + down_proj
1018        };
1019        Ok(vec![
1020            per_layer_elems * dtype.size_in_bytes();
1021            cfg.num_hidden_layers
1022        ])
1023    }
1024
1025    fn num_layers(&self, config: &str) -> Result<usize> {
1026        let cfg: Idefics2Config = serde_json::from_str(config)?;
1027        Ok(cfg.text_config.num_hidden_layers)
1028    }
1029    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1030        let cfg: Idefics2Config = serde_json::from_str(config)?;
1031        let cfg = &cfg.text_config;
1032
1033        let cfg = ModelConfigMetadata {
1034            max_seq_len: cfg.max_position_embeddings,
1035            num_layers: cfg.num_hidden_layers,
1036            hidden_size: cfg.hidden_size,
1037            num_kv_heads: cfg.num_key_value_heads,
1038            num_attn_heads: cfg.num_attention_heads,
1039            sliding_window: cfg.sliding_window,
1040            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1041            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1042        };
1043
1044        Ok(Box::new(cfg))
1045    }
1046
1047    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1048        Some(vec![NonMappedSubModel::Vision])
1049    }
1050}
1051
1052// ======================== LLaVANext Loader
1053
1054/// [`VisionLoader`] for an LLaVANext Vision model.
1055///
1056/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1057pub struct LLaVANextLoader;
1058
1059pub struct LLaVANextPrefixer;
1060
1061impl VisionPromptPrefixer for LLaVANextPrefixer {
1062    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1063        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1064    }
1065}
1066
1067impl VisionModelLoader for LLaVANextLoader {
1068    fn load(
1069        &self,
1070        config: &str,
1071        vb: ShardedVarBuilder,
1072        normal_loading_metadata: NormalLoadingMetadata,
1073        attention_mechanism: AttentionImplementation,
1074    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1075        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1076        Ok(Box::new(LLaVANext::new(
1077            &cfg,
1078            vb,
1079            self.is_gptx(config),
1080            normal_loading_metadata,
1081            attention_mechanism,
1082        )?))
1083    }
1084    fn is_gptx(&self, _config: &str) -> bool {
1085        false
1086    }
1087    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1088        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1089        Ok(Box::new(cfg))
1090    }
1091    fn get_processor(
1092        &self,
1093        model_config: &str,
1094        _processor_config: Option<ProcessorConfig>,
1095        _preprocessor_config: PreProcessorConfig,
1096        _max_edge: Option<u32>,
1097    ) -> Arc<dyn Processor + Send + Sync> {
1098        Arc::new(LLaVANextProcessor::new(model_config))
1099    }
1100    fn supports_paged_attention(&self, _config: &str) -> bool {
1101        true
1102    }
1103    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1104        true
1105    }
1106    fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
1107        Arc::new(LLaVANextPrefixer)
1108    }
1109}
1110
1111impl IsqModelLoader for LLaVANextLoader {
1112    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1113        Ok(vec![
1114            Regex::new(r"lm_head\.(weight|bias)$")?,
1115            // Attention
1116            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1117            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1118            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1119            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1120            // MLP
1121            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1122            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1123            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1124        ])
1125    }
1126    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1127        Ok(vec![
1128            Regex::new(r"lm_head\.(weight|bias)$")?,
1129            // Attention
1130            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1131            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1132            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1133            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1134            // MLP
1135            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1136            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1137            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1138        ])
1139    }
1140}
1141
1142impl DeviceMappedModelLoader for LLaVANextLoader {
1143    fn mapped_max_act_size_elems(
1144        &self,
1145        config: &str,
1146        params: &AutoDeviceMapParams,
1147        _prompt_chunksize: usize,
1148    ) -> Result<usize> {
1149        let AutoDeviceMapParams::Vision {
1150            max_seq_len,
1151            max_batch_size,
1152            max_image_shape,
1153            max_num_images,
1154        } = params
1155        else {
1156            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1157        };
1158
1159        let config: LLaVAConfig = serde_json::from_str(config)?;
1160
1161        #[allow(clippy::cast_possible_truncation)]
1162        let img_seq_len =
1163            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1164                &config,
1165                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1166            );
1167        let img_seq_len = img_seq_len * max_num_images;
1168
1169        let max_text_attn = {
1170            let cfg = &config.text_config;
1171            // This model injects the vision information directly into the input embeddings
1172            let max_seq_len = img_seq_len + max_seq_len;
1173
1174            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1175        };
1176
1177        Ok(max_text_attn)
1178    }
1179
1180    fn non_mapped_max_act_size_elems(
1181        &self,
1182        config: &str,
1183        params: &AutoDeviceMapParams,
1184    ) -> Result<usize> {
1185        let AutoDeviceMapParams::Vision {
1186            max_seq_len: _,
1187            max_batch_size,
1188            max_image_shape,
1189            max_num_images,
1190        } = params
1191        else {
1192            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1193        };
1194
1195        let config: LLaVAConfig = serde_json::from_str(config)?;
1196
1197        #[allow(clippy::cast_possible_truncation)]
1198        let img_seq_len =
1199            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1200                &config,
1201                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1202            );
1203
1204        let max_vision_attn = {
1205            (max_batch_size * max_num_images)
1206                * config.vision_config.num_attention_heads
1207                * img_seq_len
1208                * img_seq_len
1209        };
1210
1211        Ok(max_vision_attn)
1212    }
1213
1214    fn non_mapped_size_in_bytes(
1215        &self,
1216        config: &str,
1217        dtype: DType,
1218        weight_pack_factor: usize,
1219    ) -> Result<usize> {
1220        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1221        let text_elems = {
1222            let cfg = &cfg.text_config;
1223            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1224            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1225            let norm = cfg.hidden_size;
1226            embed_tokens + lm_head + norm
1227        };
1228
1229        let image_newline = cfg.text_config.hidden_size;
1230        let mmproj = {
1231            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1232                + cfg.text_config.hidden_size;
1233            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1234                + cfg.text_config.hidden_size;
1235
1236            linear_1 + linear_2
1237        };
1238        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1239
1240        let elems = text_elems + image_newline + mmproj + vision_tower;
1241        Ok(elems * dtype.size_in_bytes())
1242    }
1243
1244    fn layer_sizes_in_bytes(
1245        &self,
1246        config: &str,
1247        dtype: DType,
1248        weight_pack_factor: usize,
1249    ) -> Result<Vec<usize>> {
1250        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1251        let per_layer_elems = {
1252            let cfg = &cfg.text_config;
1253            let input_layernorm = cfg.hidden_size;
1254            let post_attention_layernorm = cfg.hidden_size;
1255
1256            let size_in = cfg.hidden_size;
1257            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1258            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1259            let q_proj = size_in * size_q / weight_pack_factor;
1260            let k_proj = size_in * size_kv / weight_pack_factor;
1261            let v_proj = size_in * size_kv / weight_pack_factor;
1262            let o_proj = size_q * size_in / weight_pack_factor;
1263
1264            let h_size = cfg.hidden_size;
1265            let i_size = cfg.intermediate_size;
1266            let gate_proj = h_size * i_size / weight_pack_factor;
1267            let up_proj = h_size * i_size / weight_pack_factor;
1268            let down_proj = i_size * h_size / weight_pack_factor;
1269
1270            input_layernorm
1271                + post_attention_layernorm
1272                + q_proj
1273                + k_proj
1274                + v_proj
1275                + o_proj
1276                + gate_proj
1277                + up_proj
1278                + down_proj
1279        };
1280        Ok(vec![
1281            per_layer_elems * dtype.size_in_bytes();
1282            cfg.text_config.num_hidden_layers
1283        ])
1284    }
1285
1286    fn num_layers(&self, config: &str) -> Result<usize> {
1287        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1288        Ok(cfg.text_config.num_hidden_layers)
1289    }
1290
1291    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1292        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1293        let cfg = &cfg.text_config;
1294
1295        let cfg = ModelConfigMetadata {
1296            max_seq_len: cfg.max_position_embeddings,
1297            num_layers: cfg.num_hidden_layers,
1298            hidden_size: cfg.hidden_size,
1299            num_kv_heads: cfg.num_key_value_heads,
1300            num_attn_heads: cfg.num_attention_heads,
1301            sliding_window: cfg.sliding_window,
1302            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1303            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1304        };
1305
1306        Ok(Box::new(cfg))
1307    }
1308
1309    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1310        Some(vec![NonMappedSubModel::Vision])
1311    }
1312}
1313
1314// ======================== LLaVA Loader
1315
1316/// [`VisionLoader`] for an LLaVA Vision model.
1317///
1318/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1319pub struct LLaVALoader;
1320
1321pub struct LLaVAPrefixer;
1322
1323impl VisionPromptPrefixer for LLaVAPrefixer {
1324    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1325        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1326    }
1327}
1328
1329impl VisionModelLoader for LLaVALoader {
1330    fn load(
1331        &self,
1332        config: &str,
1333        vb: ShardedVarBuilder,
1334        normal_loading_metadata: NormalLoadingMetadata,
1335        attention_mechanism: AttentionImplementation,
1336    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1337        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1338        Ok(Box::new(LLaVA::new(
1339            &cfg,
1340            vb,
1341            self.is_gptx(config),
1342            normal_loading_metadata,
1343            attention_mechanism,
1344        )?))
1345    }
1346    fn is_gptx(&self, _config: &str) -> bool {
1347        false
1348    }
1349    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1350        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1351        Ok(Box::new(cfg))
1352    }
1353    fn get_processor(
1354        &self,
1355        model_config: &str,
1356        _processor_config: Option<ProcessorConfig>,
1357        _preprocessor_config: PreProcessorConfig,
1358        _max_edge: Option<u32>,
1359    ) -> Arc<dyn Processor + Send + Sync> {
1360        Arc::new(LLaVAProcessor::new(model_config))
1361    }
1362    fn supports_paged_attention(&self, _config: &str) -> bool {
1363        true
1364    }
1365    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1366        true
1367    }
1368    fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
1369        Arc::new(LLaVAPrefixer)
1370    }
1371}
1372
1373impl IsqModelLoader for LLaVALoader {
1374    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1375        Ok(vec![
1376            Regex::new(r"lm_head\.(weight|bias)$")?,
1377            // Attention
1378            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1379            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1380            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1381            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1382            // MLP
1383            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1384            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1385            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1386        ])
1387    }
1388    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1389        Ok(vec![
1390            Regex::new(r"lm_head\.(weight|bias)$")?,
1391            // Attention
1392            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1393            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1394            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1395            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1396            // MLP
1397            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1398            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1399            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1400        ])
1401    }
1402}
1403
1404impl DeviceMappedModelLoader for LLaVALoader {
1405    fn mapped_max_act_size_elems(
1406        &self,
1407        config: &str,
1408        params: &AutoDeviceMapParams,
1409        _prompt_chunksize: usize,
1410    ) -> Result<usize> {
1411        let AutoDeviceMapParams::Vision {
1412            max_seq_len,
1413            max_batch_size,
1414            max_image_shape: _,
1415            max_num_images,
1416        } = params
1417        else {
1418            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1419        };
1420
1421        let config: LLaVAConfig = serde_json::from_str(config)?;
1422
1423        let img_seq_len =
1424            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1425        let img_seq_len = img_seq_len * max_num_images;
1426
1427        let max_text_attn = {
1428            let cfg = &config.text_config;
1429            // This model injects the vision information directly into the input embeddings
1430            let max_seq_len = img_seq_len + max_seq_len;
1431
1432            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1433        };
1434
1435        Ok(max_text_attn)
1436    }
1437
1438    fn non_mapped_max_act_size_elems(
1439        &self,
1440        config: &str,
1441        params: &AutoDeviceMapParams,
1442    ) -> Result<usize> {
1443        let AutoDeviceMapParams::Vision {
1444            max_seq_len: _,
1445            max_batch_size,
1446            max_image_shape: _,
1447            max_num_images,
1448        } = params
1449        else {
1450            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1451        };
1452
1453        let config: LLaVAConfig = serde_json::from_str(config)?;
1454
1455        let img_seq_len =
1456            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1457
1458        let max_vision_attn = {
1459            (max_batch_size * max_num_images)
1460                * config.vision_config.num_attention_heads
1461                * img_seq_len
1462                * img_seq_len
1463        };
1464
1465        Ok(max_vision_attn)
1466    }
1467
1468    fn non_mapped_size_in_bytes(
1469        &self,
1470        config: &str,
1471        dtype: DType,
1472        weight_pack_factor: usize,
1473    ) -> Result<usize> {
1474        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1475        let text_elems = {
1476            let cfg = &cfg.text_config;
1477            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1478            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1479            let norm = cfg.hidden_size;
1480            embed_tokens + lm_head + norm
1481        };
1482
1483        let image_newline = cfg.text_config.hidden_size;
1484        let mmproj = {
1485            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1486                + cfg.text_config.hidden_size;
1487            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1488                + cfg.text_config.hidden_size;
1489
1490            linear_1 + linear_2
1491        };
1492        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1493
1494        let elems = text_elems + image_newline + mmproj + vision_tower;
1495        Ok(elems * dtype.size_in_bytes())
1496    }
1497
1498    fn layer_sizes_in_bytes(
1499        &self,
1500        config: &str,
1501        dtype: DType,
1502        weight_pack_factor: usize,
1503    ) -> Result<Vec<usize>> {
1504        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1505        let per_layer_elems = {
1506            let cfg = &cfg.text_config;
1507            let input_layernorm = cfg.hidden_size;
1508            let post_attention_layernorm = cfg.hidden_size;
1509
1510            let size_in = cfg.hidden_size;
1511            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1512            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1513            let q_proj = size_in * size_q / weight_pack_factor;
1514            let k_proj = size_in * size_kv / weight_pack_factor;
1515            let v_proj = size_in * size_kv / weight_pack_factor;
1516            let o_proj = size_q * size_in / weight_pack_factor;
1517
1518            let h_size = cfg.hidden_size;
1519            let i_size = cfg.intermediate_size;
1520            let gate_proj = h_size * i_size / weight_pack_factor;
1521            let up_proj = h_size * i_size / weight_pack_factor;
1522            let down_proj = i_size * h_size / weight_pack_factor;
1523
1524            input_layernorm
1525                + post_attention_layernorm
1526                + q_proj
1527                + k_proj
1528                + v_proj
1529                + o_proj
1530                + gate_proj
1531                + up_proj
1532                + down_proj
1533        };
1534        Ok(vec![
1535            per_layer_elems * dtype.size_in_bytes();
1536            cfg.text_config.num_hidden_layers
1537        ])
1538    }
1539
1540    fn num_layers(&self, config: &str) -> Result<usize> {
1541        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1542        Ok(cfg.text_config.num_hidden_layers)
1543    }
1544
1545    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1546        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1547        let cfg = &cfg.text_config;
1548
1549        let cfg = ModelConfigMetadata {
1550            max_seq_len: cfg.max_position_embeddings,
1551            num_layers: cfg.num_hidden_layers,
1552            hidden_size: cfg.hidden_size,
1553            num_kv_heads: cfg.num_key_value_heads,
1554            num_attn_heads: cfg.num_attention_heads,
1555            sliding_window: cfg.sliding_window,
1556            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1557            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1558        };
1559
1560        Ok(Box::new(cfg))
1561    }
1562
1563    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1564        Some(vec![NonMappedSubModel::Vision])
1565    }
1566}
1567
1568// ======================== MLlama Loader
1569
1570/// [`VisionLoader`] for an Llama Vision model.
1571///
1572/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1573pub struct VLlamaLoader;
1574
1575pub struct VLlamaPrefixer;
1576
1577impl VisionPromptPrefixer for VLlamaPrefixer {
1578    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1579        format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1580    }
1581}
1582
1583impl VisionModelLoader for VLlamaLoader {
1584    fn load(
1585        &self,
1586        config: &str,
1587        vb: ShardedVarBuilder,
1588        normal_loading_metadata: NormalLoadingMetadata,
1589        attention_mechanism: AttentionImplementation,
1590    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1591        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1592        Ok(Box::new(MLlamaModel::new(
1593            &cfg,
1594            vb,
1595            self.is_gptx(config),
1596            normal_loading_metadata,
1597            attention_mechanism,
1598        )?))
1599    }
1600    fn is_gptx(&self, _config: &str) -> bool {
1601        true
1602    }
1603    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1604        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1605        Ok(Box::new(cfg))
1606    }
1607    fn get_processor(
1608        &self,
1609        _model_config: &str,
1610        _processor_config: Option<ProcessorConfig>,
1611        _preprocessor_config: PreProcessorConfig,
1612        _max_edge: Option<u32>,
1613    ) -> Arc<dyn Processor + Send + Sync> {
1614        Arc::new(MLlamaProcessor::new())
1615    }
1616    fn supports_paged_attention(&self, _config: &str) -> bool {
1617        false
1618    }
1619    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1620        true
1621    }
1622    fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
1623        Arc::new(VLlamaPrefixer)
1624    }
1625}
1626
1627impl IsqModelLoader for VLlamaLoader {
1628    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1629        let config: MLlamaConfig = serde_json::from_str(config)?;
1630        let cross_attn_layers = &config.text_config.cross_attention_layers;
1631        let transformer_layers =
1632            (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1633        let mut text_regexes = Vec::new();
1634        for layer in transformer_layers {
1635            text_regexes.extend(vec![
1636                // Attention text
1637                Regex::new(&format!(
1638                    r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1639                ))?,
1640                Regex::new(&format!(
1641                    r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1642                ))?,
1643                Regex::new(&format!(
1644                    r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1645                ))?,
1646                Regex::new(&format!(
1647                    r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1648                ))?,
1649                // MLP text
1650                Regex::new(&format!(
1651                    r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1652                ))?,
1653                Regex::new(&format!(
1654                    r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1655                ))?,
1656                Regex::new(&format!(
1657                    r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1658                ))?,
1659            ]);
1660        }
1661        let vision_regexes = vec![
1662            // Vision attention (transformer)
1663            Regex::new(
1664                r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1665            )?,
1666            Regex::new(
1667                r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1668            )?,
1669            Regex::new(
1670                r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1671            )?,
1672            Regex::new(
1673                r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1674            )?,
1675            // Vision attention (global transforemr)
1676            Regex::new(
1677                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1678            )?,
1679            Regex::new(
1680                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1681            )?,
1682            Regex::new(
1683                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1684            )?,
1685            Regex::new(
1686                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1687            )?,
1688            // MLP vision
1689            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1690            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1691        ];
1692
1693        Ok([text_regexes, vision_regexes].concat())
1694    }
1695    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1696        self.isq_layer_regexes(config)
1697    }
1698}
1699
1700impl DeviceMappedModelLoader for VLlamaLoader {
1701    fn mapped_max_act_size_elems(
1702        &self,
1703        config: &str,
1704        params: &AutoDeviceMapParams,
1705        _prompt_chunksize: usize,
1706    ) -> Result<usize> {
1707        let AutoDeviceMapParams::Vision {
1708            max_seq_len,
1709            max_batch_size,
1710            max_image_shape: _,
1711            max_num_images,
1712        } = params
1713        else {
1714            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1715        };
1716
1717        let config: MLlamaConfig = serde_json::from_str(config)?;
1718
1719        let img_seq_len = {
1720            let cfg = &config.vision_config;
1721            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1722            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1723            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1724        };
1725        let img_seq_len = img_seq_len * max_num_images;
1726
1727        let max_cross_text_attn = {
1728            let cfg = &config.text_config;
1729            max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1730        };
1731
1732        let max_self_text_attn = {
1733            let cfg = &config.text_config;
1734            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1735        };
1736
1737        Ok(max_self_text_attn.max(max_cross_text_attn))
1738    }
1739
1740    fn non_mapped_max_act_size_elems(
1741        &self,
1742        config: &str,
1743        params: &AutoDeviceMapParams,
1744    ) -> Result<usize> {
1745        let AutoDeviceMapParams::Vision {
1746            max_seq_len: _,
1747            max_batch_size,
1748            max_image_shape: _,
1749            max_num_images,
1750        } = params
1751        else {
1752            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1753        };
1754
1755        let config: MLlamaConfig = serde_json::from_str(config)?;
1756
1757        let img_seq_len = {
1758            let cfg = &config.vision_config;
1759            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1760            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1761            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1762        };
1763        let max_vision_attn = {
1764            let cfg = &config.vision_config;
1765            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1766        };
1767
1768        Ok(max_vision_attn)
1769    }
1770
1771    fn non_mapped_size_in_bytes(
1772        &self,
1773        config: &str,
1774        dtype: DType,
1775        weight_pack_factor: usize,
1776    ) -> Result<usize> {
1777        let config: MLlamaConfig = serde_json::from_str(config)?;
1778        let text_elems = {
1779            let cfg = &config.text_config;
1780            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1781            let lm_head = if !cfg.tie_word_embeddings {
1782                cfg.hidden_size * cfg.vocab_size
1783            } else {
1784                0
1785            };
1786            let norm = cfg.hidden_size;
1787            embed_tokens + lm_head + norm
1788        };
1789
1790        let vision_elems = {
1791            let cfg = &config.vision_config;
1792
1793            let conv_cfg = Conv2dConfig {
1794                stride: cfg.patch_size,
1795                ..Default::default()
1796            };
1797            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1798                * cfg.patch_size
1799                * cfg.patch_size;
1800
1801            let class_embedding = cfg.hidden_size;
1802
1803            let gated_positional_embedding = {
1804                let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1805                let embedding = num_patches * cfg.hidden_size;
1806                let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1807                    * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1808
1809                embedding + tile_embedding
1810            };
1811
1812            let pre_tile_positional_embedding =
1813                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1814            let post_tile_positional_embedding =
1815                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1816
1817            let layernorm_pre = cfg.hidden_size;
1818            let layernorm_post = cfg.hidden_size;
1819
1820            let encoder_layer = {
1821                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1822                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1823
1824                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1825                let q_proj =
1826                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1827                let k_proj =
1828                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1829                let v_proj =
1830                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1831                let o_proj =
1832                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1833
1834                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1835                    + cfg.intermediate_size;
1836                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1837                    + cfg.hidden_size;
1838
1839                input_layernorm
1840                    + post_attention_layernorm
1841                    + q_proj
1842                    + k_proj
1843                    + v_proj
1844                    + o_proj
1845                    + fc1
1846                    + fc2
1847            };
1848
1849            patch_embedding
1850                + class_embedding
1851                + gated_positional_embedding
1852                + pre_tile_positional_embedding
1853                + post_tile_positional_embedding
1854                + layernorm_pre
1855                + layernorm_post
1856                + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1857        };
1858
1859        let elems = text_elems + vision_elems;
1860        Ok(elems * dtype.size_in_bytes())
1861    }
1862
1863    fn layer_sizes_in_bytes(
1864        &self,
1865        config: &str,
1866        dtype: DType,
1867        weight_pack_factor: usize,
1868    ) -> Result<Vec<usize>> {
1869        let config: MLlamaConfig = serde_json::from_str(config)?;
1870        let cfg = &config.text_config;
1871
1872        let mut layer_sizes = Vec::new();
1873
1874        for i in 0..cfg.num_hidden_layers {
1875            let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1876                // No isq for cross attention
1877                1
1878            } else {
1879                weight_pack_factor
1880            };
1881
1882            let per_layer_elems = {
1883                let input_layernorm = cfg.hidden_size;
1884                let post_attention_layernorm = cfg.hidden_size;
1885
1886                let size_in = cfg.hidden_size;
1887                let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1888                let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1889                let q_proj = size_in * size_q / weight_pack_factor;
1890                let k_proj = size_in * size_kv / weight_pack_factor;
1891                let v_proj = size_in * size_kv / weight_pack_factor;
1892                let o_proj = size_q * size_in / weight_pack_factor;
1893
1894                let h_size = cfg.hidden_size;
1895                let i_size = cfg.intermediate_size;
1896                let gate_proj = h_size * i_size / weight_pack_factor;
1897                let up_proj = h_size * i_size / weight_pack_factor;
1898                let down_proj = i_size * h_size / weight_pack_factor;
1899
1900                input_layernorm
1901                    + post_attention_layernorm
1902                    + q_proj
1903                    + k_proj
1904                    + v_proj
1905                    + o_proj
1906                    + gate_proj
1907                    + up_proj
1908                    + down_proj
1909            };
1910
1911            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1912        }
1913
1914        Ok(layer_sizes)
1915    }
1916
1917    fn num_layers(&self, config: &str) -> Result<usize> {
1918        let config: MLlamaConfig = serde_json::from_str(config)?;
1919        Ok(config.text_config.num_hidden_layers)
1920    }
1921
1922    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1923        let cfg: MLlamaConfig = serde_json::from_str(config)?;
1924        let cfg = &cfg.text_config;
1925
1926        let cfg = ModelConfigMetadata {
1927            max_seq_len: cfg.max_position_embeddings,
1928            num_layers: cfg.num_hidden_layers,
1929            hidden_size: cfg.hidden_size,
1930            num_kv_heads: cfg.num_key_value_heads,
1931            num_attn_heads: cfg.num_attention_heads,
1932            sliding_window: None,
1933            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1934            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1935        };
1936
1937        Ok(Box::new(cfg))
1938    }
1939
1940    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1941        Some(vec![NonMappedSubModel::Vision])
1942    }
1943}
1944
1945// ======================== Qwen2VL Loader
1946
1947/// [`VisionLoader`] for an Qwen2-VL model.
1948///
1949/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1950pub struct Qwen2VLLoader;
1951
1952pub struct Qwen2VLPrefixer;
1953
1954impl VisionPromptPrefixer for Qwen2VLPrefixer {
1955    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1956        format!(
1957            "{}{prompt}",
1958            format!(
1959                "{}{}{}",
1960                Qwen2VLProcessor::VISION_START,
1961                Qwen2VLProcessor::IMAGE_PAD,
1962                Qwen2VLProcessor::VISION_END
1963            )
1964            .repeat(image_indexes.len())
1965        )
1966    }
1967}
1968
1969impl VisionModelLoader for Qwen2VLLoader {
1970    fn load(
1971        &self,
1972        config: &str,
1973        vb: ShardedVarBuilder,
1974        normal_loading_metadata: NormalLoadingMetadata,
1975        attention_mechanism: AttentionImplementation,
1976    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1977        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1978        Ok(Box::new(Qwen2VLModel::new(
1979            &cfg,
1980            vb,
1981            self.is_gptx(config),
1982            normal_loading_metadata,
1983            attention_mechanism,
1984        )?))
1985    }
1986    fn is_gptx(&self, _config: &str) -> bool {
1987        true
1988    }
1989    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1990        let config: Qwen2VLConfig = serde_json::from_str(config)?;
1991        Ok(Box::new(config))
1992    }
1993    fn get_processor(
1994        &self,
1995        _model_config: &str,
1996        _processor_config: Option<ProcessorConfig>,
1997        _preprocessor_config: PreProcessorConfig,
1998        max_edge: Option<u32>,
1999    ) -> Arc<dyn Processor + Send + Sync> {
2000        Arc::new(Qwen2VLProcessor::new(max_edge))
2001    }
2002    fn supports_paged_attention(&self, _config: &str) -> bool {
2003        false
2004    }
2005    fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
2006        Arc::new(Qwen2VLPrefixer)
2007    }
2008}
2009
2010impl IsqModelLoader for Qwen2VLLoader {
2011    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2012        Ok(vec![
2013            Regex::new(r"lm_head\.(weight|bias)$")?,
2014            // Attention
2015            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2016            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2017            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2018            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2019            // MLP
2020            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2021            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2022            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2023        ])
2024    }
2025    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2026        self.isq_layer_regexes(config)
2027    }
2028}
2029
2030impl DeviceMappedModelLoader for Qwen2VLLoader {
2031    fn mapped_max_act_size_elems(
2032        &self,
2033        config: &str,
2034        params: &AutoDeviceMapParams,
2035        _prompt_chunksize: usize,
2036    ) -> Result<usize> {
2037        let AutoDeviceMapParams::Vision {
2038            max_seq_len,
2039            max_batch_size,
2040            max_image_shape,
2041            max_num_images,
2042        } = params
2043        else {
2044            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2045        };
2046
2047        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2048
2049        let img_seq_len = {
2050            let cfg = &cfg.vision_config;
2051            let grid_t = max_num_images / cfg.temporal_patch_size;
2052            let grid_h = max_image_shape.0 / cfg.patch_size;
2053            let grid_w = max_image_shape.1 / cfg.patch_size;
2054            grid_t * grid_h * grid_w
2055        };
2056        let img_seq_len = img_seq_len * max_num_images;
2057
2058        let max_text_attn = {
2059            // This model injects the vision information directly into the input embeddings
2060            let max_seq_len = img_seq_len + max_seq_len;
2061            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2062        };
2063
2064        Ok(max_text_attn)
2065    }
2066
2067    fn non_mapped_max_act_size_elems(
2068        &self,
2069        config: &str,
2070        params: &AutoDeviceMapParams,
2071    ) -> Result<usize> {
2072        let AutoDeviceMapParams::Vision {
2073            max_seq_len: _,
2074            max_batch_size,
2075            max_image_shape,
2076            max_num_images,
2077        } = params
2078        else {
2079            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2080        };
2081
2082        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2083
2084        let img_seq_len = {
2085            let cfg = &cfg.vision_config;
2086            let grid_t = max_num_images / cfg.temporal_patch_size;
2087            let grid_h = max_image_shape.0 / cfg.patch_size;
2088            let grid_w = max_image_shape.1 / cfg.patch_size;
2089            grid_t * grid_h * grid_w
2090        };
2091
2092        let max_vision_attn = {
2093            let cfg = &cfg.vision_config;
2094            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2095        };
2096
2097        Ok(max_vision_attn)
2098    }
2099
2100    fn non_mapped_size_in_bytes(
2101        &self,
2102        config: &str,
2103        dtype: DType,
2104        weight_pack_factor: usize,
2105    ) -> Result<usize> {
2106        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2107        let text_elems = {
2108            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2109            let lm_head = if !cfg.tie_word_embeddings {
2110                cfg.hidden_size * cfg.vocab_size
2111            } else {
2112                0
2113            };
2114            let norm = cfg.hidden_size;
2115            embed_tokens + lm_head + norm
2116        };
2117
2118        let patch_merger = {
2119            let cfg = &cfg.vision_config;
2120            let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2121
2122            let mlp0 = hidden_size * hidden_size + hidden_size;
2123            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2124
2125            let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2126
2127            mlp0 + mlp2 + ln_q
2128        };
2129
2130        let patch_embed = {
2131            let cfg = &cfg.vision_config;
2132            let conv_cfg = Conv3dConfig {
2133                stride: cfg.patch_size,
2134                ..Default::default()
2135            };
2136            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2137            cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2138                * kernel_sizes[0]
2139                * kernel_sizes[1]
2140                * kernel_sizes[2]
2141        };
2142
2143        let encoder_layer = {
2144            let cfg = &cfg.vision_config;
2145            let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2146            let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2147
2148            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2149            let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2150            let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2151            let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2152
2153            let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2154            let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2155
2156            norm1 + norm2 + fc1 + fc2 + qkv + out
2157        };
2158
2159        let elems =
2160            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2161
2162        Ok(elems * dtype.size_in_bytes())
2163    }
2164
2165    fn layer_sizes_in_bytes(
2166        &self,
2167        config: &str,
2168        dtype: DType,
2169        weight_pack_factor: usize,
2170    ) -> Result<Vec<usize>> {
2171        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2172        let per_layer_elems = {
2173            let input_layernorm = cfg.hidden_size;
2174            let post_attention_layernorm = cfg.hidden_size;
2175
2176            let size_in = cfg.hidden_size;
2177            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2178            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2179            let q_proj = size_in * size_q / weight_pack_factor + size_q;
2180            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2181            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2182            let o_proj = size_q * size_in / weight_pack_factor;
2183
2184            let h_size = cfg.hidden_size;
2185            let i_size = cfg.intermediate_size;
2186            let gate_proj = h_size * i_size / weight_pack_factor;
2187            let up_proj = h_size * i_size / weight_pack_factor;
2188            let down_proj = i_size * h_size / weight_pack_factor;
2189
2190            input_layernorm
2191                + post_attention_layernorm
2192                + q_proj
2193                + k_proj
2194                + v_proj
2195                + o_proj
2196                + gate_proj
2197                + up_proj
2198                + down_proj
2199        };
2200        Ok(vec![
2201            per_layer_elems * dtype.size_in_bytes();
2202            cfg.num_hidden_layers
2203        ])
2204    }
2205
2206    fn num_layers(&self, config: &str) -> Result<usize> {
2207        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2208        Ok(cfg.num_hidden_layers)
2209    }
2210
2211    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2212        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2213
2214        let cfg = ModelConfigMetadata {
2215            max_seq_len: cfg.max_position_embeddings,
2216            num_layers: cfg.num_hidden_layers,
2217            hidden_size: cfg.hidden_size,
2218            num_kv_heads: cfg.num_key_value_heads,
2219            num_attn_heads: cfg.num_attention_heads,
2220            sliding_window: cfg.sliding_window,
2221            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2222            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2223        };
2224
2225        Ok(Box::new(cfg))
2226    }
2227
2228    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2229        Some(vec![NonMappedSubModel::Vision])
2230    }
2231}
2232
2233// ======================== Idefics 3 loader
2234
2235/// [`VisionLoader`] for an Idefics 3 Vision model.
2236///
2237/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2238pub struct Idefics3Loader;
2239
2240pub struct Idefics3Prefixer;
2241
2242impl VisionPromptPrefixer for Idefics3Prefixer {
2243    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2244        // Chat template does it
2245        prompt.to_string()
2246    }
2247}
2248
2249impl VisionModelLoader for Idefics3Loader {
2250    fn load(
2251        &self,
2252        config: &str,
2253        vb: ShardedVarBuilder,
2254        normal_loading_metadata: NormalLoadingMetadata,
2255        attention_mechanism: AttentionImplementation,
2256    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2257        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2258        Ok(Box::new(Idefics3Model::new(
2259            &cfg,
2260            vb,
2261            self.is_gptx(config),
2262            normal_loading_metadata,
2263            attention_mechanism,
2264        )?))
2265    }
2266    fn is_gptx(&self, _config: &str) -> bool {
2267        true
2268    }
2269    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2270        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2271        Ok(Box::new(cfg))
2272    }
2273    fn get_processor(
2274        &self,
2275        _model_config: &str,
2276        processor_config: Option<ProcessorConfig>,
2277        preprocessor_config: PreProcessorConfig,
2278        max_edge: Option<u32>,
2279    ) -> Arc<dyn Processor + Send + Sync> {
2280        Arc::new(Idefics3Processor::new(
2281            processor_config.unwrap_or_default(),
2282            preprocessor_config,
2283            max_edge,
2284        ))
2285    }
2286    fn supports_paged_attention(&self, _config: &str) -> bool {
2287        true
2288    }
2289    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2290        true
2291    }
2292    fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
2293        Arc::new(Idefics3Prefixer)
2294    }
2295}
2296
2297impl IsqModelLoader for Idefics3Loader {
2298    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2299        Ok(vec![
2300            Regex::new(r"lm_head\.(weight|bias)$")?,
2301            // Attention
2302            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2303            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2304            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2305            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2306            // MLP
2307            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2308            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2309            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2310        ])
2311    }
2312    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2313        Ok(vec![
2314            Regex::new(r"lm_head\.(weight|bias)$")?,
2315            // Attention
2316            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2317            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2318            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2319            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2320            // MLP
2321            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2322            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2323            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2324            // // Attention (vision)
2325            // Regex::new(
2326            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2327            // )?,
2328            // Regex::new(
2329            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
2330            // )?,
2331            // Regex::new(
2332            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
2333            // )?,
2334            // Regex::new(
2335            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(weight|bias)$",
2336            // )?,
2337            // MLP (vision)
2338            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2339            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
2340        ])
2341    }
2342}
2343
2344impl DeviceMappedModelLoader for Idefics3Loader {
2345    fn mapped_max_act_size_elems(
2346        &self,
2347        config: &str,
2348        params: &AutoDeviceMapParams,
2349        _prompt_chunksize: usize,
2350    ) -> Result<usize> {
2351        let AutoDeviceMapParams::Vision {
2352            max_seq_len,
2353            max_batch_size,
2354            max_image_shape: _,
2355            max_num_images,
2356        } = params
2357        else {
2358            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2359        };
2360
2361        let cfg: Idefics3Config = serde_json::from_str(config)?;
2362
2363        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2364        let img_seq_len = (num_patches + 1) * max_num_images;
2365
2366        let max_text_attn = {
2367            // This model injects the vision information directly into the input embeddings
2368            let max_seq_len = img_seq_len + max_seq_len;
2369            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2370        };
2371
2372        Ok(max_text_attn)
2373    }
2374
2375    fn non_mapped_max_act_size_elems(
2376        &self,
2377        config: &str,
2378        params: &AutoDeviceMapParams,
2379    ) -> Result<usize> {
2380        let AutoDeviceMapParams::Vision {
2381            max_seq_len: _,
2382            max_batch_size,
2383            max_image_shape: _,
2384            max_num_images,
2385        } = params
2386        else {
2387            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2388        };
2389
2390        let cfg: Idefics3Config = serde_json::from_str(config)?;
2391
2392        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2393        let img_seq_len = num_patches + 1;
2394
2395        let max_vision_attn = {
2396            // do_image_splitting = true
2397            let images_factor = 5;
2398
2399            (max_batch_size * images_factor * max_num_images)
2400                * cfg.vision_config.num_attention_heads
2401                * img_seq_len
2402                * img_seq_len
2403        };
2404
2405        Ok(max_vision_attn)
2406    }
2407
2408    fn non_mapped_size_in_bytes(
2409        &self,
2410        config: &str,
2411        dtype: DType,
2412        weight_pack_factor: usize,
2413    ) -> Result<usize> {
2414        let cfg: Idefics3Config = serde_json::from_str(config)?;
2415        let text_elems = {
2416            let cfg = &cfg.text_config;
2417
2418            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2419            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2420            let norm = cfg.hidden_size;
2421            embed_tokens + lm_head + norm
2422        };
2423
2424        let connector_elems = {
2425            let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2426            let out_dim = cfg.text_config.hidden_size;
2427
2428            in_dim * out_dim
2429        };
2430
2431        let vision_transformer = {
2432            let cfg = &cfg.vision_config;
2433
2434            let post_layernorm = cfg.hidden_size;
2435
2436            let conv_config = Conv2dConfig {
2437                stride: cfg.patch_size,
2438                ..Default::default()
2439            };
2440            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2441                * cfg.patch_size
2442                * cfg.patch_size;
2443
2444            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2445            let num_patches = num_patches_per_side.pow(2);
2446            let position_embedding = num_patches * cfg.hidden_size;
2447
2448            let layer_elems = {
2449                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2450                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2451
2452                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2453                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2454
2455                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2456                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2457                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2458                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2459
2460                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2461            };
2462
2463            post_layernorm
2464                + patch_embedding
2465                + position_embedding
2466                + layer_elems * cfg.num_hidden_layers
2467        };
2468
2469        let elems = text_elems + connector_elems + vision_transformer;
2470
2471        Ok(elems * dtype.size_in_bytes())
2472    }
2473
2474    fn layer_sizes_in_bytes(
2475        &self,
2476        config: &str,
2477        dtype: DType,
2478        weight_pack_factor: usize,
2479    ) -> Result<Vec<usize>> {
2480        let cfg: Idefics3Config = serde_json::from_str(config)?;
2481        let cfg = cfg.text_config;
2482        let per_layer_elems = {
2483            let input_layernorm = cfg.hidden_size;
2484            let post_attention_layernorm = cfg.hidden_size;
2485
2486            let size_in = cfg.hidden_size;
2487            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2488            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2489            let q_proj = size_in * size_q / weight_pack_factor;
2490            let k_proj = size_in * size_kv / weight_pack_factor;
2491            let v_proj = size_in * size_kv / weight_pack_factor;
2492            let o_proj = size_q * size_in / weight_pack_factor;
2493
2494            let h_size = cfg.hidden_size;
2495            let i_size = cfg.intermediate_size;
2496            let gate_proj = h_size * i_size / weight_pack_factor;
2497            let up_proj = h_size * i_size / weight_pack_factor;
2498            let down_proj = i_size * h_size / weight_pack_factor;
2499
2500            input_layernorm
2501                + post_attention_layernorm
2502                + q_proj
2503                + k_proj
2504                + v_proj
2505                + o_proj
2506                + gate_proj
2507                + up_proj
2508                + down_proj
2509        };
2510        Ok(vec![
2511            per_layer_elems * dtype.size_in_bytes();
2512            cfg.num_hidden_layers
2513        ])
2514    }
2515
2516    fn num_layers(&self, config: &str) -> Result<usize> {
2517        let cfg: Idefics3Config = serde_json::from_str(config)?;
2518        Ok(cfg.text_config.num_hidden_layers)
2519    }
2520    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2521        let cfg: Idefics3Config = serde_json::from_str(config)?;
2522        let cfg = &cfg.text_config;
2523
2524        let cfg = ModelConfigMetadata {
2525            max_seq_len: cfg.max_position_embeddings,
2526            num_layers: cfg.num_hidden_layers,
2527            hidden_size: cfg.hidden_size,
2528            num_kv_heads: cfg.num_key_value_heads,
2529            num_attn_heads: cfg.num_attention_heads,
2530            sliding_window: None,
2531            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2532            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2533        };
2534
2535        Ok(Box::new(cfg))
2536    }
2537
2538    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2539        Some(vec![NonMappedSubModel::Vision])
2540    }
2541}
2542
2543// ======================== MiniCpm-O loader
2544
2545/// [`VisionLoader`] for an MiniCpm-O model.
2546///
2547/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2548pub struct MiniCpmOLoader;
2549
2550pub struct MiniCpmOPrefixer;
2551
2552impl VisionPromptPrefixer for MiniCpmOPrefixer {
2553    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2554        format!(
2555            "{}{prompt}",
2556            "(<image>./</image>)".repeat(image_indexes.len())
2557        )
2558    }
2559}
2560
2561impl VisionModelLoader for MiniCpmOLoader {
2562    fn load(
2563        &self,
2564        config: &str,
2565        vb: ShardedVarBuilder,
2566        normal_loading_metadata: NormalLoadingMetadata,
2567        attention_mechanism: AttentionImplementation,
2568    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2569        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2570        Ok(Box::new(MiniCpmOModel::new(
2571            &cfg,
2572            vb,
2573            self.is_gptx(config),
2574            normal_loading_metadata,
2575            attention_mechanism,
2576        )?))
2577    }
2578    fn is_gptx(&self, _config: &str) -> bool {
2579        true
2580    }
2581    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2582        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2583        Ok(Box::new(cfg))
2584    }
2585    fn get_processor(
2586        &self,
2587        _model_config: &str,
2588        processor_config: Option<ProcessorConfig>,
2589        preprocessor_config: PreProcessorConfig,
2590        max_edge: Option<u32>,
2591    ) -> Arc<dyn Processor + Send + Sync> {
2592        Arc::new(MiniCpmOProcessor::new(
2593            processor_config.unwrap_or_default(),
2594            preprocessor_config,
2595            max_edge,
2596        ))
2597    }
2598    fn supports_paged_attention(&self, _config: &str) -> bool {
2599        true
2600    }
2601    fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
2602        Arc::new(MiniCpmOPrefixer)
2603    }
2604}
2605
2606impl IsqModelLoader for MiniCpmOLoader {
2607    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2608        Ok(vec![
2609            Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2610            // Attention
2611            Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2612            Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2613            Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2614            Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2615            // MLP
2616            Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2617            Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2618            Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2619        ])
2620    }
2621    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2622        self.isq_layer_regexes(config)
2623    }
2624}
2625
2626impl DeviceMappedModelLoader for MiniCpmOLoader {
2627    fn mapped_max_act_size_elems(
2628        &self,
2629        config: &str,
2630        params: &AutoDeviceMapParams,
2631        _prompt_chunksize: usize,
2632    ) -> Result<usize> {
2633        let AutoDeviceMapParams::Vision {
2634            max_seq_len,
2635            max_batch_size,
2636            max_image_shape: _,
2637            max_num_images,
2638        } = params
2639        else {
2640            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2641        };
2642
2643        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2644
2645        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2646        let img_seq_len = (num_patches + 1) * max_num_images;
2647
2648        let max_text_attn = {
2649            // This model injects the vision information directly into the input embeddings
2650            let max_seq_len = img_seq_len + max_seq_len;
2651            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2652        };
2653
2654        Ok(max_text_attn)
2655    }
2656
2657    fn non_mapped_max_act_size_elems(
2658        &self,
2659        config: &str,
2660        params: &AutoDeviceMapParams,
2661    ) -> Result<usize> {
2662        let AutoDeviceMapParams::Vision {
2663            max_seq_len: _,
2664            max_batch_size,
2665            max_image_shape: _,
2666            max_num_images,
2667        } = params
2668        else {
2669            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2670        };
2671
2672        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2673
2674        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2675        let img_seq_len = num_patches + 1;
2676
2677        let max_vision_attn = {
2678            // do_image_splitting = true
2679            let images_factor = 5;
2680
2681            (max_batch_size * images_factor * max_num_images)
2682                * cfg.vision_config.num_attention_heads
2683                * img_seq_len
2684                * img_seq_len
2685        };
2686
2687        Ok(max_vision_attn)
2688    }
2689
2690    fn non_mapped_size_in_bytes(
2691        &self,
2692        config: &str,
2693        dtype: DType,
2694        weight_pack_factor: usize,
2695    ) -> Result<usize> {
2696        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2697        let text_elems = {
2698            let cfg = &cfg.text_config;
2699
2700            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2701            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2702            let norm = cfg.hidden_size;
2703            embed_tokens + lm_head + norm
2704        };
2705
2706        let vision_transformer = {
2707            let cfg = &cfg.vision_config;
2708
2709            let post_layernorm = cfg.hidden_size;
2710
2711            let conv_config = Conv2dConfig {
2712                stride: cfg.patch_size,
2713                ..Default::default()
2714            };
2715            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2716                * cfg.patch_size
2717                * cfg.patch_size;
2718
2719            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2720            let num_patches = num_patches_per_side.pow(2);
2721            let position_embedding = num_patches * cfg.hidden_size;
2722
2723            let layer_elems = {
2724                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2725                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2726
2727                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2728                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2729
2730                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2731                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2732                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2733                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2734
2735                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2736            };
2737
2738            post_layernorm
2739                + patch_embedding
2740                + position_embedding
2741                + layer_elems * cfg.num_hidden_layers
2742        };
2743
2744        let elems = text_elems + vision_transformer;
2745
2746        Ok(elems * dtype.size_in_bytes())
2747    }
2748
2749    fn layer_sizes_in_bytes(
2750        &self,
2751        config: &str,
2752        dtype: DType,
2753        weight_pack_factor: usize,
2754    ) -> Result<Vec<usize>> {
2755        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2756        let cfg = cfg.text_config;
2757        let per_layer_elems = {
2758            let input_layernorm = cfg.hidden_size;
2759            let post_attention_layernorm = cfg.hidden_size;
2760
2761            let size_in = cfg.hidden_size;
2762            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2763            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2764            let q_proj = size_in * size_q / weight_pack_factor;
2765            let k_proj = size_in * size_kv / weight_pack_factor;
2766            let v_proj = size_in * size_kv / weight_pack_factor;
2767            let o_proj = size_q * size_in / weight_pack_factor;
2768
2769            let h_size = cfg.hidden_size;
2770            let i_size = cfg.intermediate_size;
2771            let gate_proj = h_size * i_size / weight_pack_factor;
2772            let up_proj = h_size * i_size / weight_pack_factor;
2773            let down_proj = i_size * h_size / weight_pack_factor;
2774
2775            input_layernorm
2776                + post_attention_layernorm
2777                + q_proj
2778                + k_proj
2779                + v_proj
2780                + o_proj
2781                + gate_proj
2782                + up_proj
2783                + down_proj
2784        };
2785        Ok(vec![
2786            per_layer_elems * dtype.size_in_bytes();
2787            cfg.num_hidden_layers
2788        ])
2789    }
2790
2791    fn num_layers(&self, config: &str) -> Result<usize> {
2792        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2793        Ok(cfg.text_config.num_hidden_layers)
2794    }
2795    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2796        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2797        let cfg = &cfg.text_config;
2798
2799        let cfg = ModelConfigMetadata {
2800            max_seq_len: cfg.max_position_embeddings,
2801            num_layers: cfg.num_hidden_layers,
2802            hidden_size: cfg.hidden_size,
2803            num_kv_heads: cfg.num_key_value_heads,
2804            num_attn_heads: cfg.num_attention_heads,
2805            sliding_window: None,
2806            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2807            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2808        };
2809
2810        Ok(Box::new(cfg))
2811    }
2812}
2813
2814// ======================== Phi 4MM loader
2815
2816/// [`VisionLoader`] for a Phi 4MM Vision model.
2817///
2818/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2819pub struct Phi4MMLoader;
2820
2821pub struct Phi4MMPrefixer;
2822
2823impl VisionPromptPrefixer for Phi4MMPrefixer {
2824    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2825        // Image indexing starts at 0.
2826
2827        format!(
2828            "{}{prompt}",
2829            image_indexes
2830                .into_iter()
2831                .map(|image_index| format!("<|image_{}|>", image_index + 1))
2832                .join("")
2833        )
2834    }
2835}
2836
2837impl VisionModelLoader for Phi4MMLoader {
2838    fn load(
2839        &self,
2840        config: &str,
2841        vb: ShardedVarBuilder,
2842        normal_loading_metadata: NormalLoadingMetadata,
2843        attention_mechanism: AttentionImplementation,
2844    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2845        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2846        Ok(Box::new(Phi4MMModel::new(
2847            &cfg,
2848            vb,
2849            self.is_gptx(config),
2850            normal_loading_metadata,
2851            attention_mechanism,
2852        )?))
2853    }
2854    fn is_gptx(&self, _config: &str) -> bool {
2855        true
2856    }
2857    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2858        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2859        Ok(Box::new(cfg))
2860    }
2861    fn get_processor(
2862        &self,
2863        _model_config: &str,
2864        processor_config: Option<ProcessorConfig>,
2865        preprocessor_config: PreProcessorConfig,
2866        _max_edge: Option<u32>,
2867    ) -> Arc<dyn Processor + Send + Sync> {
2868        Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2869    }
2870    fn supports_paged_attention(&self, _config: &str) -> bool {
2871        true
2872    }
2873    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2874        true
2875    }
2876    fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
2877        Arc::new(Phi4MMPrefixer)
2878    }
2879}
2880
2881impl IsqModelLoader for Phi4MMLoader {
2882    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2883        Ok(vec![
2884            Regex::new(r"lm_head\.(weight|bias)$")?,
2885            // Attention
2886            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
2887            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2888            // MLP
2889            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
2890            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2891        ])
2892    }
2893    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2894        self.isq_layer_regexes(config)
2895    }
2896}
2897
2898impl DeviceMappedModelLoader for Phi4MMLoader {
2899    fn mapped_max_act_size_elems(
2900        &self,
2901        config: &str,
2902        params: &AutoDeviceMapParams,
2903        _prompt_chunksize: usize,
2904    ) -> Result<usize> {
2905        // NOTE: we ignore max_num_images although it can only be one...
2906        let AutoDeviceMapParams::Vision {
2907            max_seq_len,
2908            max_batch_size,
2909            max_image_shape: _,
2910            max_num_images,
2911        } = params
2912        else {
2913            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2914        };
2915
2916        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2917
2918        let vcfg = &PHI4_MM_VISION_CFG;
2919
2920        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
2921        let img_seq_len = (num_patches + 1) * max_num_images;
2922
2923        let max_text_attn = {
2924            // This model injects the vision information directly into the input embeddings
2925            let max_seq_len = img_seq_len + max_seq_len;
2926            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2927        };
2928
2929        Ok(max_text_attn)
2930    }
2931
2932    fn non_mapped_max_act_size_elems(
2933        &self,
2934        _config: &str,
2935        params: &AutoDeviceMapParams,
2936    ) -> Result<usize> {
2937        let AutoDeviceMapParams::Vision {
2938            max_seq_len: _,
2939            max_batch_size,
2940            max_image_shape,
2941            max_num_images,
2942        } = params
2943        else {
2944            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2945        };
2946
2947        let vcfg = &PHI4_MM_VISION_CFG;
2948
2949        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
2950        let img_seq_len = num_patches + 1;
2951
2952        let max_batch_size = max_batch_size
2953            * (max_image_shape
2954                .0
2955                .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
2956                * max_image_shape
2957                    .1
2958                    .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
2959                + 1);
2960
2961        let max_vision_attn = (max_batch_size * max_num_images)
2962            * vcfg.num_attention_heads
2963            * img_seq_len
2964            * img_seq_len;
2965        let max_qkv = 3
2966            * (max_batch_size
2967                * vcfg.num_attention_heads
2968                * img_seq_len
2969                * (vcfg.hidden_size / vcfg.num_attention_heads));
2970
2971        Ok(max_vision_attn + max_qkv)
2972    }
2973
2974    fn non_mapped_size_in_bytes(
2975        &self,
2976        config: &str,
2977        dtype: DType,
2978        weight_pack_factor: usize,
2979    ) -> Result<usize> {
2980        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2981        let elems = {
2982            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2983            let lm_head = if !cfg.tie_word_embeddings {
2984                cfg.hidden_size * cfg.vocab_size
2985            } else {
2986                0
2987            };
2988            let norm = cfg.hidden_size;
2989
2990            let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
2991                let projection_cls = img_embed
2992                    .projection_cls
2993                    .clone()
2994                    .unwrap_or("linear".to_string());
2995                let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
2996                let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
2997                let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
2998
2999                let proj = match (projection_cls.as_str(), use_hd_transform) {
3000                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3001                    ("mlp", true) => {
3002                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3003                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3004                        a + b
3005                    }
3006                    ("mlp", false) => {
3007                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3008                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3009                        a + b
3010                    }
3011                    _ => {
3012                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3013                    }
3014                };
3015
3016                let (glb_gn, sub_gn) = if with_learnable_separator {
3017                    let glb_gn = image_dim_out * 4;
3018                    let sub_gn = image_dim_out * 4;
3019                    (glb_gn, sub_gn)
3020                } else {
3021                    (0, 0)
3022                };
3023
3024                let vision_transformer = {
3025                    let cfg = &PHI4_MM_VISION_CFG;
3026
3027                    let post_layernorm = cfg.hidden_size;
3028
3029                    let conv_config = Conv2dConfig {
3030                        stride: cfg.patch_size,
3031                        ..Default::default()
3032                    };
3033                    let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3034                        * cfg.patch_size
3035                        * cfg.patch_size;
3036
3037                    let num_patches_per_side = cfg.image_size / cfg.patch_size;
3038                    let num_patches = num_patches_per_side.pow(2);
3039                    let position_embedding = num_patches * cfg.hidden_size;
3040
3041                    let layer_elems = {
3042                        let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3043                        let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3044
3045                        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3046                        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3047
3048                        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3049                        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3050                        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3051                        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3052
3053                        layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3054                    };
3055
3056                    post_layernorm
3057                        + patch_embedding
3058                        + position_embedding
3059                        + layer_elems * cfg.num_hidden_layers
3060                };
3061
3062                proj + glb_gn + sub_gn + vision_transformer
3063            } else {
3064                0
3065            };
3066
3067            embed_tokens + lm_head + norm + image_embed
3068        };
3069
3070        Ok(elems * dtype.size_in_bytes())
3071    }
3072
3073    fn layer_sizes_in_bytes(
3074        &self,
3075        config: &str,
3076        dtype: DType,
3077        weight_pack_factor: usize,
3078    ) -> Result<Vec<usize>> {
3079        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3080        let per_layer_elems = {
3081            let input_layernorm = cfg.hidden_size;
3082            let post_attention_layernorm = cfg.hidden_size;
3083
3084            let size_in = cfg.hidden_size;
3085            let head_dim = cfg.head_dim();
3086            let op_size =
3087                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3088            let qkv_proj = size_in * op_size / weight_pack_factor;
3089            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3090
3091            let h_size = cfg.hidden_size;
3092            let i_size = cfg.intermediate_size;
3093            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3094            let down_proj = h_size * i_size / weight_pack_factor;
3095
3096            input_layernorm
3097                + post_attention_layernorm
3098                + qkv_proj
3099                + o_proj
3100                + gate_up_proj
3101                + down_proj
3102        };
3103        Ok(vec![
3104            per_layer_elems * dtype.size_in_bytes();
3105            cfg.num_hidden_layers
3106        ])
3107    }
3108
3109    fn num_layers(&self, config: &str) -> Result<usize> {
3110        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3111        Ok(cfg.num_hidden_layers)
3112    }
3113
3114    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3115        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3116
3117        let cfg = ModelConfigMetadata {
3118            max_seq_len: cfg.max_position_embeddings,
3119            num_layers: cfg.num_hidden_layers,
3120            hidden_size: cfg.hidden_size,
3121            num_kv_heads: cfg.num_key_value_heads(),
3122            num_attn_heads: cfg.num_attention_heads,
3123            sliding_window: cfg.sliding_window,
3124            k_head_dim: cfg.head_dim(),
3125            v_head_dim: cfg.head_dim(),
3126        };
3127
3128        Ok(Box::new(cfg))
3129    }
3130
3131    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3132        Some(vec![NonMappedSubModel::Vision])
3133    }
3134}
3135
3136// ======================== Qwen2_5VL Loader
3137
3138/// [`VisionLoader`] for an Qwen2_5VL model.
3139///
3140/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3141pub struct Qwen2_5VLLoader;
3142
3143pub struct Qwen2_5VLPrefixer;
3144
3145impl VisionPromptPrefixer for Qwen2_5VLPrefixer {
3146    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3147        format!(
3148            "{}{prompt}",
3149            format!(
3150                "{}{}{}",
3151                Qwen2_5VLProcessor::VISION_START,
3152                Qwen2_5VLProcessor::IMAGE_PAD,
3153                Qwen2_5VLProcessor::VISION_END
3154            )
3155            .repeat(image_indexes.len())
3156        )
3157    }
3158}
3159
3160impl VisionModelLoader for Qwen2_5VLLoader {
3161    fn load(
3162        &self,
3163        config: &str,
3164        vb: ShardedVarBuilder,
3165        normal_loading_metadata: NormalLoadingMetadata,
3166        attention_mechanism: AttentionImplementation,
3167    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3168        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3169        Ok(Box::new(Qwen2_5VLModel::new(
3170            &cfg,
3171            vb,
3172            self.is_gptx(config),
3173            normal_loading_metadata,
3174            attention_mechanism,
3175        )?))
3176    }
3177    fn is_gptx(&self, _config: &str) -> bool {
3178        true
3179    }
3180    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3181        let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3182        Ok(Box::new(config))
3183    }
3184    fn get_processor(
3185        &self,
3186        _model_config: &str,
3187        _processor_config: Option<ProcessorConfig>,
3188        _preprocessor_config: PreProcessorConfig,
3189        max_edge: Option<u32>,
3190    ) -> Arc<dyn Processor + Send + Sync> {
3191        Arc::new(Qwen2_5VLProcessor::new(max_edge))
3192    }
3193    fn supports_paged_attention(&self, _config: &str) -> bool {
3194        false
3195    }
3196    fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
3197        Arc::new(Qwen2_5VLPrefixer)
3198    }
3199}
3200
3201impl IsqModelLoader for Qwen2_5VLLoader {
3202    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3203        Ok(vec![
3204            Regex::new(r"lm_head\.(weight|bias)$")?,
3205            // Attention
3206            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3207            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3208            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3209            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3210            // MLP
3211            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3212            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3213            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3214        ])
3215    }
3216    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3217        self.isq_layer_regexes(config)
3218    }
3219}
3220
3221impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3222    fn mapped_max_act_size_elems(
3223        &self,
3224        config: &str,
3225        params: &AutoDeviceMapParams,
3226        _prompt_chunksize: usize,
3227    ) -> Result<usize> {
3228        let AutoDeviceMapParams::Vision {
3229            max_seq_len,
3230            max_batch_size,
3231            max_image_shape,
3232            max_num_images,
3233        } = params
3234        else {
3235            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3236        };
3237
3238        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3239
3240        let img_seq_len = {
3241            let cfg = &cfg.vision_config;
3242            let grid_t = max_num_images / cfg.temporal_patch_size;
3243            let grid_h = max_image_shape.0 / cfg.patch_size;
3244            let grid_w = max_image_shape.1 / cfg.patch_size;
3245            grid_t * grid_h * grid_w
3246        };
3247        let img_seq_len = img_seq_len * max_num_images;
3248
3249        let max_text_attn = {
3250            // This model injects the vision information directly into the input embeddings
3251            let max_seq_len = img_seq_len + max_seq_len;
3252            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3253        };
3254
3255        Ok(max_text_attn)
3256    }
3257
3258    fn non_mapped_max_act_size_elems(
3259        &self,
3260        config: &str,
3261        params: &AutoDeviceMapParams,
3262    ) -> Result<usize> {
3263        let AutoDeviceMapParams::Vision {
3264            max_seq_len: _,
3265            max_batch_size,
3266            max_image_shape,
3267            max_num_images,
3268        } = params
3269        else {
3270            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3271        };
3272
3273        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3274
3275        let img_seq_len = {
3276            let cfg = &cfg.vision_config;
3277            let grid_t = max_num_images / cfg.temporal_patch_size;
3278            let grid_h = max_image_shape.0 / cfg.patch_size;
3279            let grid_w = max_image_shape.1 / cfg.patch_size;
3280            grid_t * grid_h * grid_w
3281        };
3282
3283        let max_vision_attn = {
3284            let cfg = &cfg.vision_config;
3285            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3286        };
3287
3288        Ok(max_vision_attn)
3289    }
3290
3291    fn non_mapped_size_in_bytes(
3292        &self,
3293        config: &str,
3294        dtype: DType,
3295        weight_pack_factor: usize,
3296    ) -> Result<usize> {
3297        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3298        let text_elems = {
3299            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3300            let lm_head = if !cfg.tie_word_embeddings {
3301                cfg.hidden_size * cfg.vocab_size
3302            } else {
3303                0
3304            };
3305            let norm = cfg.hidden_size;
3306            embed_tokens + lm_head + norm
3307        };
3308
3309        let patch_merger = {
3310            let cfg = &cfg.vision_config;
3311            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3312
3313            let mlp0 = hidden_size * hidden_size + hidden_size;
3314            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3315
3316            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3317
3318            mlp0 + mlp2 + ln_q
3319        };
3320
3321        let patch_embed = {
3322            let cfg = &cfg.vision_config;
3323            let conv_cfg = Conv3dConfig {
3324                stride: cfg.patch_size,
3325                ..Default::default()
3326            };
3327            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3328            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3329                * kernel_sizes[0]
3330                * kernel_sizes[1]
3331                * kernel_sizes[2]
3332        };
3333
3334        let encoder_layer = {
3335            let cfg = &cfg.vision_config;
3336            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3337            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3338
3339            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3340            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3341            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3342
3343            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3344            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3345
3346            norm1 + norm2 + fc1 + fc2 + qkv + out
3347        };
3348
3349        let elems =
3350            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3351
3352        Ok(elems * dtype.size_in_bytes())
3353    }
3354
3355    fn layer_sizes_in_bytes(
3356        &self,
3357        config: &str,
3358        dtype: DType,
3359        weight_pack_factor: usize,
3360    ) -> Result<Vec<usize>> {
3361        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3362        let per_layer_elems = {
3363            let input_layernorm = cfg.hidden_size;
3364            let post_attention_layernorm = cfg.hidden_size;
3365
3366            let size_in = cfg.hidden_size;
3367            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3368            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3369            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3370            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3371            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3372            let o_proj = size_q * size_in / weight_pack_factor;
3373
3374            let h_size = cfg.hidden_size;
3375            let i_size = cfg.intermediate_size;
3376            let gate_proj = h_size * i_size / weight_pack_factor;
3377            let up_proj = h_size * i_size / weight_pack_factor;
3378            let down_proj = i_size * h_size / weight_pack_factor;
3379
3380            input_layernorm
3381                + post_attention_layernorm
3382                + q_proj
3383                + k_proj
3384                + v_proj
3385                + o_proj
3386                + gate_proj
3387                + up_proj
3388                + down_proj
3389        };
3390        Ok(vec![
3391            per_layer_elems * dtype.size_in_bytes();
3392            cfg.num_hidden_layers
3393        ])
3394    }
3395
3396    fn num_layers(&self, config: &str) -> Result<usize> {
3397        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3398        Ok(cfg.num_hidden_layers)
3399    }
3400
3401    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3402        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3403
3404        let cfg = ModelConfigMetadata {
3405            max_seq_len: cfg.max_position_embeddings,
3406            num_layers: cfg.num_hidden_layers,
3407            hidden_size: cfg.hidden_size,
3408            num_kv_heads: cfg.num_key_value_heads,
3409            num_attn_heads: cfg.num_attention_heads,
3410            sliding_window: cfg.sliding_window,
3411            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3412            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3413        };
3414
3415        Ok(Box::new(cfg))
3416    }
3417
3418    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3419        Some(vec![NonMappedSubModel::Vision])
3420    }
3421}
3422
3423// ======================== Gemma 3 Loader
3424
3425/// [`VisionLoader`] for an Gemma 3 model.
3426///
3427/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3428pub struct Gemma3Loader;
3429
3430pub struct Gemma3Prefixer;
3431
3432impl VisionPromptPrefixer for Gemma3Prefixer {
3433    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3434        prompt.to_string()
3435    }
3436}
3437
3438impl VisionModelLoader for Gemma3Loader {
3439    fn load(
3440        &self,
3441        config: &str,
3442        vb: ShardedVarBuilder,
3443        normal_loading_metadata: NormalLoadingMetadata,
3444        attention_mechanism: AttentionImplementation,
3445    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3446        let cfg: Gemma3Config = serde_json::from_str(config)?;
3447        Ok(Box::new(Gemma3Model::new(
3448            &cfg,
3449            vb,
3450            self.is_gptx(config),
3451            normal_loading_metadata,
3452            attention_mechanism,
3453        )?))
3454    }
3455    fn is_gptx(&self, _config: &str) -> bool {
3456        true
3457    }
3458    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3459        let config: Gemma3Config = serde_json::from_str(config)?;
3460        Ok(Box::new(config))
3461    }
3462    fn get_processor(
3463        &self,
3464        config: &str,
3465        processor_config: Option<ProcessorConfig>,
3466        _preprocessor_config: PreProcessorConfig,
3467        _max_edge: Option<u32>,
3468    ) -> Arc<dyn Processor + Send + Sync> {
3469        let config: Gemma3Config = serde_json::from_str(config).unwrap();
3470        // Handle the Gemma 3 1b case here
3471        Arc::new(Gemma3Processor::new(
3472            processor_config.unwrap_or_default(),
3473            matches!(config, Gemma3Config::WithVision { .. }),
3474        ))
3475    }
3476    fn supports_paged_attention(&self, _config: &str) -> bool {
3477        true
3478    }
3479    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3480        true
3481    }
3482    fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
3483        Arc::new(Gemma3Prefixer)
3484    }
3485}
3486
3487impl IsqModelLoader for Gemma3Loader {
3488    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3489        Ok(vec![
3490            Regex::new(r"lm_head\.(weight|bias)$")?,
3491            // Attention
3492            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3493            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3494            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3495            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3496            // MLP
3497            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3498            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3499            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3500        ])
3501    }
3502    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3503        Ok(vec![
3504            Regex::new(r"lm_head\.(weight|bias)$")?,
3505            // Attention
3506            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3507            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3508            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3509            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3510            // MLP
3511            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3512            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3513            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3514        ])
3515    }
3516}
3517
3518impl DeviceMappedModelLoader for Gemma3Loader {
3519    fn mapped_max_act_size_elems(
3520        &self,
3521        config: &str,
3522        params: &AutoDeviceMapParams,
3523        prompt_chunksize: usize,
3524    ) -> Result<usize> {
3525        let AutoDeviceMapParams::Vision {
3526            max_seq_len,
3527            max_batch_size,
3528            max_image_shape: _,
3529            max_num_images,
3530        } = params
3531        else {
3532            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3533        };
3534
3535        let cfg: Gemma3Config = serde_json::from_str(config)?;
3536
3537        match cfg {
3538            Gemma3Config::Text(text_config) => Ok(max_batch_size
3539                * text_config.num_attention_heads
3540                * prompt_chunksize
3541                * prompt_chunksize),
3542            Gemma3Config::WithVision {
3543                text_config,
3544                vision_config,
3545                ..
3546            } => {
3547                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3548                let img_seq_len = (num_patches + 1) * max_num_images;
3549
3550                let max_text_attn = {
3551                    // This model injects the vision information directly into the input embeddings
3552                    let max_seq_len = img_seq_len + *max_seq_len;
3553                    max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3554                };
3555                Ok(max_text_attn)
3556            }
3557        }
3558    }
3559
3560    fn non_mapped_max_act_size_elems(
3561        &self,
3562        config: &str,
3563        params: &AutoDeviceMapParams,
3564    ) -> Result<usize> {
3565        let AutoDeviceMapParams::Vision {
3566            max_seq_len: _,
3567            max_batch_size,
3568            max_image_shape: _,
3569            max_num_images,
3570        } = params
3571        else {
3572            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3573        };
3574
3575        let cfg: Gemma3Config = serde_json::from_str(config)?;
3576
3577        match cfg {
3578            Gemma3Config::WithVision { vision_config, .. } => {
3579                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3580                let img_seq_len = num_patches + 1;
3581
3582                let max_vision_attn = {
3583                    (max_batch_size * max_num_images)
3584                        * vision_config.num_attention_heads
3585                        * img_seq_len
3586                        * img_seq_len
3587                };
3588
3589                Ok(max_vision_attn)
3590            }
3591            Gemma3Config::Text(_) => Ok(0),
3592        }
3593    }
3594
3595    fn non_mapped_size_in_bytes(
3596        &self,
3597        config: &str,
3598        dtype: DType,
3599        weight_pack_factor: usize,
3600    ) -> Result<usize> {
3601        let cfg: Gemma3Config = serde_json::from_str(config)?;
3602
3603        let text_elems = {
3604            let cfg = match &cfg {
3605                Gemma3Config::Text(cfg) => cfg,
3606                Gemma3Config::WithVision { text_config, .. } => text_config,
3607            };
3608            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3609            let lm_head = if !cfg.tie_word_embeddings {
3610                cfg.hidden_size * cfg.vocab_size
3611            } else {
3612                0
3613            };
3614            let norm = cfg.hidden_size;
3615            embed_tokens + lm_head + norm
3616        };
3617
3618        let vision_transformer = if let Gemma3Config::WithVision {
3619            vision_config: cfg, ..
3620        } = &cfg
3621        {
3622            let post_layernorm = cfg.hidden_size;
3623
3624            let conv_config = Conv2dConfig {
3625                stride: cfg.patch_size,
3626                ..Default::default()
3627            };
3628            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3629                * cfg.patch_size
3630                * cfg.patch_size;
3631
3632            let num_patches_per_side = cfg.image_size / cfg.patch_size;
3633            let num_patches = num_patches_per_side.pow(2);
3634            let position_embedding = num_patches * cfg.hidden_size;
3635
3636            let layer_elems = {
3637                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3638                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3639
3640                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3641                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3642
3643                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3644                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3645                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3646                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3647
3648                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3649            };
3650
3651            post_layernorm
3652                + patch_embedding
3653                + position_embedding
3654                + layer_elems * cfg.num_hidden_layers
3655        } else {
3656            0
3657        };
3658
3659        let elems = text_elems + vision_transformer;
3660
3661        Ok(elems * dtype.size_in_bytes())
3662    }
3663
3664    fn layer_sizes_in_bytes(
3665        &self,
3666        config: &str,
3667        dtype: DType,
3668        weight_pack_factor: usize,
3669    ) -> Result<Vec<usize>> {
3670        let cfg: Gemma3Config = serde_json::from_str(config)?;
3671
3672        let txt_cfg = match &cfg {
3673            Gemma3Config::Text(cfg) => cfg,
3674            Gemma3Config::WithVision { text_config, .. } => text_config,
3675        };
3676        let per_layer_elems = {
3677            let cfg = txt_cfg;
3678
3679            let input_layernorm = cfg.hidden_size;
3680            let post_attention_layernorm = cfg.hidden_size;
3681
3682            let size_in = cfg.hidden_size;
3683            let size_q = cfg.head_dim * cfg.num_attention_heads;
3684            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3685            let q_proj =
3686                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3687            let k_proj =
3688                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3689            let v_proj =
3690                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3691            let o_proj =
3692                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3693
3694            let h_size = cfg.hidden_size;
3695            let i_size = cfg.intermediate_size;
3696            let gate_proj = h_size * i_size / weight_pack_factor;
3697            let up_proj = h_size * i_size / weight_pack_factor;
3698            let down_proj = i_size * h_size / weight_pack_factor;
3699
3700            input_layernorm
3701                + post_attention_layernorm
3702                + q_proj
3703                + k_proj
3704                + v_proj
3705                + o_proj
3706                + gate_proj
3707                + up_proj
3708                + down_proj
3709        };
3710        Ok(vec![
3711            per_layer_elems * dtype.size_in_bytes();
3712            txt_cfg.num_hidden_layers
3713        ])
3714    }
3715
3716    fn num_layers(&self, config: &str) -> Result<usize> {
3717        let cfg: Gemma3Config = serde_json::from_str(config)?;
3718
3719        let txt_cfg = match &cfg {
3720            Gemma3Config::Text(cfg) => cfg,
3721            Gemma3Config::WithVision { text_config, .. } => text_config,
3722        };
3723
3724        Ok(txt_cfg.num_hidden_layers)
3725    }
3726
3727    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3728        let cfg: Gemma3Config = serde_json::from_str(config)?;
3729
3730        let cfg = match &cfg {
3731            Gemma3Config::Text(cfg) => cfg,
3732            Gemma3Config::WithVision { text_config, .. } => text_config,
3733        };
3734
3735        let cfg = ModelConfigMetadata {
3736            max_seq_len: cfg.max_position_embeddings,
3737            num_layers: cfg.num_hidden_layers,
3738            hidden_size: cfg.hidden_size,
3739            num_kv_heads: cfg.num_key_value_heads,
3740            num_attn_heads: cfg.num_attention_heads,
3741            sliding_window: None, // None to be more forgiving, some do not
3742            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3743            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3744        };
3745
3746        Ok(Box::new(cfg))
3747    }
3748
3749    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3750        Some(vec![NonMappedSubModel::Vision])
3751    }
3752}
3753
3754// ======================== Mistral 3 Loader
3755
3756/// [`VisionLoader`] for an Mistral 3 model.
3757///
3758/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3759pub struct Mistral3Loader;
3760
3761pub struct Mistral3Prefixer;
3762
3763impl VisionPromptPrefixer for Mistral3Prefixer {
3764    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3765        prompt.to_string()
3766    }
3767}
3768
3769impl VisionModelLoader for Mistral3Loader {
3770    fn load(
3771        &self,
3772        config: &str,
3773        vb: ShardedVarBuilder,
3774        normal_loading_metadata: NormalLoadingMetadata,
3775        attention_mechanism: AttentionImplementation,
3776    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3777        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3778        Ok(Box::new(Mistral3Model::new(
3779            &cfg,
3780            vb,
3781            self.is_gptx(config),
3782            normal_loading_metadata,
3783            attention_mechanism,
3784        )?))
3785    }
3786    fn is_gptx(&self, _config: &str) -> bool {
3787        true
3788    }
3789    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3790        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3791        Ok(Box::new(cfg))
3792    }
3793    fn get_processor(
3794        &self,
3795        _model_config: &str,
3796        processor_config: Option<ProcessorConfig>,
3797        _preprocessor_config: PreProcessorConfig,
3798        _max_edge: Option<u32>,
3799    ) -> Arc<dyn Processor + Send + Sync> {
3800        Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3801    }
3802    fn supports_paged_attention(&self, _config: &str) -> bool {
3803        true
3804    }
3805    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3806        true
3807    }
3808    fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
3809        Arc::new(Mistral3Prefixer)
3810    }
3811}
3812
3813impl IsqModelLoader for Mistral3Loader {
3814    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3815        Ok(vec![
3816            Regex::new(r"lm_head\.(weight|bias)$")?,
3817            // Attention
3818            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3819            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3820            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3821            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3822            // MLP
3823            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3824            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3825            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3826        ])
3827    }
3828    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3829        Ok(vec![
3830            Regex::new(r"lm_head\.(weight|bias)$")?,
3831            // Attention
3832            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3833            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3834            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3835            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3836            // MLP
3837            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3838            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3839            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3840        ])
3841    }
3842}
3843
3844#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3845impl DeviceMappedModelLoader for Mistral3Loader {
3846    fn mapped_max_act_size_elems(
3847        &self,
3848        config: &str,
3849        params: &AutoDeviceMapParams,
3850        _prompt_chunksize: usize,
3851    ) -> Result<usize> {
3852        let cfg: Mistral3Config = serde_json::from_str(config)?;
3853        let vcfg = &cfg.vision_config;
3854        let tcfg = &cfg.text_config;
3855
3856        let AutoDeviceMapParams::Vision {
3857            max_seq_len,
3858            max_batch_size,
3859            max_image_shape: (mut height, mut width),
3860            max_num_images,
3861        } = params
3862        else {
3863            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3864        };
3865
3866        let img_seq_len = {
3867            // Reshaping algorithm
3868
3869            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
3870            let (max_height, max_width) = (1540, 1540);
3871            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3872            if ratio > 1. {
3873                height = (height as f64 / ratio).floor() as usize;
3874                width = (width as f64 / ratio).floor() as usize;
3875            }
3876
3877            let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
3878            let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
3879
3880            height = num_height_tokens * vcfg.patch_size;
3881            width = num_width_tokens * vcfg.patch_size;
3882
3883            let num_height_tokens = height / vcfg.patch_size;
3884            let num_width_tokens = width / vcfg.patch_size;
3885
3886            (num_width_tokens + 1) * num_height_tokens
3887        };
3888
3889        // This model injects the vision information directly into the input embeddings
3890        let max_seq_len = img_seq_len * max_num_images + *max_seq_len;
3891        Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
3892    }
3893
3894    fn non_mapped_max_act_size_elems(
3895        &self,
3896        config: &str,
3897        params: &AutoDeviceMapParams,
3898    ) -> Result<usize> {
3899        let cfg: Mistral3Config = serde_json::from_str(config)?;
3900        let cfg = &cfg.vision_config;
3901
3902        let AutoDeviceMapParams::Vision {
3903            max_seq_len: _,
3904            max_batch_size,
3905            max_image_shape: (mut height, mut width),
3906            max_num_images,
3907        } = params
3908        else {
3909            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3910        };
3911
3912        let img_seq_len = {
3913            // Reshaping algorithm
3914
3915            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
3916            let (max_height, max_width) = (1540, 1540);
3917            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3918            if ratio > 1. {
3919                height = (height as f64 / ratio).floor() as usize;
3920                width = (width as f64 / ratio).floor() as usize;
3921            }
3922
3923            let num_height_tokens = (height - 1) / cfg.patch_size + 1;
3924            let num_width_tokens = (width - 1) / cfg.patch_size + 1;
3925
3926            height = num_height_tokens * cfg.patch_size;
3927            width = num_width_tokens * cfg.patch_size;
3928
3929            let num_height_tokens = height / cfg.patch_size;
3930            let num_width_tokens = width / cfg.patch_size;
3931
3932            (num_width_tokens + 1) * num_height_tokens
3933        };
3934
3935        Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
3936    }
3937
3938    fn non_mapped_size_in_bytes(
3939        &self,
3940        config: &str,
3941        dtype: DType,
3942        weight_pack_factor: usize,
3943    ) -> Result<usize> {
3944        let cfg: Mistral3Config = serde_json::from_str(config)?;
3945
3946        let text_elems = {
3947            let cfg = &cfg.text_config;
3948
3949            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3950            let lm_head = if !cfg.tie_word_embeddings {
3951                cfg.hidden_size * cfg.vocab_size
3952            } else {
3953                0
3954            };
3955            let norm = cfg.hidden_size;
3956            embed_tokens + lm_head + norm
3957        };
3958
3959        let vision_elems = {
3960            let cfg = &cfg.vision_config;
3961
3962            let patch_embed = {
3963                let conv_cfg = Conv2dConfig {
3964                    stride: cfg.patch_size,
3965                    ..Default::default()
3966                };
3967                cfg.num_channels * cfg.hidden_size / conv_cfg.groups
3968                    * cfg.patch_size
3969                    * cfg.patch_size
3970                    * cfg.patch_size
3971            };
3972            let ln_pre = cfg.hidden_size;
3973            let vision_layer = {
3974                let attn_norm = cfg.hidden_size;
3975                let ffn_norm = cfg.hidden_size;
3976
3977                let gate = cfg.hidden_size * cfg.intermediate_size;
3978                let up = cfg.hidden_size * cfg.intermediate_size;
3979                let down = cfg.hidden_size * cfg.intermediate_size;
3980
3981                let q = cfg.hidden_size * cfg.hidden_size;
3982                let k = cfg.hidden_size * cfg.hidden_size;
3983                let v = cfg.hidden_size * cfg.hidden_size;
3984                let o = cfg.hidden_size * cfg.hidden_size;
3985
3986                attn_norm + ffn_norm + gate + up + down + q + k + v + o
3987            };
3988
3989            patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
3990        };
3991
3992        let elems = text_elems + vision_elems;
3993
3994        Ok(elems * dtype.size_in_bytes())
3995    }
3996
3997    fn layer_sizes_in_bytes(
3998        &self,
3999        config: &str,
4000        dtype: DType,
4001        weight_pack_factor: usize,
4002    ) -> Result<Vec<usize>> {
4003        let cfg: Mistral3Config = serde_json::from_str(config)?;
4004        let cfg = &cfg.text_config;
4005
4006        let per_layer_elems = {
4007            let input_layernorm = cfg.hidden_size;
4008            let post_attention_layernorm = cfg.hidden_size;
4009
4010            let size_in = cfg.hidden_size;
4011            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4012            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4013            let q_proj = size_in * size_q / weight_pack_factor;
4014            let k_proj = size_in * size_kv / weight_pack_factor;
4015            let v_proj = size_in * size_kv / weight_pack_factor;
4016            let o_proj = size_q * size_in / weight_pack_factor;
4017
4018            let h_size = cfg.hidden_size;
4019            let i_size = cfg.intermediate_size;
4020            let gate_proj = h_size * i_size / weight_pack_factor;
4021            let up_proj = h_size * i_size / weight_pack_factor;
4022            let down_proj = i_size * h_size / weight_pack_factor;
4023
4024            input_layernorm
4025                + post_attention_layernorm
4026                + q_proj
4027                + k_proj
4028                + v_proj
4029                + o_proj
4030                + gate_proj
4031                + up_proj
4032                + down_proj
4033        };
4034        Ok(vec![
4035            per_layer_elems * dtype.size_in_bytes();
4036            cfg.num_hidden_layers
4037        ])
4038    }
4039
4040    fn num_layers(&self, config: &str) -> Result<usize> {
4041        let cfg: Mistral3Config = serde_json::from_str(config)?;
4042        let cfg = &cfg.text_config;
4043        Ok(cfg.num_hidden_layers)
4044    }
4045
4046    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4047        let cfg: Mistral3Config = serde_json::from_str(config)?;
4048        let cfg = &cfg.text_config;
4049
4050        let cfg = ModelConfigMetadata {
4051            max_seq_len: cfg.max_position_embeddings,
4052            num_layers: cfg.num_hidden_layers,
4053            hidden_size: cfg.hidden_size,
4054            num_kv_heads: cfg.num_key_value_heads,
4055            num_attn_heads: cfg.num_attention_heads,
4056            sliding_window: cfg.sliding_window,
4057            k_head_dim: cfg.head_dim(),
4058            v_head_dim: cfg.head_dim(),
4059        };
4060
4061        Ok(Box::new(cfg))
4062    }
4063
4064    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4065        Some(vec![NonMappedSubModel::Vision])
4066    }
4067}
4068
4069// ======================== Llama 4 Loader
4070
4071/// [`VisionLoader`] for an Llama Vision model.
4072///
4073/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
4074pub struct VLlama4Loader;
4075
4076pub struct VLlama4Prefixer;
4077
4078impl VisionPromptPrefixer for VLlama4Prefixer {
4079    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4080        format!(
4081            "{}{prompt}",
4082            llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4083        )
4084    }
4085}
4086
4087impl VisionModelLoader for VLlama4Loader {
4088    fn load(
4089        &self,
4090        config: &str,
4091        vb: ShardedVarBuilder,
4092        normal_loading_metadata: NormalLoadingMetadata,
4093        attention_mechanism: AttentionImplementation,
4094    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4095        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4096        Ok(Box::new(Llama4Model::new(
4097            &cfg,
4098            vb,
4099            self.is_gptx(config),
4100            normal_loading_metadata,
4101            attention_mechanism,
4102        )?))
4103    }
4104    fn is_gptx(&self, _config: &str) -> bool {
4105        false
4106    }
4107    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4108        let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4109        Ok(Box::new(cfg))
4110    }
4111    fn get_processor(
4112        &self,
4113        _model_config: &str,
4114        processor_config: Option<ProcessorConfig>,
4115        _preprocessor_config: PreProcessorConfig,
4116        _max_edge: Option<u32>,
4117    ) -> Arc<dyn Processor + Send + Sync> {
4118        Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4119    }
4120    fn supports_paged_attention(&self, _config: &str) -> bool {
4121        true
4122    }
4123    fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
4124        Arc::new(VLlama4Prefixer)
4125    }
4126}
4127
4128impl IsqModelLoader for VLlama4Loader {
4129    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4130        Ok(vec![
4131            Regex::new(r"lm_head\.(weight|bias)$")?,
4132            // Attention
4133            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4134            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4135            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4136            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4137            // FF MoE
4138            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4139            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4140            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4141            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4142            Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4143            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4144            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4145            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4146            // FF MLP
4147            Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4148            Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4149            Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4150        ])
4151    }
4152    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4153        Ok(vec![
4154            Regex::new(r"lm_head\.(weight|bias)$")?,
4155            // Attention
4156            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4157            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4158            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4159            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4160            // FF MoE
4161            Regex::new(
4162                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$",
4163            )?,
4164            Regex::new(
4165                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$",
4166            )?,
4167            Regex::new(
4168                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$",
4169            )?,
4170            Regex::new(
4171                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$",
4172            )?,
4173            Regex::new(
4174                r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4175            )?,
4176            Regex::new(
4177                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4178            )?,
4179            Regex::new(
4180                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4181            )?,
4182            Regex::new(
4183                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4184            )?,
4185            // FF MLP
4186            Regex::new(
4187                r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4188            )?,
4189            Regex::new(
4190                r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4191            )?,
4192            Regex::new(
4193                r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4194            )?,
4195        ])
4196    }
4197}
4198
4199impl VLlama4Loader {
4200    /// This incorporates the max batch size!
4201    /// Returns (pixels max batch size, num text image tokens)
4202    #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4203    fn run_dummy_processing(
4204        &self,
4205        cfg: &Llama4Config,
4206        height: usize,
4207        width: usize,
4208        max_num_images: usize,
4209        max_batch_size: usize,
4210    ) -> Result<(usize, usize)> {
4211        let cfg = &cfg.vision_config;
4212
4213        let img_processor =
4214            Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4215        let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4216        let res = img_processor.preprocess(
4217            vec![image; max_num_images],
4218            vec![],
4219            &PreProcessorConfig::default(),
4220            &Device::Cpu,
4221            (max_batch_size, max_num_images),
4222        )?;
4223
4224        let pixels_batch_size = res.pixel_values.dim(0)?;
4225        let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4226
4227        let (image_h, image_w) = (
4228            res.pixel_values.dim(D::Minus2).unwrap(),
4229            res.pixel_values.dim(D::Minus1).unwrap(),
4230        );
4231        let num_patches_per_chunk = (image_h / img_processor.patch_size)
4232            * (image_w / img_processor.patch_size)
4233            / img_processor.downsample_ratio;
4234
4235        Ok((
4236            pixels_max_batch_size,
4237            num_patches_per_chunk * pixels_max_batch_size,
4238        ))
4239    }
4240}
4241
4242impl DeviceMappedModelLoader for VLlama4Loader {
4243    fn mapped_max_act_size_elems(
4244        &self,
4245        config: &str,
4246        params: &AutoDeviceMapParams,
4247        _prompt_chunksize: usize,
4248    ) -> Result<usize> {
4249        let AutoDeviceMapParams::Vision {
4250            max_seq_len,
4251            max_batch_size,
4252            max_image_shape: (height, width),
4253            max_num_images,
4254        } = params
4255        else {
4256            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4257        };
4258
4259        let cfg: Llama4Config = serde_json::from_str(config)?;
4260
4261        let (_pixels_batch_size, num_text_image_toks) =
4262            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4263
4264        let max_seq_len = max_seq_len + num_text_image_toks;
4265
4266        Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4267    }
4268    fn non_mapped_max_act_size_elems(
4269        &self,
4270        config: &str,
4271        params: &AutoDeviceMapParams,
4272    ) -> Result<usize> {
4273        let AutoDeviceMapParams::Vision {
4274            max_seq_len: _,
4275            max_batch_size,
4276            max_image_shape: (height, width),
4277            max_num_images,
4278        } = params
4279        else {
4280            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4281        };
4282
4283        let cfg: Llama4Config = serde_json::from_str(config)?;
4284
4285        let (pixels_batch_size, _num_text_image_toks) =
4286            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4287        let max_seq_len = cfg.vision_config.num_patches();
4288
4289        Ok((max_batch_size * pixels_batch_size)
4290            * cfg.vision_config.num_attention_heads
4291            * max_seq_len
4292            * max_seq_len)
4293    }
4294
4295    fn non_mapped_size_in_bytes(
4296        &self,
4297        config: &str,
4298        dtype: DType,
4299        weight_pack_factor: usize,
4300    ) -> Result<usize> {
4301        let cfg: Llama4Config = serde_json::from_str(config)?;
4302        let tcfg = &cfg.text_config;
4303
4304        let text_elems = {
4305            let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4306            let lm_head = if !tcfg.tie_word_embeddings {
4307                tcfg.hidden_size * tcfg.vocab_size
4308            } else {
4309                0
4310            };
4311            let norm = tcfg.hidden_size;
4312            embed_tokens + lm_head + norm
4313        };
4314
4315        let vision_elems = {
4316            let cfg = &cfg.vision_config;
4317
4318            let num_patches = cfg.num_patches();
4319
4320            let unfold_elems =
4321                (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4322            let class_embeddng_elems = cfg.hidden_size;
4323            let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4324            let layernorm_pre_elems = cfg.hidden_size;
4325            let layernorm_post_elems = cfg.hidden_size;
4326
4327            let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4328                / weight_pack_factor
4329                + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4330
4331            let encoder_layer = {
4332                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4333                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4334
4335                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4336                let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4337                    / weight_pack_factor
4338                    + cfg.num_attention_heads * head_dim;
4339                let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4340                    / weight_pack_factor
4341                    + cfg.num_attention_heads * head_dim;
4342                let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4343                    / weight_pack_factor
4344                    + cfg.num_attention_heads * head_dim;
4345                let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4346                    / weight_pack_factor
4347                    + cfg.num_attention_heads * head_dim;
4348
4349                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4350                    + cfg.intermediate_size;
4351                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4352                    + cfg.hidden_size;
4353
4354                input_layernorm
4355                    + post_attention_layernorm
4356                    + q_proj
4357                    + k_proj
4358                    + v_proj
4359                    + o_proj
4360                    + fc1
4361                    + fc2
4362            };
4363
4364            unfold_elems
4365                + class_embeddng_elems
4366                + positional_embedding_vlm_elems
4367                + layernorm_post_elems
4368                + layernorm_pre_elems
4369                + pixel_shuffle_elems
4370                + encoder_layer * cfg.num_hidden_layers
4371        };
4372
4373        let elems = text_elems + vision_elems;
4374
4375        Ok(elems * dtype.size_in_bytes())
4376    }
4377
4378    fn layer_sizes_in_bytes(
4379        &self,
4380        config: &str,
4381        dtype: DType,
4382        weight_pack_factor: usize,
4383    ) -> Result<Vec<usize>> {
4384        let cfg: Llama4Config = serde_json::from_str(config)?;
4385        let tcfg = &cfg.text_config;
4386
4387        let mut per_layer_elems = Vec::new();
4388
4389        for layer_idx in 0..tcfg.num_hidden_layers {
4390            let input_layernorm = tcfg.hidden_size;
4391            let post_attention_layernorm = tcfg.hidden_size;
4392
4393            let size_in = tcfg.hidden_size;
4394            let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4395            let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4396            let q_proj = size_in * size_q / weight_pack_factor;
4397            let k_proj = size_in * size_kv / weight_pack_factor;
4398            let v_proj = size_in * size_kv / weight_pack_factor;
4399            let o_proj = size_q * size_in / weight_pack_factor;
4400
4401            let use_moe = tcfg.moe_layers().contains(&layer_idx);
4402            let moe_block = if use_moe {
4403                let h_size = tcfg.hidden_size;
4404                let i_size = tcfg.intermediate_size;
4405                let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4406                let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4407                let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4408
4409                gate_proj + up_proj + down_proj
4410            } else {
4411                let h_size = tcfg.hidden_size;
4412                let i_size = tcfg.intermediate_size_mlp;
4413                let gate_proj = h_size * i_size / weight_pack_factor;
4414                let up_proj = h_size * i_size / weight_pack_factor;
4415                let down_proj = i_size * h_size / weight_pack_factor;
4416
4417                gate_proj + up_proj + down_proj
4418            };
4419
4420            per_layer_elems.push(
4421                input_layernorm
4422                    + post_attention_layernorm
4423                    + q_proj
4424                    + k_proj
4425                    + v_proj
4426                    + o_proj
4427                    + moe_block,
4428            );
4429        }
4430
4431        Ok(per_layer_elems
4432            .into_iter()
4433            .map(|x| x * dtype.size_in_bytes())
4434            .collect())
4435    }
4436
4437    fn num_layers(&self, config: &str) -> Result<usize> {
4438        let cfg: Llama4Config = serde_json::from_str(config)?;
4439        Ok(cfg.text_config.num_hidden_layers)
4440    }
4441
4442    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4443        let cfg: Llama4Config = serde_json::from_str(config)?;
4444        let cfg = &cfg.text_config;
4445
4446        let cfg = ModelConfigMetadata {
4447            max_seq_len: cfg.max_position_embeddings,
4448            num_layers: cfg.num_hidden_layers,
4449            hidden_size: cfg.hidden_size,
4450            num_kv_heads: cfg.num_attention_heads,
4451            num_attn_heads: cfg.num_attention_heads,
4452            sliding_window: None,
4453            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4454            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4455        };
4456
4457        Ok(Box::new(cfg))
4458    }
4459
4460    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4461        Some(vec![NonMappedSubModel::Vision])
4462    }
4463}