mistralrs_core/pipeline/loaders/
vision_loaders.rs

1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor};
7use candle_nn::Conv2dConfig;
8use mistralrs_quant::ShardedVarBuilder;
9
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use regex::Regex;
14use serde::Deserialize;
15
16use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
17
18use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
19use crate::amoe::AnyMoeBaseModelMixin;
20use crate::device_map::DeviceMapper;
21use crate::layers::Conv3dConfig;
22use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
23use crate::pipeline::isq::IsqModelLoader;
24use crate::pipeline::loaders::AutoDeviceMapParams;
25use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
26use crate::pipeline::{EitherCache, IsqModel, Processor, ProcessorCreator, VisionPromptPrefixer};
27use crate::utils::varbuilder_utils::DeviceForLoadTensor;
28use crate::vision_models::clip::ClipConfig;
29use crate::vision_models::gemma3::config::Gemma3Config;
30use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
31use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
32use crate::vision_models::idefics2_input_processor::Idefics2Processor;
33use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
34use crate::vision_models::inputs_processor::Phi4MMProcessor;
35use crate::vision_models::llava::config::Config as LLaVAConfig;
36use crate::vision_models::llava15::Model as LLaVA;
37use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
38use crate::vision_models::llava_next::Model as LLaVANext;
39use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
40use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
41use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
42use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
43use crate::vision_models::phi3_inputs_processor::Phi3Processor;
44use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
45use crate::vision_models::preprocessor_config::PreProcessorConfig;
46use crate::vision_models::processor_config::ProcessorConfig;
47use crate::vision_models::qwen2_5_vl::{
48    Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
49};
50use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
51use crate::vision_models::{minicpmo, phi4};
52
53pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
54    // pixel_values and pixel_attention_mask only specified for prompt seqs
55    #[allow(clippy::too_many_arguments)]
56    fn forward(
57        &self,
58        input_ids: &Tensor,
59        pixel_values: Option<Tensor>,
60        seqlen_offsets: &[usize],
61        context_lens: Vec<(usize, usize)>,
62        position_ids: Vec<usize>,
63        model_specific_args: Box<dyn Any>, // pixel attention mask, or image sizes, or anything else
64        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
65        flash_params: &FlashParams,
66    ) -> candle_core::Result<Tensor>;
67    fn device(&self) -> &Device;
68    fn cache(&self) -> &EitherCache;
69    fn cache_mut(&mut self) -> &mut EitherCache;
70    fn max_seq_len(&self) -> usize;
71    fn has_conv2d(&self) -> bool;
72    fn config(&self) -> &ModelConfigMetadata;
73    /// For a prompt without images. Requires batch size of 1!
74    fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
75}
76
77pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
78    fn load(
79        &self,
80        config: &str,
81        use_flash_attn: bool,
82        vb: ShardedVarBuilder,
83        normal_loading_metadata: NormalLoadingMetadata,
84        attention_mechanism: AttentionImplementation,
85    ) -> Result<Box<dyn VisionModel + Send + Sync>>;
86    fn is_gptx(&self) -> bool;
87    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>>;
88    fn get_processor(
89        &self,
90        model_config: &str,
91        processor_config: Option<ProcessorConfig>,
92        preprocessor_config: PreProcessorConfig,
93        max_edge: Option<u32>,
94    ) -> Arc<dyn Processor + Send + Sync>;
95    fn supports_paged_attention(&self) -> bool;
96    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer>;
97    fn get_device_for_tensor(
98        &self,
99        config: &str,
100        _mapper: &dyn DeviceMapper,
101        loading_isq: bool,
102    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
103        if loading_isq {
104            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
105        } else {
106            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
107            let num_layers = self.model_config(config)?.num_layers();
108            let closure = move |name: String| {
109                if let Some(captures) = re.captures(&name) {
110                    captures
111                        .get(1)
112                        .and_then(|m| m.as_str().parse::<usize>().ok())
113                        .map(|l| l.min(num_layers))
114                        .map(DeviceForLoadTensor::Idx)
115                        .unwrap_or(DeviceForLoadTensor::Base)
116                } else {
117                    DeviceForLoadTensor::Base
118                }
119            };
120
121            Ok(Arc::new(closure))
122        }
123    }
124}
125
126#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
127#[derive(Clone, Debug, Deserialize, PartialEq)]
128/// The architecture to load the vision model as.
129pub enum VisionLoaderType {
130    #[serde(rename = "phi3v")]
131    Phi3V,
132    #[serde(rename = "idefics2")]
133    Idefics2,
134    #[serde(rename = "llava_next")]
135    LLaVANext,
136    #[serde(rename = "llava")]
137    LLaVA,
138    #[serde(rename = "vllama")]
139    VLlama,
140    #[serde(rename = "qwen2vl")]
141    Qwen2VL,
142    #[serde(rename = "idefics3")]
143    Idefics3,
144    #[serde(rename = "minicpmo")]
145    MiniCpmO,
146    #[serde(rename = "phi4mm")]
147    Phi4MM,
148    #[serde(rename = "qwen2_5vl")]
149    Qwen2_5VL,
150    #[serde(rename = "gemma3")]
151    Gemma3,
152    #[serde(rename = "mistral3")]
153    Mistral3,
154}
155
156impl FromStr for VisionLoaderType {
157    type Err = String;
158    fn from_str(s: &str) -> Result<Self, Self::Err> {
159        match s {
160            "phi3v" => Ok(Self::Phi3V),
161            "idefics2" => Ok(Self::Idefics2),
162            "llava_next" => Ok(Self::LLaVANext),
163            "llava" => Ok(Self::LLaVA),
164            "vllama" => Ok(Self::VLlama),
165            "qwen2vl" => Ok(Self::Qwen2VL),
166            "idefics3" => Ok(Self::Idefics3),
167            "minicpmo" => Ok(Self::MiniCpmO),
168            "phi4mm" => Ok(Self::Phi4MM),
169            "qwen2_5vl" => Ok(Self::Qwen2_5VL),
170            "gemma3" => Ok(Self::Gemma3),
171            "mistral3" => Ok(Self::Mistral3),
172            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`.")),
173        }
174    }
175}
176
177macro_rules! bias_if {
178    ($cond:expr, $size:expr) => {
179        if $cond {
180            $size
181        } else {
182            0
183        }
184    };
185}
186
187fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
188    let pre_layer_norm = cfg.hidden_size;
189    let final_layer_norm = cfg.hidden_size;
190
191    let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
192    let num_positions = num_patches + 1;
193
194    let class_embedding = cfg.hidden_size;
195
196    let position_ids = num_positions;
197    let position_embedding = num_positions * cfg.hidden_size;
198
199    let conv2dconfig = Conv2dConfig {
200        stride: cfg.patch_size,
201        ..Default::default()
202    };
203    let patch_embedding =
204        cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
205
206    let encoder_layer_elems = {
207        let layer_norm1 = cfg.hidden_size;
208        let layer_norm2 = cfg.hidden_size;
209
210        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
211        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
212        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
213        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
214
215        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
216        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
217
218        layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
219    };
220
221    pre_layer_norm
222        + final_layer_norm
223        + class_embedding
224        + position_ids
225        + position_embedding
226        + patch_embedding
227        + cfg.num_hidden_layers * encoder_layer_elems
228}
229
230// ======================== Phi 3 loader
231
232/// [`VisionLoader`] for a Phi 3 Vision model.
233///
234/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
235pub struct Phi3VLoader;
236
237pub struct Phi3VPrefixer;
238
239impl VisionPromptPrefixer for Phi3VPrefixer {
240    fn prefix_image(&self, image_index: usize, prompt: &str) -> String {
241        // Image indexing starts at 0.
242        format!("<|image_{}|>{prompt}", image_index + 1)
243    }
244}
245
246impl VisionModelLoader for Phi3VLoader {
247    fn load(
248        &self,
249        config: &str,
250        use_flash_attn: bool,
251        vb: ShardedVarBuilder,
252        normal_loading_metadata: NormalLoadingMetadata,
253        attention_mechanism: AttentionImplementation,
254    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
255        let mut config: Phi3Config = serde_json::from_str(config)?;
256        config.use_flash_attn = use_flash_attn;
257        Ok(Box::new(Phi3::new(
258            &config,
259            vb,
260            self.is_gptx(),
261            normal_loading_metadata,
262            attention_mechanism,
263        )?))
264    }
265    fn is_gptx(&self) -> bool {
266        true
267    }
268    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
269        let mut config: Phi3Config = serde_json::from_str(config)?;
270        config.use_flash_attn = use_flash_attn;
271        Ok(Box::new(config))
272    }
273    fn get_processor(
274        &self,
275        _model_config: &str,
276        processor_config: Option<ProcessorConfig>,
277        preprocessor_config: PreProcessorConfig,
278        _max_edge: Option<u32>,
279    ) -> Arc<dyn Processor + Send + Sync> {
280        Phi3Processor::new_processor(processor_config, preprocessor_config)
281    }
282    fn supports_paged_attention(&self) -> bool {
283        true
284    }
285    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
286        Arc::new(Phi3VPrefixer)
287    }
288}
289
290impl IsqModelLoader for Phi3VLoader {
291    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
292        Ok(vec![
293            Regex::new(r"lm_head\.(weight|bias)$")?,
294            // Attention
295            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
296            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
297            // MLP
298            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
299            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
300        ])
301    }
302}
303
304impl DeviceMappedModelLoader for Phi3VLoader {
305    fn mapped_max_act_size_elems(
306        &self,
307        config: &str,
308        params: &AutoDeviceMapParams,
309        _prompt_chunksize: usize,
310    ) -> Result<usize> {
311        // NOTE: we ignore max_num_images although it can only be one...
312        let AutoDeviceMapParams::Vision {
313            max_seq_len,
314            max_batch_size,
315            max_image_shape: _,
316            max_num_images,
317        } = params
318        else {
319            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
320        };
321
322        let cfg: Phi3Config = serde_json::from_str(config)?;
323
324        let vcfg = &PHI3V_CLIP_CONFIG;
325
326        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
327        let img_seq_len = (num_patches + 1) * max_num_images;
328
329        let max_text_attn = {
330            // This model injects the vision information directly into the input embeddings
331            let max_seq_len = img_seq_len + max_seq_len;
332            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
333        };
334
335        Ok(max_text_attn)
336    }
337
338    fn non_mapped_max_act_size_elems(
339        &self,
340        config: &str,
341        params: &AutoDeviceMapParams,
342    ) -> Result<usize> {
343        // NOTE: we ignore max_num_images although it can only be one...
344        let AutoDeviceMapParams::Vision {
345            max_seq_len: _,
346            max_batch_size,
347            max_image_shape: _,
348            max_num_images,
349        } = params
350        else {
351            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
352        };
353
354        let cfg: Phi3Config = serde_json::from_str(config)?;
355
356        let vcfg = &PHI3V_CLIP_CONFIG;
357
358        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
359        let img_seq_len = num_patches + 1;
360
361        let max_vision_attn = {
362            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
363        };
364
365        Ok(max_vision_attn)
366    }
367
368    fn non_mapped_size_in_bytes(
369        &self,
370        config: &str,
371        dtype: DType,
372        weight_pack_factor: usize,
373    ) -> Result<usize> {
374        let cfg: Phi3Config = serde_json::from_str(config)?;
375        let elems = {
376            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
377            let lm_head = if !cfg.tie_word_embeddings {
378                cfg.hidden_size * cfg.vocab_size
379            } else {
380                0
381            };
382            let norm = cfg.hidden_size;
383
384            let image_embed = {
385                let projection_cls = cfg
386                    .embd_layer
387                    .projection_cls
388                    .clone()
389                    .unwrap_or("linear".to_string());
390                let with_learnable_separator =
391                    cfg.embd_layer.with_learnable_separator.unwrap_or(false);
392                let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
393                let image_dim_out = cfg.img_processor.image_dim_out;
394
395                let proj = match (projection_cls.as_str(), use_hd_transform) {
396                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
397                    ("mlp", true) => {
398                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
399                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
400                        a + b
401                    }
402                    ("mlp", false) => {
403                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
404                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
405                        a + b
406                    }
407                    _ => {
408                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
409                    }
410                };
411
412                let (glb_gn, sub_gn) = if with_learnable_separator {
413                    let glb_gn = image_dim_out * 4;
414                    let sub_gn = image_dim_out * 4;
415                    (glb_gn, sub_gn)
416                } else {
417                    (0, 0)
418                };
419
420                let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
421
422                proj + glb_gn + sub_gn + clip_vit
423            };
424
425            embed_tokens + lm_head + norm + image_embed
426        };
427
428        Ok(elems * dtype.size_in_bytes())
429    }
430
431    fn layer_sizes_in_bytes(
432        &self,
433        config: &str,
434        dtype: DType,
435        weight_pack_factor: usize,
436    ) -> Result<Vec<usize>> {
437        let cfg: Phi3Config = serde_json::from_str(config)?;
438        let per_layer_elems = {
439            let input_layernorm = cfg.hidden_size;
440            let post_attention_layernorm = cfg.hidden_size;
441
442            let size_in = cfg.hidden_size;
443            let head_dim = cfg.head_dim();
444            let op_size = head_dim * head_dim + 2 * cfg.num_key_value_heads * head_dim;
445            let qkv_proj = size_in * op_size / weight_pack_factor;
446            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
447
448            let h_size = cfg.hidden_size;
449            let i_size = cfg.intermediate_size;
450            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
451            let down_proj = h_size * i_size / weight_pack_factor;
452
453            input_layernorm
454                + post_attention_layernorm
455                + qkv_proj
456                + o_proj
457                + gate_up_proj
458                + down_proj
459        };
460        Ok(vec![
461            per_layer_elems * dtype.size_in_bytes();
462            cfg.num_hidden_layers
463        ])
464    }
465
466    fn num_layers(&self, config: &str) -> Result<usize> {
467        let cfg: Phi3Config = serde_json::from_str(config)?;
468        Ok(cfg.num_hidden_layers)
469    }
470
471    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
472        let cfg: Phi3Config = serde_json::from_str(config)?;
473
474        let cfg = ModelConfigMetadata {
475            max_seq_len: cfg.max_position_embeddings,
476            num_layers: cfg.num_hidden_layers,
477            hidden_size: cfg.hidden_size,
478            num_kv_heads: cfg.num_key_value_heads,
479            num_attn_heads: cfg.num_attention_heads,
480            sliding_window: cfg.sliding_window,
481            k_head_dim: cfg.head_dim(),
482            v_head_dim: cfg.head_dim(),
483        };
484
485        Ok(Box::new(cfg))
486    }
487
488    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
489        Some(vec![NonMappedSubModel::Vision])
490    }
491}
492
493// ======================== Idefics 2 loader
494
495/// [`VisionLoader`] for an Idefics 2 Vision model.
496///
497/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
498pub struct Idefics2Loader;
499
500pub struct Idefics2Prefixer;
501
502impl VisionPromptPrefixer for Idefics2Prefixer {
503    fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
504        // Chat template does it
505        prompt.to_string()
506    }
507}
508
509impl VisionModelLoader for Idefics2Loader {
510    fn load(
511        &self,
512        config: &str,
513        use_flash_attn: bool,
514        vb: ShardedVarBuilder,
515        normal_loading_metadata: NormalLoadingMetadata,
516        attention_mechanism: AttentionImplementation,
517    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
518        let mut config: Idefics2Config = serde_json::from_str(config)?;
519        config.text_config.use_flash_attn = use_flash_attn;
520        Ok(Box::new(Idefics2::new(
521            &config,
522            vb,
523            self.is_gptx(),
524            normal_loading_metadata,
525            attention_mechanism,
526        )?))
527    }
528    fn is_gptx(&self) -> bool {
529        true
530    }
531    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
532        let mut config: Idefics2Config = serde_json::from_str(config)?;
533        config.text_config.use_flash_attn = use_flash_attn;
534        Ok(Box::new(config))
535    }
536    fn get_processor(
537        &self,
538        _model_config: &str,
539        processor_config: Option<ProcessorConfig>,
540        preprocessor_config: PreProcessorConfig,
541        max_edge: Option<u32>,
542    ) -> Arc<dyn Processor + Send + Sync> {
543        Arc::new(Idefics2Processor::new(
544            processor_config.unwrap(),
545            preprocessor_config,
546            max_edge,
547        ))
548    }
549    fn supports_paged_attention(&self) -> bool {
550        true
551    }
552    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
553        Arc::new(Idefics2Prefixer)
554    }
555}
556
557impl IsqModelLoader for Idefics2Loader {
558    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
559        Ok(vec![
560            Regex::new(r"lm_head\.(weight|bias)$")?,
561            // Attention
562            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
563            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
564            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
565            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
566            // MLP
567            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
568            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
569            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
570        ])
571    }
572}
573
574impl DeviceMappedModelLoader for Idefics2Loader {
575    fn mapped_max_act_size_elems(
576        &self,
577        config: &str,
578        params: &AutoDeviceMapParams,
579        _prompt_chunksize: usize,
580    ) -> Result<usize> {
581        let AutoDeviceMapParams::Vision {
582            max_seq_len,
583            max_batch_size,
584            max_image_shape: _,
585            max_num_images,
586        } = params
587        else {
588            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
589        };
590
591        let cfg: Idefics2Config = serde_json::from_str(config)?;
592
593        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
594        let img_seq_len = (num_patches + 1) * max_num_images;
595
596        let max_text_attn = {
597            // This model injects the vision information directly into the input embeddings
598            let max_seq_len = img_seq_len + max_seq_len;
599            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
600        };
601
602        Ok(max_text_attn)
603    }
604
605    fn non_mapped_max_act_size_elems(
606        &self,
607        config: &str,
608        params: &AutoDeviceMapParams,
609    ) -> Result<usize> {
610        let AutoDeviceMapParams::Vision {
611            max_seq_len: _,
612            max_batch_size,
613            max_image_shape: _,
614            max_num_images,
615        } = params
616        else {
617            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
618        };
619
620        let cfg: Idefics2Config = serde_json::from_str(config)?;
621
622        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
623        let img_seq_len = num_patches + 1;
624
625        let max_vision_attn = {
626            // do_image_splitting = true
627            let images_factor = 5;
628
629            (max_batch_size * images_factor * max_num_images)
630                * cfg.vision_config.num_attention_heads
631                * img_seq_len
632                * img_seq_len
633        };
634
635        Ok(max_vision_attn)
636    }
637
638    fn non_mapped_size_in_bytes(
639        &self,
640        config: &str,
641        dtype: DType,
642        weight_pack_factor: usize,
643    ) -> Result<usize> {
644        let cfg: Idefics2Config = serde_json::from_str(config)?;
645        let text_elems = {
646            let tie_word_embeddings = cfg.tie_word_embeddings;
647            let cfg = &cfg.text_config;
648
649            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
650            let lm_head = if !tie_word_embeddings {
651                cfg.hidden_size * cfg.vocab_size
652            } else {
653                0
654            };
655            let norm = cfg.hidden_size;
656            embed_tokens + lm_head + norm
657        };
658
659        let connector_elems = {
660            let tcfg = &cfg.text_config;
661            let vcfg = &cfg.vision_config;
662            let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
663            let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
664            let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
665
666            let perceiver_elems = {
667                let tcfg = &cfg.text_config;
668                let pcfg = &cfg.perceiver_config;
669
670                let n_latents = pcfg.resampler_n_latents;
671                let hidden_size = tcfg.hidden_size;
672                let depth = pcfg.resampler_depth;
673
674                let norm = tcfg.hidden_size;
675                let latents = n_latents * hidden_size;
676
677                let layer_elems = {
678                    let input_latents_norm = hidden_size;
679                    let input_context_norm = hidden_size;
680                    let post_attn_norm = hidden_size;
681
682                    let num_heads = pcfg.resampler_n_heads;
683                    let head_dim = pcfg.resampler_head_dim;
684                    let num_key_value_heads = pcfg.num_key_value_heads;
685
686                    let q_proj = hidden_size * num_heads * head_dim;
687                    let k_proj = hidden_size * num_key_value_heads * head_dim;
688                    let v_proj = hidden_size * num_key_value_heads * head_dim;
689                    let o_proj = num_heads * head_dim * hidden_size;
690
691                    let gate_proj = hidden_size * hidden_size * 4;
692                    let up_proj = hidden_size * hidden_size * 4;
693                    let down_proj = hidden_size * 4 * hidden_size;
694
695                    input_latents_norm
696                        + input_context_norm
697                        + post_attn_norm
698                        + q_proj
699                        + k_proj
700                        + v_proj
701                        + o_proj
702                        + gate_proj
703                        + up_proj
704                        + down_proj
705                };
706
707                norm + latents + layer_elems * depth
708            };
709
710            gate_proj + up_proj + down_proj + perceiver_elems
711        };
712
713        let vision_transformer = {
714            let cfg = &cfg.vision_config;
715
716            let post_layernorm = cfg.hidden_size;
717
718            let conv_config = Conv2dConfig {
719                stride: cfg.patch_size,
720                ..Default::default()
721            };
722            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
723                * cfg.patch_size
724                * cfg.patch_size;
725
726            let num_patches_per_side = cfg.image_size / cfg.patch_size;
727            let num_patches = num_patches_per_side.pow(2);
728            let position_embedding = num_patches * cfg.hidden_size;
729
730            let layer_elems = {
731                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
732                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
733
734                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
735                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
736
737                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
738                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
739                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
740                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
741
742                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
743            };
744
745            post_layernorm + patch_embedding + position_embedding + layer_elems
746        };
747
748        let elems = text_elems + connector_elems + vision_transformer;
749
750        Ok(elems * dtype.size_in_bytes())
751    }
752
753    fn layer_sizes_in_bytes(
754        &self,
755        config: &str,
756        dtype: DType,
757        weight_pack_factor: usize,
758    ) -> Result<Vec<usize>> {
759        let cfg: Idefics2Config = serde_json::from_str(config)?;
760        let cfg = cfg.text_config;
761        let per_layer_elems = {
762            let input_layernorm = cfg.hidden_size;
763            let post_attention_layernorm = cfg.hidden_size;
764
765            let size_in = cfg.hidden_size;
766            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
767            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
768            let q_proj = size_in * size_q / weight_pack_factor;
769            let k_proj = size_in * size_kv / weight_pack_factor;
770            let v_proj = size_in * size_kv / weight_pack_factor;
771            let o_proj = size_q * size_in / weight_pack_factor;
772
773            let h_size = cfg.hidden_size;
774            let i_size = cfg.intermediate_size;
775            let gate_proj = h_size * i_size / weight_pack_factor;
776            let up_proj = h_size * i_size / weight_pack_factor;
777            let down_proj = i_size * h_size / weight_pack_factor;
778
779            input_layernorm
780                + post_attention_layernorm
781                + q_proj
782                + k_proj
783                + v_proj
784                + o_proj
785                + gate_proj
786                + up_proj
787                + down_proj
788        };
789        Ok(vec![
790            per_layer_elems * dtype.size_in_bytes();
791            cfg.num_hidden_layers
792        ])
793    }
794
795    fn num_layers(&self, config: &str) -> Result<usize> {
796        let cfg: Idefics2Config = serde_json::from_str(config)?;
797        Ok(cfg.text_config.num_hidden_layers)
798    }
799    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
800        let cfg: Idefics2Config = serde_json::from_str(config)?;
801        let cfg = &cfg.text_config;
802
803        let cfg = ModelConfigMetadata {
804            max_seq_len: cfg.max_position_embeddings,
805            num_layers: cfg.num_hidden_layers,
806            hidden_size: cfg.hidden_size,
807            num_kv_heads: cfg.num_key_value_heads,
808            num_attn_heads: cfg.num_attention_heads,
809            sliding_window: cfg.sliding_window,
810            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
811            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
812        };
813
814        Ok(Box::new(cfg))
815    }
816
817    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
818        Some(vec![NonMappedSubModel::Vision])
819    }
820}
821
822// ======================== LLaVANext Loader
823
824/// [`VisionLoader`] for an LLaVANext Vision model.
825///
826/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
827pub struct LLaVANextLoader;
828
829pub struct LLaVANextPrefixer;
830
831impl VisionPromptPrefixer for LLaVANextPrefixer {
832    fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
833        format!("<image>{prompt}")
834    }
835}
836
837impl VisionModelLoader for LLaVANextLoader {
838    fn load(
839        &self,
840        config: &str,
841        use_flash_attn: bool,
842        vb: ShardedVarBuilder,
843        normal_loading_metadata: NormalLoadingMetadata,
844        attention_mechanism: AttentionImplementation,
845    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
846        let mut config: LLaVAConfig = serde_json::from_str(config)?;
847        config.use_flash_attn = use_flash_attn;
848        Ok(Box::new(LLaVANext::new(
849            &config,
850            vb,
851            self.is_gptx(),
852            normal_loading_metadata,
853            attention_mechanism,
854        )?))
855    }
856    fn is_gptx(&self) -> bool {
857        false
858    }
859    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
860        let mut config: LLaVAConfig = serde_json::from_str(config)?;
861        config.use_flash_attn = use_flash_attn;
862        Ok(Box::new(config))
863    }
864    fn get_processor(
865        &self,
866        model_config: &str,
867        _processor_config: Option<ProcessorConfig>,
868        _preprocessor_config: PreProcessorConfig,
869        _max_edge: Option<u32>,
870    ) -> Arc<dyn Processor + Send + Sync> {
871        Arc::new(LLaVANextProcessor::new(model_config))
872    }
873    fn supports_paged_attention(&self) -> bool {
874        true
875    }
876    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
877        Arc::new(LLaVANextPrefixer)
878    }
879}
880
881impl IsqModelLoader for LLaVANextLoader {
882    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
883        Ok(vec![
884            Regex::new(r"lm_head\.(weight|bias)$")?,
885            // Attention
886            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
887            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
888            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
889            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
890            // MLP
891            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
892            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
893            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
894        ])
895    }
896}
897
898impl DeviceMappedModelLoader for LLaVANextLoader {
899    fn mapped_max_act_size_elems(
900        &self,
901        config: &str,
902        params: &AutoDeviceMapParams,
903        _prompt_chunksize: usize,
904    ) -> Result<usize> {
905        let AutoDeviceMapParams::Vision {
906            max_seq_len,
907            max_batch_size,
908            max_image_shape,
909            max_num_images,
910        } = params
911        else {
912            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
913        };
914
915        let config: LLaVAConfig = serde_json::from_str(config)?;
916
917        #[allow(clippy::cast_possible_truncation)]
918        let img_seq_len =
919            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
920                &config,
921                (max_image_shape.0 as u32, max_image_shape.1 as u32),
922            );
923        let img_seq_len = img_seq_len * max_num_images;
924
925        let max_text_attn = {
926            let cfg = &config.text_config;
927            // This model injects the vision information directly into the input embeddings
928            let max_seq_len = img_seq_len + max_seq_len;
929
930            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
931        };
932
933        Ok(max_text_attn)
934    }
935
936    fn non_mapped_max_act_size_elems(
937        &self,
938        config: &str,
939        params: &AutoDeviceMapParams,
940    ) -> Result<usize> {
941        let AutoDeviceMapParams::Vision {
942            max_seq_len: _,
943            max_batch_size,
944            max_image_shape,
945            max_num_images,
946        } = params
947        else {
948            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
949        };
950
951        let config: LLaVAConfig = serde_json::from_str(config)?;
952
953        #[allow(clippy::cast_possible_truncation)]
954        let img_seq_len =
955            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
956                &config,
957                (max_image_shape.0 as u32, max_image_shape.1 as u32),
958            );
959
960        let max_vision_attn = {
961            (max_batch_size * max_num_images)
962                * config.vision_config.num_attention_heads
963                * img_seq_len
964                * img_seq_len
965        };
966
967        Ok(max_vision_attn)
968    }
969
970    fn non_mapped_size_in_bytes(
971        &self,
972        config: &str,
973        dtype: DType,
974        weight_pack_factor: usize,
975    ) -> Result<usize> {
976        let cfg: LLaVAConfig = serde_json::from_str(config)?;
977        let text_elems = {
978            let cfg = &cfg.text_config;
979            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
980            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
981            let norm = cfg.hidden_size;
982            embed_tokens + lm_head + norm
983        };
984
985        let image_newline = cfg.text_config.hidden_size;
986        let mmproj = {
987            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
988                + cfg.text_config.hidden_size;
989            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
990                + cfg.text_config.hidden_size;
991
992            linear_1 + linear_2
993        };
994        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
995
996        let elems = text_elems + image_newline + mmproj + vision_tower;
997        Ok(elems * dtype.size_in_bytes())
998    }
999
1000    fn layer_sizes_in_bytes(
1001        &self,
1002        config: &str,
1003        dtype: DType,
1004        weight_pack_factor: usize,
1005    ) -> Result<Vec<usize>> {
1006        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1007        let per_layer_elems = {
1008            let cfg = &cfg.text_config;
1009            let input_layernorm = cfg.hidden_size;
1010            let post_attention_layernorm = cfg.hidden_size;
1011
1012            let size_in = cfg.hidden_size;
1013            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1014            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1015            let q_proj = size_in * size_q / weight_pack_factor;
1016            let k_proj = size_in * size_kv / weight_pack_factor;
1017            let v_proj = size_in * size_kv / weight_pack_factor;
1018            let o_proj = size_q * size_in / weight_pack_factor;
1019
1020            let h_size = cfg.hidden_size;
1021            let i_size = cfg.intermediate_size;
1022            let gate_proj = h_size * i_size / weight_pack_factor;
1023            let up_proj = h_size * i_size / weight_pack_factor;
1024            let down_proj = i_size * h_size / weight_pack_factor;
1025
1026            input_layernorm
1027                + post_attention_layernorm
1028                + q_proj
1029                + k_proj
1030                + v_proj
1031                + o_proj
1032                + gate_proj
1033                + up_proj
1034                + down_proj
1035        };
1036        Ok(vec![
1037            per_layer_elems * dtype.size_in_bytes();
1038            cfg.text_config.num_hidden_layers
1039        ])
1040    }
1041
1042    fn num_layers(&self, config: &str) -> Result<usize> {
1043        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1044        Ok(cfg.text_config.num_hidden_layers)
1045    }
1046
1047    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1048        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1049        let cfg = &cfg.text_config;
1050
1051        let cfg = ModelConfigMetadata {
1052            max_seq_len: cfg.max_position_embeddings,
1053            num_layers: cfg.num_hidden_layers,
1054            hidden_size: cfg.hidden_size,
1055            num_kv_heads: cfg.num_key_value_heads,
1056            num_attn_heads: cfg.num_attention_heads,
1057            sliding_window: cfg.sliding_window,
1058            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1059            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1060        };
1061
1062        Ok(Box::new(cfg))
1063    }
1064
1065    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1066        Some(vec![NonMappedSubModel::Vision])
1067    }
1068}
1069
1070// ======================== LLaVA Loader
1071
1072/// [`VisionLoader`] for an LLaVA Vision model.
1073///
1074/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1075pub struct LLaVALoader;
1076
1077pub struct LLaVAPrefixer;
1078
1079impl VisionPromptPrefixer for LLaVAPrefixer {
1080    fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
1081        format!("<image>{prompt}")
1082    }
1083}
1084
1085impl VisionModelLoader for LLaVALoader {
1086    fn load(
1087        &self,
1088        config: &str,
1089        use_flash_attn: bool,
1090        vb: ShardedVarBuilder,
1091        normal_loading_metadata: NormalLoadingMetadata,
1092        attention_mechanism: AttentionImplementation,
1093    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1094        let mut config: LLaVAConfig = serde_json::from_str(config)?;
1095        config.use_flash_attn = use_flash_attn;
1096        Ok(Box::new(LLaVA::new(
1097            &config,
1098            vb,
1099            self.is_gptx(),
1100            normal_loading_metadata,
1101            attention_mechanism,
1102        )?))
1103    }
1104    fn is_gptx(&self) -> bool {
1105        false
1106    }
1107    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1108        let mut config: LLaVAConfig = serde_json::from_str(config)?;
1109        config.use_flash_attn = use_flash_attn;
1110        Ok(Box::new(config))
1111    }
1112    fn get_processor(
1113        &self,
1114        model_config: &str,
1115        _processor_config: Option<ProcessorConfig>,
1116        _preprocessor_config: PreProcessorConfig,
1117        _max_edge: Option<u32>,
1118    ) -> Arc<dyn Processor + Send + Sync> {
1119        Arc::new(LLaVAProcessor::new(model_config))
1120    }
1121    fn supports_paged_attention(&self) -> bool {
1122        true
1123    }
1124    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
1125        Arc::new(LLaVAPrefixer)
1126    }
1127}
1128
1129impl IsqModelLoader for LLaVALoader {
1130    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1131        Ok(vec![
1132            Regex::new(r"lm_head\.(weight|bias)$")?,
1133            // Attention
1134            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1135            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1136            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1137            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1138            // MLP
1139            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1140            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1141            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1142        ])
1143    }
1144}
1145
1146impl DeviceMappedModelLoader for LLaVALoader {
1147    fn mapped_max_act_size_elems(
1148        &self,
1149        config: &str,
1150        params: &AutoDeviceMapParams,
1151        _prompt_chunksize: usize,
1152    ) -> Result<usize> {
1153        let AutoDeviceMapParams::Vision {
1154            max_seq_len,
1155            max_batch_size,
1156            max_image_shape: _,
1157            max_num_images,
1158        } = params
1159        else {
1160            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1161        };
1162
1163        let config: LLaVAConfig = serde_json::from_str(config)?;
1164
1165        let img_seq_len =
1166            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1167        let img_seq_len = img_seq_len * max_num_images;
1168
1169        let max_text_attn = {
1170            let cfg = &config.text_config;
1171            // This model injects the vision information directly into the input embeddings
1172            let max_seq_len = img_seq_len + max_seq_len;
1173
1174            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1175        };
1176
1177        Ok(max_text_attn)
1178    }
1179
1180    fn non_mapped_max_act_size_elems(
1181        &self,
1182        config: &str,
1183        params: &AutoDeviceMapParams,
1184    ) -> Result<usize> {
1185        let AutoDeviceMapParams::Vision {
1186            max_seq_len: _,
1187            max_batch_size,
1188            max_image_shape: _,
1189            max_num_images,
1190        } = params
1191        else {
1192            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1193        };
1194
1195        let config: LLaVAConfig = serde_json::from_str(config)?;
1196
1197        let img_seq_len =
1198            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1199
1200        let max_vision_attn = {
1201            (max_batch_size * max_num_images)
1202                * config.vision_config.num_attention_heads
1203                * img_seq_len
1204                * img_seq_len
1205        };
1206
1207        Ok(max_vision_attn)
1208    }
1209
1210    fn non_mapped_size_in_bytes(
1211        &self,
1212        config: &str,
1213        dtype: DType,
1214        weight_pack_factor: usize,
1215    ) -> Result<usize> {
1216        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1217        let text_elems = {
1218            let cfg = &cfg.text_config;
1219            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1220            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1221            let norm = cfg.hidden_size;
1222            embed_tokens + lm_head + norm
1223        };
1224
1225        let image_newline = cfg.text_config.hidden_size;
1226        let mmproj = {
1227            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1228                + cfg.text_config.hidden_size;
1229            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1230                + cfg.text_config.hidden_size;
1231
1232            linear_1 + linear_2
1233        };
1234        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1235
1236        let elems = text_elems + image_newline + mmproj + vision_tower;
1237        Ok(elems * dtype.size_in_bytes())
1238    }
1239
1240    fn layer_sizes_in_bytes(
1241        &self,
1242        config: &str,
1243        dtype: DType,
1244        weight_pack_factor: usize,
1245    ) -> Result<Vec<usize>> {
1246        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1247        let per_layer_elems = {
1248            let cfg = &cfg.text_config;
1249            let input_layernorm = cfg.hidden_size;
1250            let post_attention_layernorm = cfg.hidden_size;
1251
1252            let size_in = cfg.hidden_size;
1253            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1254            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1255            let q_proj = size_in * size_q / weight_pack_factor;
1256            let k_proj = size_in * size_kv / weight_pack_factor;
1257            let v_proj = size_in * size_kv / weight_pack_factor;
1258            let o_proj = size_q * size_in / weight_pack_factor;
1259
1260            let h_size = cfg.hidden_size;
1261            let i_size = cfg.intermediate_size;
1262            let gate_proj = h_size * i_size / weight_pack_factor;
1263            let up_proj = h_size * i_size / weight_pack_factor;
1264            let down_proj = i_size * h_size / weight_pack_factor;
1265
1266            input_layernorm
1267                + post_attention_layernorm
1268                + q_proj
1269                + k_proj
1270                + v_proj
1271                + o_proj
1272                + gate_proj
1273                + up_proj
1274                + down_proj
1275        };
1276        Ok(vec![
1277            per_layer_elems * dtype.size_in_bytes();
1278            cfg.text_config.num_hidden_layers
1279        ])
1280    }
1281
1282    fn num_layers(&self, config: &str) -> Result<usize> {
1283        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1284        Ok(cfg.text_config.num_hidden_layers)
1285    }
1286
1287    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1288        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1289        let cfg = &cfg.text_config;
1290
1291        let cfg = ModelConfigMetadata {
1292            max_seq_len: cfg.max_position_embeddings,
1293            num_layers: cfg.num_hidden_layers,
1294            hidden_size: cfg.hidden_size,
1295            num_kv_heads: cfg.num_key_value_heads,
1296            num_attn_heads: cfg.num_attention_heads,
1297            sliding_window: cfg.sliding_window,
1298            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1299            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1300        };
1301
1302        Ok(Box::new(cfg))
1303    }
1304
1305    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1306        Some(vec![NonMappedSubModel::Vision])
1307    }
1308}
1309
1310// ======================== MLlama Loader
1311
1312/// [`VisionLoader`] for an Llama Vision model.
1313///
1314/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1315pub struct VLlamaLoader;
1316
1317pub struct VLlamaPrefixer;
1318
1319impl VisionPromptPrefixer for VLlamaPrefixer {
1320    fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
1321        format!("<|image|>{prompt}")
1322    }
1323}
1324
1325impl VisionModelLoader for VLlamaLoader {
1326    fn load(
1327        &self,
1328        config: &str,
1329        use_flash_attn: bool,
1330        vb: ShardedVarBuilder,
1331        normal_loading_metadata: NormalLoadingMetadata,
1332        attention_mechanism: AttentionImplementation,
1333    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1334        let mut config: MLlamaConfig = serde_json::from_str(config)?;
1335        config.text_config.use_flash_attn = use_flash_attn;
1336        Ok(Box::new(MLlamaModel::new(
1337            &config,
1338            vb,
1339            self.is_gptx(),
1340            normal_loading_metadata,
1341            attention_mechanism,
1342        )?))
1343    }
1344    fn is_gptx(&self) -> bool {
1345        true
1346    }
1347    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1348        let mut config: MLlamaConfig = serde_json::from_str(config)?;
1349        config.text_config.use_flash_attn = use_flash_attn;
1350        Ok(Box::new(config))
1351    }
1352    fn get_processor(
1353        &self,
1354        _model_config: &str,
1355        _processor_config: Option<ProcessorConfig>,
1356        _preprocessor_config: PreProcessorConfig,
1357        _max_edge: Option<u32>,
1358    ) -> Arc<dyn Processor + Send + Sync> {
1359        Arc::new(MLlamaProcessor::new())
1360    }
1361    fn supports_paged_attention(&self) -> bool {
1362        false
1363    }
1364    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
1365        Arc::new(VLlamaPrefixer)
1366    }
1367}
1368
1369impl IsqModelLoader for VLlamaLoader {
1370    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1371        let config: MLlamaConfig = serde_json::from_str(config)?;
1372        let cross_attn_layers = &config.text_config.cross_attention_layers;
1373        let transformer_layers =
1374            (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1375        let mut text_regexes = Vec::new();
1376        for layer in transformer_layers {
1377            text_regexes.extend(vec![
1378                // Attention text
1379                Regex::new(&format!(
1380                    r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1381                ))?,
1382                Regex::new(&format!(
1383                    r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1384                ))?,
1385                Regex::new(&format!(
1386                    r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1387                ))?,
1388                Regex::new(&format!(
1389                    r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1390                ))?,
1391                // MLP text
1392                Regex::new(&format!(
1393                    r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1394                ))?,
1395                Regex::new(&format!(
1396                    r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1397                ))?,
1398                Regex::new(&format!(
1399                    r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1400                ))?,
1401            ]);
1402        }
1403        let vision_regexes = vec![
1404            // Vision attention (transformer)
1405            Regex::new(
1406                r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1407            )?,
1408            Regex::new(
1409                r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1410            )?,
1411            Regex::new(
1412                r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1413            )?,
1414            Regex::new(
1415                r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1416            )?,
1417            // Vision attention (global transforemr)
1418            Regex::new(
1419                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1420            )?,
1421            Regex::new(
1422                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1423            )?,
1424            Regex::new(
1425                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1426            )?,
1427            Regex::new(
1428                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1429            )?,
1430            // MLP vision
1431            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1432            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1433        ];
1434
1435        Ok([text_regexes, vision_regexes].concat())
1436    }
1437}
1438
1439impl DeviceMappedModelLoader for VLlamaLoader {
1440    fn mapped_max_act_size_elems(
1441        &self,
1442        config: &str,
1443        params: &AutoDeviceMapParams,
1444        _prompt_chunksize: usize,
1445    ) -> Result<usize> {
1446        let AutoDeviceMapParams::Vision {
1447            max_seq_len,
1448            max_batch_size,
1449            max_image_shape: _,
1450            max_num_images,
1451        } = params
1452        else {
1453            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1454        };
1455
1456        let config: MLlamaConfig = serde_json::from_str(config)?;
1457
1458        let img_seq_len = {
1459            let cfg = &config.vision_config;
1460            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1461            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1462            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1463        };
1464        let img_seq_len = img_seq_len * max_num_images;
1465
1466        let max_cross_text_attn = {
1467            let cfg = &config.text_config;
1468            max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1469        };
1470
1471        let max_self_text_attn = {
1472            let cfg = &config.text_config;
1473            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1474        };
1475
1476        Ok(max_self_text_attn.max(max_cross_text_attn))
1477    }
1478
1479    fn non_mapped_max_act_size_elems(
1480        &self,
1481        config: &str,
1482        params: &AutoDeviceMapParams,
1483    ) -> Result<usize> {
1484        let AutoDeviceMapParams::Vision {
1485            max_seq_len: _,
1486            max_batch_size,
1487            max_image_shape: _,
1488            max_num_images,
1489        } = params
1490        else {
1491            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1492        };
1493
1494        let config: MLlamaConfig = serde_json::from_str(config)?;
1495
1496        let img_seq_len = {
1497            let cfg = &config.vision_config;
1498            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1499            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1500            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1501        };
1502        let max_vision_attn = {
1503            let cfg = &config.vision_config;
1504            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1505        };
1506
1507        Ok(max_vision_attn)
1508    }
1509
1510    fn non_mapped_size_in_bytes(
1511        &self,
1512        config: &str,
1513        dtype: DType,
1514        weight_pack_factor: usize,
1515    ) -> Result<usize> {
1516        let config: MLlamaConfig = serde_json::from_str(config)?;
1517        let text_elems = {
1518            let cfg = &config.text_config;
1519            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1520            let lm_head = if !cfg.tie_word_embeddings {
1521                cfg.hidden_size * cfg.vocab_size
1522            } else {
1523                0
1524            };
1525            let norm = cfg.hidden_size;
1526            embed_tokens + lm_head + norm
1527        };
1528
1529        let vision_elems = {
1530            let cfg = &config.vision_config;
1531
1532            let conv_cfg = Conv2dConfig {
1533                stride: cfg.patch_size,
1534                ..Default::default()
1535            };
1536            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1537                * cfg.patch_size
1538                * cfg.patch_size;
1539
1540            let class_embedding = cfg.hidden_size;
1541
1542            let gated_positional_embedding = {
1543                let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1544                let embedding = num_patches * cfg.hidden_size;
1545                let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1546                    * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1547
1548                embedding + tile_embedding
1549            };
1550
1551            let pre_tile_positional_embedding =
1552                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1553            let post_tile_positional_embedding =
1554                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1555
1556            let layernorm_pre = cfg.hidden_size;
1557            let layernorm_post = cfg.hidden_size;
1558
1559            let encoder_layer = {
1560                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1561                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1562
1563                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1564                let q_proj =
1565                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1566                let k_proj =
1567                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1568                let v_proj =
1569                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1570                let o_proj =
1571                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1572
1573                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1574                    + cfg.intermediate_size;
1575                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1576                    + cfg.hidden_size;
1577
1578                input_layernorm
1579                    + post_attention_layernorm
1580                    + q_proj
1581                    + k_proj
1582                    + v_proj
1583                    + o_proj
1584                    + fc1
1585                    + fc2
1586            };
1587
1588            patch_embedding
1589                + class_embedding
1590                + gated_positional_embedding
1591                + pre_tile_positional_embedding
1592                + post_tile_positional_embedding
1593                + layernorm_pre
1594                + layernorm_post
1595                + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1596        };
1597
1598        let elems = text_elems + vision_elems;
1599        Ok(elems * dtype.size_in_bytes())
1600    }
1601
1602    fn layer_sizes_in_bytes(
1603        &self,
1604        config: &str,
1605        dtype: DType,
1606        weight_pack_factor: usize,
1607    ) -> Result<Vec<usize>> {
1608        let config: MLlamaConfig = serde_json::from_str(config)?;
1609        let cfg = &config.text_config;
1610
1611        let mut layer_sizes = Vec::new();
1612
1613        for i in 0..cfg.num_hidden_layers {
1614            let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1615                // No isq for cross attention
1616                1
1617            } else {
1618                weight_pack_factor
1619            };
1620
1621            let per_layer_elems = {
1622                let input_layernorm = cfg.hidden_size;
1623                let post_attention_layernorm = cfg.hidden_size;
1624
1625                let size_in = cfg.hidden_size;
1626                let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1627                let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1628                let q_proj = size_in * size_q / weight_pack_factor;
1629                let k_proj = size_in * size_kv / weight_pack_factor;
1630                let v_proj = size_in * size_kv / weight_pack_factor;
1631                let o_proj = size_q * size_in / weight_pack_factor;
1632
1633                let h_size = cfg.hidden_size;
1634                let i_size = cfg.intermediate_size;
1635                let gate_proj = h_size * i_size / weight_pack_factor;
1636                let up_proj = h_size * i_size / weight_pack_factor;
1637                let down_proj = i_size * h_size / weight_pack_factor;
1638
1639                input_layernorm
1640                    + post_attention_layernorm
1641                    + q_proj
1642                    + k_proj
1643                    + v_proj
1644                    + o_proj
1645                    + gate_proj
1646                    + up_proj
1647                    + down_proj
1648            };
1649
1650            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1651        }
1652
1653        Ok(layer_sizes)
1654    }
1655
1656    fn num_layers(&self, config: &str) -> Result<usize> {
1657        let config: MLlamaConfig = serde_json::from_str(config)?;
1658        Ok(config.text_config.num_hidden_layers)
1659    }
1660
1661    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1662        let cfg: MLlamaConfig = serde_json::from_str(config)?;
1663        let cfg = &cfg.text_config;
1664
1665        let cfg = ModelConfigMetadata {
1666            max_seq_len: cfg.max_position_embeddings,
1667            num_layers: cfg.num_hidden_layers,
1668            hidden_size: cfg.hidden_size,
1669            num_kv_heads: cfg.num_key_value_heads,
1670            num_attn_heads: cfg.num_attention_heads,
1671            sliding_window: None,
1672            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1673            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1674        };
1675
1676        Ok(Box::new(cfg))
1677    }
1678
1679    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1680        Some(vec![NonMappedSubModel::Vision])
1681    }
1682}
1683
1684// ======================== Qwen2VL Loader
1685
1686/// [`VisionLoader`] for an Qwen2-VL model.
1687///
1688/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1689pub struct Qwen2VLLoader;
1690
1691pub struct Qwen2VLPrefixer;
1692
1693impl VisionPromptPrefixer for Qwen2VLPrefixer {
1694    fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
1695        format!(
1696            "{}{}{}{prompt}",
1697            Qwen2VLProcessor::VISION_START,
1698            Qwen2VLProcessor::IMAGE_PAD,
1699            Qwen2VLProcessor::VISION_END
1700        )
1701    }
1702}
1703
1704impl VisionModelLoader for Qwen2VLLoader {
1705    fn load(
1706        &self,
1707        config: &str,
1708        _use_flash_attn: bool,
1709        vb: ShardedVarBuilder,
1710        normal_loading_metadata: NormalLoadingMetadata,
1711        attention_mechanism: AttentionImplementation,
1712    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1713        let config: Qwen2VLConfig = serde_json::from_str(config)?;
1714        Ok(Box::new(Qwen2VLModel::new(
1715            &config,
1716            vb,
1717            self.is_gptx(),
1718            normal_loading_metadata,
1719            attention_mechanism,
1720        )?))
1721    }
1722    fn is_gptx(&self) -> bool {
1723        true
1724    }
1725    fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1726        let config: Qwen2VLConfig = serde_json::from_str(config)?;
1727        Ok(Box::new(config))
1728    }
1729    fn get_processor(
1730        &self,
1731        _model_config: &str,
1732        _processor_config: Option<ProcessorConfig>,
1733        _preprocessor_config: PreProcessorConfig,
1734        max_edge: Option<u32>,
1735    ) -> Arc<dyn Processor + Send + Sync> {
1736        Arc::new(Qwen2VLProcessor::new(max_edge))
1737    }
1738    fn supports_paged_attention(&self) -> bool {
1739        false
1740    }
1741    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
1742        Arc::new(Qwen2VLPrefixer)
1743    }
1744}
1745
1746impl IsqModelLoader for Qwen2VLLoader {
1747    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1748        Ok(vec![
1749            Regex::new(r"lm_head\.(weight|bias)$")?,
1750            // Attention
1751            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1752            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1753            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1754            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1755            // MLP
1756            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1757            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1758            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1759        ])
1760    }
1761}
1762
1763impl DeviceMappedModelLoader for Qwen2VLLoader {
1764    fn mapped_max_act_size_elems(
1765        &self,
1766        config: &str,
1767        params: &AutoDeviceMapParams,
1768        _prompt_chunksize: usize,
1769    ) -> Result<usize> {
1770        let AutoDeviceMapParams::Vision {
1771            max_seq_len,
1772            max_batch_size,
1773            max_image_shape,
1774            max_num_images,
1775        } = params
1776        else {
1777            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1778        };
1779
1780        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1781
1782        let img_seq_len = {
1783            let cfg = &cfg.vision_config;
1784            let grid_t = max_num_images / cfg.temporal_patch_size;
1785            let grid_h = max_image_shape.0 / cfg.patch_size;
1786            let grid_w = max_image_shape.1 / cfg.patch_size;
1787            grid_t * grid_h * grid_w
1788        };
1789        let img_seq_len = img_seq_len * max_num_images;
1790
1791        let max_text_attn = {
1792            // This model injects the vision information directly into the input embeddings
1793            let max_seq_len = img_seq_len + max_seq_len;
1794            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1795        };
1796
1797        Ok(max_text_attn)
1798    }
1799
1800    fn non_mapped_max_act_size_elems(
1801        &self,
1802        config: &str,
1803        params: &AutoDeviceMapParams,
1804    ) -> Result<usize> {
1805        let AutoDeviceMapParams::Vision {
1806            max_seq_len: _,
1807            max_batch_size,
1808            max_image_shape,
1809            max_num_images,
1810        } = params
1811        else {
1812            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1813        };
1814
1815        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1816
1817        let img_seq_len = {
1818            let cfg = &cfg.vision_config;
1819            let grid_t = max_num_images / cfg.temporal_patch_size;
1820            let grid_h = max_image_shape.0 / cfg.patch_size;
1821            let grid_w = max_image_shape.1 / cfg.patch_size;
1822            grid_t * grid_h * grid_w
1823        };
1824
1825        let max_vision_attn = {
1826            let cfg = &cfg.vision_config;
1827            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
1828        };
1829
1830        Ok(max_vision_attn)
1831    }
1832
1833    fn non_mapped_size_in_bytes(
1834        &self,
1835        config: &str,
1836        dtype: DType,
1837        weight_pack_factor: usize,
1838    ) -> Result<usize> {
1839        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1840        let text_elems = {
1841            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1842            let lm_head = if !cfg.tie_word_embeddings {
1843                cfg.hidden_size * cfg.vocab_size
1844            } else {
1845                0
1846            };
1847            let norm = cfg.hidden_size;
1848            embed_tokens + lm_head + norm
1849        };
1850
1851        let patch_merger = {
1852            let cfg = &cfg.vision_config;
1853            let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
1854
1855            let mlp0 = hidden_size * hidden_size + hidden_size;
1856            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
1857
1858            let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
1859
1860            mlp0 + mlp2 + ln_q
1861        };
1862
1863        let patch_embed = {
1864            let cfg = &cfg.vision_config;
1865            let conv_cfg = Conv3dConfig {
1866                stride: cfg.patch_size,
1867                ..Default::default()
1868            };
1869            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
1870            cfg.in_channels * cfg.embed_dim / conv_cfg.groups
1871                * kernel_sizes[0]
1872                * kernel_sizes[1]
1873                * kernel_sizes[2]
1874        };
1875
1876        let encoder_layer = {
1877            let cfg = &cfg.vision_config;
1878            let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
1879            let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
1880
1881            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
1882            let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
1883            let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
1884            let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
1885
1886            let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
1887            let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
1888
1889            norm1 + norm2 + fc1 + fc2 + qkv + out
1890        };
1891
1892        let elems =
1893            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
1894
1895        Ok(elems * dtype.size_in_bytes())
1896    }
1897
1898    fn layer_sizes_in_bytes(
1899        &self,
1900        config: &str,
1901        dtype: DType,
1902        weight_pack_factor: usize,
1903    ) -> Result<Vec<usize>> {
1904        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1905        let per_layer_elems = {
1906            let input_layernorm = cfg.hidden_size;
1907            let post_attention_layernorm = cfg.hidden_size;
1908
1909            let size_in = cfg.hidden_size;
1910            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1911            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1912            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1913            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1914            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1915            let o_proj = size_q * size_in / weight_pack_factor;
1916
1917            let h_size = cfg.hidden_size;
1918            let i_size = cfg.intermediate_size;
1919            let gate_proj = h_size * i_size / weight_pack_factor;
1920            let up_proj = h_size * i_size / weight_pack_factor;
1921            let down_proj = i_size * h_size / weight_pack_factor;
1922
1923            input_layernorm
1924                + post_attention_layernorm
1925                + q_proj
1926                + k_proj
1927                + v_proj
1928                + o_proj
1929                + gate_proj
1930                + up_proj
1931                + down_proj
1932        };
1933        Ok(vec![
1934            per_layer_elems * dtype.size_in_bytes();
1935            cfg.num_hidden_layers
1936        ])
1937    }
1938
1939    fn num_layers(&self, config: &str) -> Result<usize> {
1940        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1941        Ok(cfg.num_hidden_layers)
1942    }
1943
1944    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1945        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1946
1947        let cfg = ModelConfigMetadata {
1948            max_seq_len: cfg.max_position_embeddings,
1949            num_layers: cfg.num_hidden_layers,
1950            hidden_size: cfg.hidden_size,
1951            num_kv_heads: cfg.num_key_value_heads,
1952            num_attn_heads: cfg.num_attention_heads,
1953            sliding_window: cfg.sliding_window,
1954            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1955            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1956        };
1957
1958        Ok(Box::new(cfg))
1959    }
1960
1961    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1962        Some(vec![NonMappedSubModel::Vision])
1963    }
1964}
1965
1966// ======================== Idefics 3 loader
1967
1968/// [`VisionLoader`] for an Idefics 3 Vision model.
1969///
1970/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
1971pub struct Idefics3Loader;
1972
1973pub struct Idefics3Prefixer;
1974
1975impl VisionPromptPrefixer for Idefics3Prefixer {
1976    fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
1977        // Chat template does it
1978        prompt.to_string()
1979    }
1980}
1981
1982impl VisionModelLoader for Idefics3Loader {
1983    fn load(
1984        &self,
1985        config: &str,
1986        use_flash_attn: bool,
1987        vb: ShardedVarBuilder,
1988        normal_loading_metadata: NormalLoadingMetadata,
1989        attention_mechanism: AttentionImplementation,
1990    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1991        let mut config: Idefics3Config = serde_json::from_str(config)?;
1992        config.text_config.use_flash_attn = use_flash_attn;
1993        Ok(Box::new(Idefics3Model::new(
1994            &config,
1995            vb,
1996            self.is_gptx(),
1997            normal_loading_metadata,
1998            attention_mechanism,
1999        )?))
2000    }
2001    fn is_gptx(&self) -> bool {
2002        true
2003    }
2004    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2005        let mut config: Idefics3Config = serde_json::from_str(config)?;
2006        config.text_config.use_flash_attn = use_flash_attn;
2007        Ok(Box::new(config))
2008    }
2009    fn get_processor(
2010        &self,
2011        _model_config: &str,
2012        processor_config: Option<ProcessorConfig>,
2013        preprocessor_config: PreProcessorConfig,
2014        max_edge: Option<u32>,
2015    ) -> Arc<dyn Processor + Send + Sync> {
2016        Arc::new(Idefics3Processor::new(
2017            processor_config.unwrap_or_default(),
2018            preprocessor_config,
2019            max_edge,
2020        ))
2021    }
2022    fn supports_paged_attention(&self) -> bool {
2023        true
2024    }
2025    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2026        Arc::new(Idefics3Prefixer)
2027    }
2028}
2029
2030impl IsqModelLoader for Idefics3Loader {
2031    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2032        Ok(vec![
2033            Regex::new(r"lm_head\.(weight|bias)$")?,
2034            // Attention
2035            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2036            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2037            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2038            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2039            // MLP
2040            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2041            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2042            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2043        ])
2044    }
2045}
2046
2047impl DeviceMappedModelLoader for Idefics3Loader {
2048    fn mapped_max_act_size_elems(
2049        &self,
2050        config: &str,
2051        params: &AutoDeviceMapParams,
2052        _prompt_chunksize: usize,
2053    ) -> Result<usize> {
2054        let AutoDeviceMapParams::Vision {
2055            max_seq_len,
2056            max_batch_size,
2057            max_image_shape: _,
2058            max_num_images,
2059        } = params
2060        else {
2061            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2062        };
2063
2064        let cfg: Idefics3Config = serde_json::from_str(config)?;
2065
2066        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2067        let img_seq_len = (num_patches + 1) * max_num_images;
2068
2069        let max_text_attn = {
2070            // This model injects the vision information directly into the input embeddings
2071            let max_seq_len = img_seq_len + max_seq_len;
2072            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2073        };
2074
2075        Ok(max_text_attn)
2076    }
2077
2078    fn non_mapped_max_act_size_elems(
2079        &self,
2080        config: &str,
2081        params: &AutoDeviceMapParams,
2082    ) -> Result<usize> {
2083        let AutoDeviceMapParams::Vision {
2084            max_seq_len: _,
2085            max_batch_size,
2086            max_image_shape: _,
2087            max_num_images,
2088        } = params
2089        else {
2090            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2091        };
2092
2093        let cfg: Idefics3Config = serde_json::from_str(config)?;
2094
2095        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2096        let img_seq_len = num_patches + 1;
2097
2098        let max_vision_attn = {
2099            // do_image_splitting = true
2100            let images_factor = 5;
2101
2102            (max_batch_size * images_factor * max_num_images)
2103                * cfg.vision_config.num_attention_heads
2104                * img_seq_len
2105                * img_seq_len
2106        };
2107
2108        Ok(max_vision_attn)
2109    }
2110
2111    fn non_mapped_size_in_bytes(
2112        &self,
2113        config: &str,
2114        dtype: DType,
2115        weight_pack_factor: usize,
2116    ) -> Result<usize> {
2117        let cfg: Idefics3Config = serde_json::from_str(config)?;
2118        let text_elems = {
2119            let cfg = &cfg.text_config;
2120
2121            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2122            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2123            let norm = cfg.hidden_size;
2124            embed_tokens + lm_head + norm
2125        };
2126
2127        let connector_elems = {
2128            let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2129            let out_dim = cfg.text_config.hidden_size;
2130
2131            in_dim * out_dim
2132        };
2133
2134        let vision_transformer = {
2135            let cfg = &cfg.vision_config;
2136
2137            let post_layernorm = cfg.hidden_size;
2138
2139            let conv_config = Conv2dConfig {
2140                stride: cfg.patch_size,
2141                ..Default::default()
2142            };
2143            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2144                * cfg.patch_size
2145                * cfg.patch_size;
2146
2147            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2148            let num_patches = num_patches_per_side.pow(2);
2149            let position_embedding = num_patches * cfg.hidden_size;
2150
2151            let layer_elems = {
2152                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2153                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2154
2155                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2156                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2157
2158                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2159                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2160                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2161                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2162
2163                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2164            };
2165
2166            post_layernorm
2167                + patch_embedding
2168                + position_embedding
2169                + layer_elems * cfg.num_hidden_layers
2170        };
2171
2172        let elems = text_elems + connector_elems + vision_transformer;
2173
2174        Ok(elems * dtype.size_in_bytes())
2175    }
2176
2177    fn layer_sizes_in_bytes(
2178        &self,
2179        config: &str,
2180        dtype: DType,
2181        weight_pack_factor: usize,
2182    ) -> Result<Vec<usize>> {
2183        let cfg: Idefics3Config = serde_json::from_str(config)?;
2184        let cfg = cfg.text_config;
2185        let per_layer_elems = {
2186            let input_layernorm = cfg.hidden_size;
2187            let post_attention_layernorm = cfg.hidden_size;
2188
2189            let size_in = cfg.hidden_size;
2190            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2191            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2192            let q_proj = size_in * size_q / weight_pack_factor;
2193            let k_proj = size_in * size_kv / weight_pack_factor;
2194            let v_proj = size_in * size_kv / weight_pack_factor;
2195            let o_proj = size_q * size_in / weight_pack_factor;
2196
2197            let h_size = cfg.hidden_size;
2198            let i_size = cfg.intermediate_size;
2199            let gate_proj = h_size * i_size / weight_pack_factor;
2200            let up_proj = h_size * i_size / weight_pack_factor;
2201            let down_proj = i_size * h_size / weight_pack_factor;
2202
2203            input_layernorm
2204                + post_attention_layernorm
2205                + q_proj
2206                + k_proj
2207                + v_proj
2208                + o_proj
2209                + gate_proj
2210                + up_proj
2211                + down_proj
2212        };
2213        Ok(vec![
2214            per_layer_elems * dtype.size_in_bytes();
2215            cfg.num_hidden_layers
2216        ])
2217    }
2218
2219    fn num_layers(&self, config: &str) -> Result<usize> {
2220        let cfg: Idefics3Config = serde_json::from_str(config)?;
2221        Ok(cfg.text_config.num_hidden_layers)
2222    }
2223    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2224        let cfg: Idefics3Config = serde_json::from_str(config)?;
2225        let cfg = &cfg.text_config;
2226
2227        let cfg = ModelConfigMetadata {
2228            max_seq_len: cfg.max_position_embeddings,
2229            num_layers: cfg.num_hidden_layers,
2230            hidden_size: cfg.hidden_size,
2231            num_kv_heads: cfg.num_key_value_heads,
2232            num_attn_heads: cfg.num_attention_heads,
2233            sliding_window: None,
2234            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2235            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2236        };
2237
2238        Ok(Box::new(cfg))
2239    }
2240
2241    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2242        Some(vec![NonMappedSubModel::Vision])
2243    }
2244}
2245
2246// ======================== MiniCpm-O loader
2247
2248/// [`VisionLoader`] for an MiniCpm-O model.
2249///
2250/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2251pub struct MiniCpmOLoader;
2252
2253pub struct MiniCpmOPrefixer;
2254
2255impl VisionPromptPrefixer for MiniCpmOPrefixer {
2256    fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
2257        format!("(<image>./</image>){prompt}")
2258    }
2259}
2260
2261impl VisionModelLoader for MiniCpmOLoader {
2262    fn load(
2263        &self,
2264        config: &str,
2265        use_flash_attn: bool,
2266        vb: ShardedVarBuilder,
2267        normal_loading_metadata: NormalLoadingMetadata,
2268        attention_mechanism: AttentionImplementation,
2269    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2270        let mut config: MiniCpmOConfig = serde_json::from_str(config)?;
2271        config.text_config.use_flash_attn = use_flash_attn;
2272        Ok(Box::new(MiniCpmOModel::new(
2273            &config,
2274            vb,
2275            self.is_gptx(),
2276            normal_loading_metadata,
2277            attention_mechanism,
2278        )?))
2279    }
2280    fn is_gptx(&self) -> bool {
2281        true
2282    }
2283    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2284        let mut config: MiniCpmOConfig = serde_json::from_str(config)?;
2285        config.text_config.use_flash_attn = use_flash_attn;
2286        Ok(Box::new(config))
2287    }
2288    fn get_processor(
2289        &self,
2290        _model_config: &str,
2291        processor_config: Option<ProcessorConfig>,
2292        preprocessor_config: PreProcessorConfig,
2293        max_edge: Option<u32>,
2294    ) -> Arc<dyn Processor + Send + Sync> {
2295        Arc::new(MiniCpmOProcessor::new(
2296            processor_config.unwrap_or_default(),
2297            preprocessor_config,
2298            max_edge,
2299        ))
2300    }
2301    fn supports_paged_attention(&self) -> bool {
2302        true
2303    }
2304    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2305        Arc::new(MiniCpmOPrefixer)
2306    }
2307}
2308
2309impl IsqModelLoader for MiniCpmOLoader {
2310    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2311        Ok(vec![
2312            Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2313            // Attention
2314            Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2315            Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2316            Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2317            Regex::new(r"llm.layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2318            // MLP
2319            Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2320            Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2321            Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2322        ])
2323    }
2324}
2325
2326impl DeviceMappedModelLoader for MiniCpmOLoader {
2327    fn mapped_max_act_size_elems(
2328        &self,
2329        config: &str,
2330        params: &AutoDeviceMapParams,
2331        _prompt_chunksize: usize,
2332    ) -> Result<usize> {
2333        let AutoDeviceMapParams::Vision {
2334            max_seq_len,
2335            max_batch_size,
2336            max_image_shape: _,
2337            max_num_images,
2338        } = params
2339        else {
2340            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2341        };
2342
2343        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2344
2345        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2346        let img_seq_len = (num_patches + 1) * max_num_images;
2347
2348        let max_text_attn = {
2349            // This model injects the vision information directly into the input embeddings
2350            let max_seq_len = img_seq_len + max_seq_len;
2351            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2352        };
2353
2354        Ok(max_text_attn)
2355    }
2356
2357    fn non_mapped_max_act_size_elems(
2358        &self,
2359        config: &str,
2360        params: &AutoDeviceMapParams,
2361    ) -> Result<usize> {
2362        let AutoDeviceMapParams::Vision {
2363            max_seq_len: _,
2364            max_batch_size,
2365            max_image_shape: _,
2366            max_num_images,
2367        } = params
2368        else {
2369            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2370        };
2371
2372        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2373
2374        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2375        let img_seq_len = num_patches + 1;
2376
2377        let max_vision_attn = {
2378            // do_image_splitting = true
2379            let images_factor = 5;
2380
2381            (max_batch_size * images_factor * max_num_images)
2382                * cfg.vision_config.num_attention_heads
2383                * img_seq_len
2384                * img_seq_len
2385        };
2386
2387        Ok(max_vision_attn)
2388    }
2389
2390    fn non_mapped_size_in_bytes(
2391        &self,
2392        config: &str,
2393        dtype: DType,
2394        weight_pack_factor: usize,
2395    ) -> Result<usize> {
2396        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2397        let text_elems = {
2398            let cfg = &cfg.text_config;
2399
2400            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2401            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2402            let norm = cfg.hidden_size;
2403            embed_tokens + lm_head + norm
2404        };
2405
2406        let vision_transformer = {
2407            let cfg = &cfg.vision_config;
2408
2409            let post_layernorm = cfg.hidden_size;
2410
2411            let conv_config = Conv2dConfig {
2412                stride: cfg.patch_size,
2413                ..Default::default()
2414            };
2415            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2416                * cfg.patch_size
2417                * cfg.patch_size;
2418
2419            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2420            let num_patches = num_patches_per_side.pow(2);
2421            let position_embedding = num_patches * cfg.hidden_size;
2422
2423            let layer_elems = {
2424                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2425                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2426
2427                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2428                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2429
2430                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2431                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2432                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2433                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2434
2435                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2436            };
2437
2438            post_layernorm
2439                + patch_embedding
2440                + position_embedding
2441                + layer_elems * cfg.num_hidden_layers
2442        };
2443
2444        let elems = text_elems + vision_transformer;
2445
2446        Ok(elems * dtype.size_in_bytes())
2447    }
2448
2449    fn layer_sizes_in_bytes(
2450        &self,
2451        config: &str,
2452        dtype: DType,
2453        weight_pack_factor: usize,
2454    ) -> Result<Vec<usize>> {
2455        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2456        let cfg = cfg.text_config;
2457        let per_layer_elems = {
2458            let input_layernorm = cfg.hidden_size;
2459            let post_attention_layernorm = cfg.hidden_size;
2460
2461            let size_in = cfg.hidden_size;
2462            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2463            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2464            let q_proj = size_in * size_q / weight_pack_factor;
2465            let k_proj = size_in * size_kv / weight_pack_factor;
2466            let v_proj = size_in * size_kv / weight_pack_factor;
2467            let o_proj = size_q * size_in / weight_pack_factor;
2468
2469            let h_size = cfg.hidden_size;
2470            let i_size = cfg.intermediate_size;
2471            let gate_proj = h_size * i_size / weight_pack_factor;
2472            let up_proj = h_size * i_size / weight_pack_factor;
2473            let down_proj = i_size * h_size / weight_pack_factor;
2474
2475            input_layernorm
2476                + post_attention_layernorm
2477                + q_proj
2478                + k_proj
2479                + v_proj
2480                + o_proj
2481                + gate_proj
2482                + up_proj
2483                + down_proj
2484        };
2485        Ok(vec![
2486            per_layer_elems * dtype.size_in_bytes();
2487            cfg.num_hidden_layers
2488        ])
2489    }
2490
2491    fn num_layers(&self, config: &str) -> Result<usize> {
2492        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2493        Ok(cfg.text_config.num_hidden_layers)
2494    }
2495    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2496        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2497        let cfg = &cfg.text_config;
2498
2499        let cfg = ModelConfigMetadata {
2500            max_seq_len: cfg.max_position_embeddings,
2501            num_layers: cfg.num_hidden_layers,
2502            hidden_size: cfg.hidden_size,
2503            num_kv_heads: cfg.num_key_value_heads,
2504            num_attn_heads: cfg.num_attention_heads,
2505            sliding_window: None,
2506            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2507            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2508        };
2509
2510        Ok(Box::new(cfg))
2511    }
2512}
2513
2514// ======================== Phi 4MM loader
2515
2516/// [`VisionLoader`] for a Phi 4MM Vision model.
2517///
2518/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2519pub struct Phi4MMLoader;
2520
2521pub struct Phi4MMPrefixer;
2522
2523impl VisionPromptPrefixer for Phi4MMPrefixer {
2524    fn prefix_image(&self, image_index: usize, prompt: &str) -> String {
2525        // Image indexing starts at 0.
2526        format!("<|image_{}|>{prompt}", image_index + 1)
2527    }
2528}
2529
2530impl VisionModelLoader for Phi4MMLoader {
2531    fn load(
2532        &self,
2533        config: &str,
2534        use_flash_attn: bool,
2535        vb: ShardedVarBuilder,
2536        normal_loading_metadata: NormalLoadingMetadata,
2537        attention_mechanism: AttentionImplementation,
2538    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2539        let mut config: Phi4MMConfig = serde_json::from_str(config)?;
2540        config.use_flash_attn = use_flash_attn;
2541        Ok(Box::new(Phi4MMModel::new(
2542            &config,
2543            vb,
2544            self.is_gptx(),
2545            normal_loading_metadata,
2546            attention_mechanism,
2547        )?))
2548    }
2549    fn is_gptx(&self) -> bool {
2550        true
2551    }
2552    fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2553        let mut config: Phi4MMConfig = serde_json::from_str(config)?;
2554        config.use_flash_attn = use_flash_attn;
2555        Ok(Box::new(config))
2556    }
2557    fn get_processor(
2558        &self,
2559        _model_config: &str,
2560        processor_config: Option<ProcessorConfig>,
2561        preprocessor_config: PreProcessorConfig,
2562        _max_edge: Option<u32>,
2563    ) -> Arc<dyn Processor + Send + Sync> {
2564        Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2565    }
2566    fn supports_paged_attention(&self) -> bool {
2567        true
2568    }
2569    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2570        Arc::new(Phi4MMPrefixer)
2571    }
2572}
2573
2574impl IsqModelLoader for Phi4MMLoader {
2575    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2576        Ok(vec![
2577            Regex::new(r"lm_head\.(weight|bias)$")?,
2578            // Attention
2579            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
2580            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2581            // MLP
2582            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
2583            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2584        ])
2585    }
2586}
2587
2588impl DeviceMappedModelLoader for Phi4MMLoader {
2589    fn mapped_max_act_size_elems(
2590        &self,
2591        config: &str,
2592        params: &AutoDeviceMapParams,
2593        _prompt_chunksize: usize,
2594    ) -> Result<usize> {
2595        // NOTE: we ignore max_num_images although it can only be one...
2596        let AutoDeviceMapParams::Vision {
2597            max_seq_len,
2598            max_batch_size,
2599            max_image_shape: _,
2600            max_num_images,
2601        } = params
2602        else {
2603            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2604        };
2605
2606        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2607
2608        let vcfg = &PHI4_MM_VISION_CFG;
2609
2610        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
2611        let img_seq_len = (num_patches + 1) * max_num_images;
2612
2613        let max_text_attn = {
2614            // This model injects the vision information directly into the input embeddings
2615            let max_seq_len = img_seq_len + max_seq_len;
2616            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2617        };
2618
2619        Ok(max_text_attn)
2620    }
2621
2622    fn non_mapped_max_act_size_elems(
2623        &self,
2624        _config: &str,
2625        params: &AutoDeviceMapParams,
2626    ) -> Result<usize> {
2627        let AutoDeviceMapParams::Vision {
2628            max_seq_len: _,
2629            max_batch_size,
2630            max_image_shape,
2631            max_num_images,
2632        } = params
2633        else {
2634            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2635        };
2636
2637        let vcfg = &PHI4_MM_VISION_CFG;
2638
2639        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
2640        let img_seq_len = num_patches + 1;
2641
2642        let max_batch_size = max_batch_size
2643            * (max_image_shape
2644                .0
2645                .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
2646                * max_image_shape
2647                    .1
2648                    .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
2649                + 1);
2650
2651        let max_vision_attn = (max_batch_size * max_num_images)
2652            * vcfg.num_attention_heads
2653            * img_seq_len
2654            * img_seq_len;
2655        let max_qkv = 3
2656            * (max_batch_size
2657                * vcfg.num_attention_heads
2658                * img_seq_len
2659                * (vcfg.hidden_size / vcfg.num_attention_heads));
2660
2661        Ok(max_vision_attn + max_qkv)
2662    }
2663
2664    fn non_mapped_size_in_bytes(
2665        &self,
2666        config: &str,
2667        dtype: DType,
2668        weight_pack_factor: usize,
2669    ) -> Result<usize> {
2670        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2671        let elems = {
2672            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2673            let lm_head = if !cfg.tie_word_embeddings {
2674                cfg.hidden_size * cfg.vocab_size
2675            } else {
2676                0
2677            };
2678            let norm = cfg.hidden_size;
2679
2680            let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
2681                let projection_cls = img_embed
2682                    .projection_cls
2683                    .clone()
2684                    .unwrap_or("linear".to_string());
2685                let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
2686                let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
2687                let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
2688
2689                let proj = match (projection_cls.as_str(), use_hd_transform) {
2690                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
2691                    ("mlp", true) => {
2692                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
2693                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2694                        a + b
2695                    }
2696                    ("mlp", false) => {
2697                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
2698                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2699                        a + b
2700                    }
2701                    _ => {
2702                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
2703                    }
2704                };
2705
2706                let (glb_gn, sub_gn) = if with_learnable_separator {
2707                    let glb_gn = image_dim_out * 4;
2708                    let sub_gn = image_dim_out * 4;
2709                    (glb_gn, sub_gn)
2710                } else {
2711                    (0, 0)
2712                };
2713
2714                let vision_transformer = {
2715                    let cfg = &PHI4_MM_VISION_CFG;
2716
2717                    let post_layernorm = cfg.hidden_size;
2718
2719                    let conv_config = Conv2dConfig {
2720                        stride: cfg.patch_size,
2721                        ..Default::default()
2722                    };
2723                    let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2724                        * cfg.patch_size
2725                        * cfg.patch_size;
2726
2727                    let num_patches_per_side = cfg.image_size / cfg.patch_size;
2728                    let num_patches = num_patches_per_side.pow(2);
2729                    let position_embedding = num_patches * cfg.hidden_size;
2730
2731                    let layer_elems = {
2732                        let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2733                        let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2734
2735                        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2736                        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2737
2738                        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2739                        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2740                        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2741                        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2742
2743                        layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2744                    };
2745
2746                    post_layernorm
2747                        + patch_embedding
2748                        + position_embedding
2749                        + layer_elems * cfg.num_hidden_layers
2750                };
2751
2752                proj + glb_gn + sub_gn + vision_transformer
2753            } else {
2754                0
2755            };
2756
2757            embed_tokens + lm_head + norm + image_embed
2758        };
2759
2760        Ok(elems * dtype.size_in_bytes())
2761    }
2762
2763    fn layer_sizes_in_bytes(
2764        &self,
2765        config: &str,
2766        dtype: DType,
2767        weight_pack_factor: usize,
2768    ) -> Result<Vec<usize>> {
2769        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2770        let per_layer_elems = {
2771            let input_layernorm = cfg.hidden_size;
2772            let post_attention_layernorm = cfg.hidden_size;
2773
2774            let size_in = cfg.hidden_size;
2775            let head_dim = cfg.head_dim();
2776            let op_size = head_dim * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
2777            let qkv_proj = size_in * op_size / weight_pack_factor;
2778            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
2779
2780            let h_size = cfg.hidden_size;
2781            let i_size = cfg.intermediate_size;
2782            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
2783            let down_proj = h_size * i_size / weight_pack_factor;
2784
2785            input_layernorm
2786                + post_attention_layernorm
2787                + qkv_proj
2788                + o_proj
2789                + gate_up_proj
2790                + down_proj
2791        };
2792        Ok(vec![
2793            per_layer_elems * dtype.size_in_bytes();
2794            cfg.num_hidden_layers
2795        ])
2796    }
2797
2798    fn num_layers(&self, config: &str) -> Result<usize> {
2799        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2800        Ok(cfg.num_hidden_layers)
2801    }
2802
2803    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2804        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2805
2806        let cfg = ModelConfigMetadata {
2807            max_seq_len: cfg.max_position_embeddings,
2808            num_layers: cfg.num_hidden_layers,
2809            hidden_size: cfg.hidden_size,
2810            num_kv_heads: cfg.num_key_value_heads(),
2811            num_attn_heads: cfg.num_attention_heads,
2812            sliding_window: cfg.sliding_window,
2813            k_head_dim: cfg.head_dim(),
2814            v_head_dim: cfg.head_dim(),
2815        };
2816
2817        Ok(Box::new(cfg))
2818    }
2819
2820    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2821        Some(vec![NonMappedSubModel::Vision])
2822    }
2823}
2824
2825// ======================== Qwen2_5VL Loader
2826
2827/// [`VisionLoader`] for an Qwen2_5VL model.
2828///
2829/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
2830pub struct Qwen2_5VLLoader;
2831
2832pub struct Qwen2_5VLPrefixer;
2833
2834impl VisionPromptPrefixer for Qwen2_5VLPrefixer {
2835    fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
2836        format!(
2837            "{}{}{}{prompt}",
2838            Qwen2_5VLProcessor::VISION_START,
2839            Qwen2_5VLProcessor::IMAGE_PAD,
2840            Qwen2_5VLProcessor::VISION_END
2841        )
2842    }
2843}
2844
2845impl VisionModelLoader for Qwen2_5VLLoader {
2846    fn load(
2847        &self,
2848        config: &str,
2849        _use_flash_attn: bool,
2850        vb: ShardedVarBuilder,
2851        normal_loading_metadata: NormalLoadingMetadata,
2852        attention_mechanism: AttentionImplementation,
2853    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2854        let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
2855        Ok(Box::new(Qwen2_5VLModel::new(
2856            &config,
2857            vb,
2858            self.is_gptx(),
2859            normal_loading_metadata,
2860            attention_mechanism,
2861        )?))
2862    }
2863    fn is_gptx(&self) -> bool {
2864        true
2865    }
2866    fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2867        let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
2868        Ok(Box::new(config))
2869    }
2870    fn get_processor(
2871        &self,
2872        _model_config: &str,
2873        _processor_config: Option<ProcessorConfig>,
2874        _preprocessor_config: PreProcessorConfig,
2875        max_edge: Option<u32>,
2876    ) -> Arc<dyn Processor + Send + Sync> {
2877        Arc::new(Qwen2_5VLProcessor::new(max_edge))
2878    }
2879    fn supports_paged_attention(&self) -> bool {
2880        false
2881    }
2882    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2883        Arc::new(Qwen2_5VLPrefixer)
2884    }
2885}
2886
2887impl IsqModelLoader for Qwen2_5VLLoader {
2888    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2889        Ok(vec![
2890            Regex::new(r"lm_head\.(weight|bias)$")?,
2891            // Attention
2892            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2893            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2894            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2895            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2896            // MLP
2897            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2898            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2899            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2900        ])
2901    }
2902}
2903
2904impl DeviceMappedModelLoader for Qwen2_5VLLoader {
2905    fn mapped_max_act_size_elems(
2906        &self,
2907        config: &str,
2908        params: &AutoDeviceMapParams,
2909        _prompt_chunksize: usize,
2910    ) -> Result<usize> {
2911        let AutoDeviceMapParams::Vision {
2912            max_seq_len,
2913            max_batch_size,
2914            max_image_shape,
2915            max_num_images,
2916        } = params
2917        else {
2918            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2919        };
2920
2921        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
2922
2923        let img_seq_len = {
2924            let cfg = &cfg.vision_config;
2925            let grid_t = max_num_images / cfg.temporal_patch_size;
2926            let grid_h = max_image_shape.0 / cfg.patch_size;
2927            let grid_w = max_image_shape.1 / cfg.patch_size;
2928            grid_t * grid_h * grid_w
2929        };
2930        let img_seq_len = img_seq_len * max_num_images;
2931
2932        let max_text_attn = {
2933            // This model injects the vision information directly into the input embeddings
2934            let max_seq_len = img_seq_len + max_seq_len;
2935            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2936        };
2937
2938        Ok(max_text_attn)
2939    }
2940
2941    fn non_mapped_max_act_size_elems(
2942        &self,
2943        config: &str,
2944        params: &AutoDeviceMapParams,
2945    ) -> Result<usize> {
2946        let AutoDeviceMapParams::Vision {
2947            max_seq_len: _,
2948            max_batch_size,
2949            max_image_shape,
2950            max_num_images,
2951        } = params
2952        else {
2953            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2954        };
2955
2956        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
2957
2958        let img_seq_len = {
2959            let cfg = &cfg.vision_config;
2960            let grid_t = max_num_images / cfg.temporal_patch_size;
2961            let grid_h = max_image_shape.0 / cfg.patch_size;
2962            let grid_w = max_image_shape.1 / cfg.patch_size;
2963            grid_t * grid_h * grid_w
2964        };
2965
2966        let max_vision_attn = {
2967            let cfg = &cfg.vision_config;
2968            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2969        };
2970
2971        Ok(max_vision_attn)
2972    }
2973
2974    fn non_mapped_size_in_bytes(
2975        &self,
2976        config: &str,
2977        dtype: DType,
2978        weight_pack_factor: usize,
2979    ) -> Result<usize> {
2980        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
2981        let text_elems = {
2982            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2983            let lm_head = if !cfg.tie_word_embeddings {
2984                cfg.hidden_size * cfg.vocab_size
2985            } else {
2986                0
2987            };
2988            let norm = cfg.hidden_size;
2989            embed_tokens + lm_head + norm
2990        };
2991
2992        let patch_merger = {
2993            let cfg = &cfg.vision_config;
2994            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
2995
2996            let mlp0 = hidden_size * hidden_size + hidden_size;
2997            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2998
2999            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3000
3001            mlp0 + mlp2 + ln_q
3002        };
3003
3004        let patch_embed = {
3005            let cfg = &cfg.vision_config;
3006            let conv_cfg = Conv3dConfig {
3007                stride: cfg.patch_size,
3008                ..Default::default()
3009            };
3010            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3011            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3012                * kernel_sizes[0]
3013                * kernel_sizes[1]
3014                * kernel_sizes[2]
3015        };
3016
3017        let encoder_layer = {
3018            let cfg = &cfg.vision_config;
3019            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3020            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3021
3022            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3023            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3024            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3025
3026            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3027            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3028
3029            norm1 + norm2 + fc1 + fc2 + qkv + out
3030        };
3031
3032        let elems =
3033            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3034
3035        Ok(elems * dtype.size_in_bytes())
3036    }
3037
3038    fn layer_sizes_in_bytes(
3039        &self,
3040        config: &str,
3041        dtype: DType,
3042        weight_pack_factor: usize,
3043    ) -> Result<Vec<usize>> {
3044        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3045        let per_layer_elems = {
3046            let input_layernorm = cfg.hidden_size;
3047            let post_attention_layernorm = cfg.hidden_size;
3048
3049            let size_in = cfg.hidden_size;
3050            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3051            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3052            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3053            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3054            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3055            let o_proj = size_q * size_in / weight_pack_factor;
3056
3057            let h_size = cfg.hidden_size;
3058            let i_size = cfg.intermediate_size;
3059            let gate_proj = h_size * i_size / weight_pack_factor;
3060            let up_proj = h_size * i_size / weight_pack_factor;
3061            let down_proj = i_size * h_size / weight_pack_factor;
3062
3063            input_layernorm
3064                + post_attention_layernorm
3065                + q_proj
3066                + k_proj
3067                + v_proj
3068                + o_proj
3069                + gate_proj
3070                + up_proj
3071                + down_proj
3072        };
3073        Ok(vec![
3074            per_layer_elems * dtype.size_in_bytes();
3075            cfg.num_hidden_layers
3076        ])
3077    }
3078
3079    fn num_layers(&self, config: &str) -> Result<usize> {
3080        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3081        Ok(cfg.num_hidden_layers)
3082    }
3083
3084    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3085        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3086
3087        let cfg = ModelConfigMetadata {
3088            max_seq_len: cfg.max_position_embeddings,
3089            num_layers: cfg.num_hidden_layers,
3090            hidden_size: cfg.hidden_size,
3091            num_kv_heads: cfg.num_key_value_heads,
3092            num_attn_heads: cfg.num_attention_heads,
3093            sliding_window: cfg.sliding_window,
3094            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3095            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3096        };
3097
3098        Ok(Box::new(cfg))
3099    }
3100
3101    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3102        Some(vec![NonMappedSubModel::Vision])
3103    }
3104}
3105
3106// ======================== Gemma 3 Loader
3107
3108/// [`VisionLoader`] for an Gemma 3 model.
3109///
3110/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3111pub struct Gemma3Loader;
3112
3113pub struct Gemma3Prefixer;
3114
3115impl VisionPromptPrefixer for Gemma3Prefixer {
3116    fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
3117        prompt.to_string()
3118    }
3119}
3120
3121impl VisionModelLoader for Gemma3Loader {
3122    fn load(
3123        &self,
3124        config: &str,
3125        _use_flash_attn: bool,
3126        vb: ShardedVarBuilder,
3127        normal_loading_metadata: NormalLoadingMetadata,
3128        attention_mechanism: AttentionImplementation,
3129    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3130        let config: Gemma3Config = serde_json::from_str(config)?;
3131        Ok(Box::new(Gemma3Model::new(
3132            &config,
3133            vb,
3134            self.is_gptx(),
3135            normal_loading_metadata,
3136            attention_mechanism,
3137        )?))
3138    }
3139    fn is_gptx(&self) -> bool {
3140        true
3141    }
3142    fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3143        let config: Gemma3Config = serde_json::from_str(config)?;
3144        Ok(Box::new(config))
3145    }
3146    fn get_processor(
3147        &self,
3148        config: &str,
3149        processor_config: Option<ProcessorConfig>,
3150        _preprocessor_config: PreProcessorConfig,
3151        _max_edge: Option<u32>,
3152    ) -> Arc<dyn Processor + Send + Sync> {
3153        let config: Gemma3Config = serde_json::from_str(config).unwrap();
3154        // Handle the Gemma 3 1b case here
3155        Arc::new(Gemma3Processor::new(
3156            processor_config.unwrap_or_default(),
3157            matches!(config, Gemma3Config::WithVision { .. }),
3158        ))
3159    }
3160    fn supports_paged_attention(&self) -> bool {
3161        true
3162    }
3163    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
3164        Arc::new(Gemma3Prefixer)
3165    }
3166}
3167
3168impl IsqModelLoader for Gemma3Loader {
3169    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3170        Ok(vec![
3171            Regex::new(r"lm_head\.(weight|bias)$")?,
3172            // Attention
3173            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3174            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3175            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3176            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
3177            // MLP
3178            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3179            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3180            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3181        ])
3182    }
3183}
3184
3185impl DeviceMappedModelLoader for Gemma3Loader {
3186    fn mapped_max_act_size_elems(
3187        &self,
3188        config: &str,
3189        params: &AutoDeviceMapParams,
3190        prompt_chunksize: usize,
3191    ) -> Result<usize> {
3192        let AutoDeviceMapParams::Vision {
3193            max_seq_len,
3194            max_batch_size,
3195            max_image_shape: _,
3196            max_num_images,
3197        } = params
3198        else {
3199            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3200        };
3201
3202        let cfg: Gemma3Config = serde_json::from_str(config)?;
3203
3204        match cfg {
3205            Gemma3Config::Text(text_config) => Ok(max_batch_size
3206                * text_config.num_attention_heads
3207                * prompt_chunksize
3208                * prompt_chunksize),
3209            Gemma3Config::WithVision {
3210                text_config,
3211                vision_config,
3212                ..
3213            } => {
3214                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3215                let img_seq_len = (num_patches + 1) * max_num_images;
3216
3217                let max_text_attn = {
3218                    // This model injects the vision information directly into the input embeddings
3219                    let max_seq_len = img_seq_len + *max_seq_len;
3220                    max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3221                };
3222                Ok(max_text_attn)
3223            }
3224        }
3225    }
3226
3227    fn non_mapped_max_act_size_elems(
3228        &self,
3229        config: &str,
3230        params: &AutoDeviceMapParams,
3231    ) -> Result<usize> {
3232        let AutoDeviceMapParams::Vision {
3233            max_seq_len: _,
3234            max_batch_size,
3235            max_image_shape: _,
3236            max_num_images,
3237        } = params
3238        else {
3239            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3240        };
3241
3242        let cfg: Gemma3Config = serde_json::from_str(config)?;
3243
3244        match cfg {
3245            Gemma3Config::WithVision { vision_config, .. } => {
3246                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3247                let img_seq_len = num_patches + 1;
3248
3249                let max_vision_attn = {
3250                    (max_batch_size * max_num_images)
3251                        * vision_config.num_attention_heads
3252                        * img_seq_len
3253                        * img_seq_len
3254                };
3255
3256                Ok(max_vision_attn)
3257            }
3258            Gemma3Config::Text(_) => Ok(0),
3259        }
3260    }
3261
3262    fn non_mapped_size_in_bytes(
3263        &self,
3264        config: &str,
3265        dtype: DType,
3266        weight_pack_factor: usize,
3267    ) -> Result<usize> {
3268        let cfg: Gemma3Config = serde_json::from_str(config)?;
3269
3270        let text_elems = {
3271            let cfg = match &cfg {
3272                Gemma3Config::Text(cfg) => cfg,
3273                Gemma3Config::WithVision { text_config, .. } => text_config,
3274            };
3275            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3276            let lm_head = if !cfg.tie_word_embeddings {
3277                cfg.hidden_size * cfg.vocab_size
3278            } else {
3279                0
3280            };
3281            let norm = cfg.hidden_size;
3282            embed_tokens + lm_head + norm
3283        };
3284
3285        let vision_transformer = if let Gemma3Config::WithVision {
3286            vision_config: cfg, ..
3287        } = &cfg
3288        {
3289            let post_layernorm = cfg.hidden_size;
3290
3291            let conv_config = Conv2dConfig {
3292                stride: cfg.patch_size,
3293                ..Default::default()
3294            };
3295            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3296                * cfg.patch_size
3297                * cfg.patch_size;
3298
3299            let num_patches_per_side = cfg.image_size / cfg.patch_size;
3300            let num_patches = num_patches_per_side.pow(2);
3301            let position_embedding = num_patches * cfg.hidden_size;
3302
3303            let layer_elems = {
3304                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3305                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3306
3307                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3308                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3309
3310                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3311                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3312                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3313                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3314
3315                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3316            };
3317
3318            post_layernorm
3319                + patch_embedding
3320                + position_embedding
3321                + layer_elems * cfg.num_hidden_layers
3322        } else {
3323            0
3324        };
3325
3326        let elems = text_elems + vision_transformer;
3327
3328        Ok(elems * dtype.size_in_bytes())
3329    }
3330
3331    fn layer_sizes_in_bytes(
3332        &self,
3333        config: &str,
3334        dtype: DType,
3335        weight_pack_factor: usize,
3336    ) -> Result<Vec<usize>> {
3337        let cfg: Gemma3Config = serde_json::from_str(config)?;
3338
3339        let txt_cfg = match &cfg {
3340            Gemma3Config::Text(cfg) => cfg,
3341            Gemma3Config::WithVision { text_config, .. } => text_config,
3342        };
3343        let per_layer_elems = {
3344            let cfg = txt_cfg;
3345
3346            let input_layernorm = cfg.hidden_size;
3347            let post_attention_layernorm = cfg.hidden_size;
3348
3349            let size_in = cfg.hidden_size;
3350            let size_q = cfg.head_dim * cfg.num_attention_heads;
3351            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3352            let q_proj =
3353                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3354            let k_proj =
3355                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3356            let v_proj =
3357                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3358            let o_proj =
3359                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3360
3361            let h_size = cfg.hidden_size;
3362            let i_size = cfg.intermediate_size;
3363            let gate_proj = h_size * i_size / weight_pack_factor;
3364            let up_proj = h_size * i_size / weight_pack_factor;
3365            let down_proj = i_size * h_size / weight_pack_factor;
3366
3367            input_layernorm
3368                + post_attention_layernorm
3369                + q_proj
3370                + k_proj
3371                + v_proj
3372                + o_proj
3373                + gate_proj
3374                + up_proj
3375                + down_proj
3376        };
3377        Ok(vec![
3378            per_layer_elems * dtype.size_in_bytes();
3379            txt_cfg.num_hidden_layers
3380        ])
3381    }
3382
3383    fn num_layers(&self, config: &str) -> Result<usize> {
3384        let cfg: Gemma3Config = serde_json::from_str(config)?;
3385
3386        let txt_cfg = match &cfg {
3387            Gemma3Config::Text(cfg) => cfg,
3388            Gemma3Config::WithVision { text_config, .. } => text_config,
3389        };
3390
3391        Ok(txt_cfg.num_hidden_layers)
3392    }
3393
3394    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3395        let cfg: Gemma3Config = serde_json::from_str(config)?;
3396
3397        let cfg = match &cfg {
3398            Gemma3Config::Text(cfg) => cfg,
3399            Gemma3Config::WithVision { text_config, .. } => text_config,
3400        };
3401
3402        let cfg = ModelConfigMetadata {
3403            max_seq_len: cfg.max_position_embeddings,
3404            num_layers: cfg.num_hidden_layers,
3405            hidden_size: cfg.hidden_size,
3406            num_kv_heads: cfg.num_key_value_heads,
3407            num_attn_heads: cfg.num_attention_heads,
3408            sliding_window: None, // None to be more forgiving, some do not
3409            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3410            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3411        };
3412
3413        Ok(Box::new(cfg))
3414    }
3415
3416    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3417        Some(vec![NonMappedSubModel::Vision])
3418    }
3419}
3420
3421// ======================== Mistral 3 Loader
3422
3423/// [`VisionLoader`] for an Mistral 3 model.
3424///
3425/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
3426pub struct Mistral3Loader;
3427
3428pub struct Mistral3Prefixer;
3429
3430impl VisionPromptPrefixer for Mistral3Prefixer {
3431    fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
3432        prompt.to_string()
3433    }
3434}
3435
3436impl VisionModelLoader for Mistral3Loader {
3437    fn load(
3438        &self,
3439        config: &str,
3440        use_flash_attn: bool,
3441        vb: ShardedVarBuilder,
3442        normal_loading_metadata: NormalLoadingMetadata,
3443        attention_mechanism: AttentionImplementation,
3444    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3445        let mut config: Mistral3Config = serde_json::from_str(config)?;
3446        config.text_config.use_flash_attn = use_flash_attn;
3447        Ok(Box::new(Mistral3Model::new(
3448            &config,
3449            vb,
3450            self.is_gptx(),
3451            normal_loading_metadata,
3452            attention_mechanism,
3453        )?))
3454    }
3455    fn is_gptx(&self) -> bool {
3456        true
3457    }
3458    fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3459        let config: Mistral3Config = serde_json::from_str(config)?;
3460        Ok(Box::new(config))
3461    }
3462    fn get_processor(
3463        &self,
3464        _model_config: &str,
3465        processor_config: Option<ProcessorConfig>,
3466        _preprocessor_config: PreProcessorConfig,
3467        _max_edge: Option<u32>,
3468    ) -> Arc<dyn Processor + Send + Sync> {
3469        Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3470    }
3471    fn supports_paged_attention(&self) -> bool {
3472        true
3473    }
3474    fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
3475        Arc::new(Mistral3Prefixer)
3476    }
3477}
3478
3479impl IsqModelLoader for Mistral3Loader {
3480    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3481        Ok(vec![
3482            Regex::new(r"lm_head\.(weight|bias)$")?,
3483            // Attention
3484            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3485            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3486            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3487            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
3488            // MLP
3489            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3490            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3491            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3492        ])
3493    }
3494}
3495
3496#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3497impl DeviceMappedModelLoader for Mistral3Loader {
3498    fn mapped_max_act_size_elems(
3499        &self,
3500        config: &str,
3501        params: &AutoDeviceMapParams,
3502        _prompt_chunksize: usize,
3503    ) -> Result<usize> {
3504        let cfg: Mistral3Config = serde_json::from_str(config)?;
3505        let vcfg = &cfg.vision_config;
3506        let tcfg = &cfg.text_config;
3507
3508        let AutoDeviceMapParams::Vision {
3509            max_seq_len,
3510            max_batch_size,
3511            max_image_shape: (mut height, mut width),
3512            max_num_images,
3513        } = params
3514        else {
3515            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3516        };
3517
3518        let img_seq_len = {
3519            // Reshaping algorithm
3520
3521            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
3522            let (max_height, max_width) = (1540, 1540);
3523            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3524            if ratio > 1. {
3525                height = (height as f64 / ratio).floor() as usize;
3526                width = (width as f64 / ratio).floor() as usize;
3527            }
3528
3529            let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
3530            let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
3531
3532            height = num_height_tokens * vcfg.patch_size;
3533            width = num_width_tokens * vcfg.patch_size;
3534
3535            let num_height_tokens = height / vcfg.patch_size;
3536            let num_width_tokens = width / vcfg.patch_size;
3537
3538            (num_width_tokens + 1) * num_height_tokens
3539        };
3540
3541        // This model injects the vision information directly into the input embeddings
3542        let max_seq_len = img_seq_len * max_num_images + *max_seq_len;
3543        Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
3544    }
3545
3546    fn non_mapped_max_act_size_elems(
3547        &self,
3548        config: &str,
3549        params: &AutoDeviceMapParams,
3550    ) -> Result<usize> {
3551        let cfg: Mistral3Config = serde_json::from_str(config)?;
3552        let cfg = &cfg.vision_config;
3553
3554        let AutoDeviceMapParams::Vision {
3555            max_seq_len: _,
3556            max_batch_size,
3557            max_image_shape: (mut height, mut width),
3558            max_num_images,
3559        } = params
3560        else {
3561            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3562        };
3563
3564        let img_seq_len = {
3565            // Reshaping algorithm
3566
3567            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
3568            let (max_height, max_width) = (1540, 1540);
3569            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3570            if ratio > 1. {
3571                height = (height as f64 / ratio).floor() as usize;
3572                width = (width as f64 / ratio).floor() as usize;
3573            }
3574
3575            let num_height_tokens = (height - 1) / cfg.patch_size + 1;
3576            let num_width_tokens = (width - 1) / cfg.patch_size + 1;
3577
3578            height = num_height_tokens * cfg.patch_size;
3579            width = num_width_tokens * cfg.patch_size;
3580
3581            let num_height_tokens = height / cfg.patch_size;
3582            let num_width_tokens = width / cfg.patch_size;
3583
3584            (num_width_tokens + 1) * num_height_tokens
3585        };
3586
3587        Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
3588    }
3589
3590    fn non_mapped_size_in_bytes(
3591        &self,
3592        config: &str,
3593        dtype: DType,
3594        weight_pack_factor: usize,
3595    ) -> Result<usize> {
3596        let cfg: Mistral3Config = serde_json::from_str(config)?;
3597
3598        let text_elems = {
3599            let cfg = &cfg.text_config;
3600
3601            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3602            let lm_head = if !cfg.tie_word_embeddings {
3603                cfg.hidden_size * cfg.vocab_size
3604            } else {
3605                0
3606            };
3607            let norm = cfg.hidden_size;
3608            embed_tokens + lm_head + norm
3609        };
3610
3611        let vision_elems = {
3612            let cfg = &cfg.vision_config;
3613
3614            let patch_embed = {
3615                let conv_cfg = Conv2dConfig {
3616                    stride: cfg.patch_size,
3617                    ..Default::default()
3618                };
3619                cfg.num_channels * cfg.hidden_size / conv_cfg.groups
3620                    * cfg.patch_size
3621                    * cfg.patch_size
3622                    * cfg.patch_size
3623            };
3624            let ln_pre = cfg.hidden_size;
3625            let vision_layer = {
3626                let attn_norm = cfg.hidden_size;
3627                let ffn_norm = cfg.hidden_size;
3628
3629                let gate = cfg.hidden_size * cfg.intermediate_size;
3630                let up = cfg.hidden_size * cfg.intermediate_size;
3631                let down = cfg.hidden_size * cfg.intermediate_size;
3632
3633                let q = cfg.hidden_size * cfg.hidden_size;
3634                let k = cfg.hidden_size * cfg.hidden_size;
3635                let v = cfg.hidden_size * cfg.hidden_size;
3636                let o = cfg.hidden_size * cfg.hidden_size;
3637
3638                attn_norm + ffn_norm + gate + up + down + q + k + v + o
3639            };
3640
3641            patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
3642        };
3643
3644        let elems = text_elems + vision_elems;
3645
3646        Ok(elems * dtype.size_in_bytes())
3647    }
3648
3649    fn layer_sizes_in_bytes(
3650        &self,
3651        config: &str,
3652        dtype: DType,
3653        weight_pack_factor: usize,
3654    ) -> Result<Vec<usize>> {
3655        let cfg: Mistral3Config = serde_json::from_str(config)?;
3656        let cfg = &cfg.text_config;
3657
3658        let per_layer_elems = {
3659            let input_layernorm = cfg.hidden_size;
3660            let post_attention_layernorm = cfg.hidden_size;
3661
3662            let size_in = cfg.hidden_size;
3663            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3664            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3665            let q_proj = size_in * size_q / weight_pack_factor;
3666            let k_proj = size_in * size_kv / weight_pack_factor;
3667            let v_proj = size_in * size_kv / weight_pack_factor;
3668            let o_proj = size_q * size_in / weight_pack_factor;
3669
3670            let h_size = cfg.hidden_size;
3671            let i_size = cfg.intermediate_size;
3672            let gate_proj = h_size * i_size / weight_pack_factor;
3673            let up_proj = h_size * i_size / weight_pack_factor;
3674            let down_proj = i_size * h_size / weight_pack_factor;
3675
3676            input_layernorm
3677                + post_attention_layernorm
3678                + q_proj
3679                + k_proj
3680                + v_proj
3681                + o_proj
3682                + gate_proj
3683                + up_proj
3684                + down_proj
3685        };
3686        Ok(vec![
3687            per_layer_elems * dtype.size_in_bytes();
3688            cfg.num_hidden_layers
3689        ])
3690    }
3691
3692    fn num_layers(&self, config: &str) -> Result<usize> {
3693        let cfg: Mistral3Config = serde_json::from_str(config)?;
3694        let cfg = &cfg.text_config;
3695        Ok(cfg.num_hidden_layers)
3696    }
3697
3698    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3699        let cfg: Mistral3Config = serde_json::from_str(config)?;
3700        let cfg = &cfg.text_config;
3701
3702        let cfg = ModelConfigMetadata {
3703            max_seq_len: cfg.max_position_embeddings,
3704            num_layers: cfg.num_hidden_layers,
3705            hidden_size: cfg.hidden_size,
3706            num_kv_heads: cfg.num_key_value_heads,
3707            num_attn_heads: cfg.num_attention_heads,
3708            sliding_window: cfg.sliding_window,
3709            k_head_dim: cfg.head_dim(),
3710            v_head_dim: cfg.head_dim(),
3711        };
3712
3713        Ok(Box::new(cfg))
3714    }
3715
3716    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3717        Some(vec![NonMappedSubModel::Vision])
3718    }
3719}