1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::device_map::DeviceMapper;
24use crate::layers::Conv3dConfig;
25use crate::matformer::MatformerSliceConfig;
26use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
27use crate::pipeline::isq::IsqModelLoader;
28use crate::pipeline::loaders::AutoDeviceMapParams;
29use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
30use crate::pipeline::{
31 EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
32 SupportedModality,
33};
34use crate::utils::varbuilder_utils::DeviceForLoadTensor;
35use crate::vision_models::clip::ClipConfig;
36use crate::vision_models::gemma3::config::Gemma3Config;
37use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
38use crate::vision_models::gemma3n::config::Gemma3nConfig;
39use crate::vision_models::gemma3n::{Gemma3nModel, Gemma3nProcessor};
40use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
41use crate::vision_models::idefics2_input_processor::Idefics2Processor;
42use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
43use crate::vision_models::image_processor::ImagePreProcessor;
44use crate::vision_models::inputs_processor::Phi4MMProcessor;
45use crate::vision_models::llama4::{
46 self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
47};
48use crate::vision_models::llava::config::Config as LLaVAConfig;
49use crate::vision_models::llava15::Model as LLaVA;
50use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
51use crate::vision_models::llava_next::Model as LLaVANext;
52use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
53use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
54use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
55use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
56use crate::vision_models::phi3_inputs_processor::Phi3Processor;
57use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
58use crate::vision_models::preprocessor_config::PreProcessorConfig;
59use crate::vision_models::processor_config::ProcessorConfig;
60use crate::vision_models::qwen2_5_vl::{
61 Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
62};
63use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
64use crate::vision_models::{minicpmo, phi4};
65
66pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
67 #[allow(clippy::too_many_arguments)]
69 fn forward(
70 &self,
71 input_ids: &Tensor,
72 pixel_values: Option<Tensor>,
73 seqlen_offsets: &[usize],
74 context_lens: Vec<(usize, usize)>,
75 position_ids: Vec<usize>,
76 model_specific_args: Box<dyn Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
78 flash_params: &FlashParams,
79 ) -> candle_core::Result<Tensor>;
80 fn device(&self) -> &Device;
81 fn cache(&self) -> &EitherCache;
82 fn cache_mut(&mut self) -> &mut EitherCache;
83 fn max_seq_len(&self) -> usize;
84 fn config(&self) -> &ModelConfigMetadata;
85 fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
87}
88
89pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90 fn load(
91 &self,
92 config: &str,
93 vb: ShardedVarBuilder,
94 normal_loading_metadata: NormalLoadingMetadata,
95 attention_mechanism: AttentionImplementation,
96 ) -> Result<Box<dyn VisionModel + Send + Sync>>;
97 fn is_gptx(&self, config: &str) -> bool;
98 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
99 fn get_processor(
100 &self,
101 model_config: &str,
102 processor_config: Option<ProcessorConfig>,
103 preprocessor_config: PreProcessorConfig,
104 max_edge: Option<u32>,
105 ) -> Arc<dyn Processor + Send + Sync>;
106 fn supports_paged_attention(&self, config: &str) -> bool;
107 fn supports_prefix_cacher(&self, _config: &str) -> bool {
108 false
110 }
111 fn modalities(&self, config: &str) -> Result<Modalities>;
112 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
113 fn get_device_for_tensor(
114 &self,
115 config: &str,
116 _mapper: &dyn DeviceMapper,
117 loading_isq: bool,
118 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119 if loading_isq {
120 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121 } else {
122 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123 let num_layers = self.model_config(config)?.num_layers();
124 let closure = move |name: String| {
125 if let Some(captures) = re.captures(&name) {
126 captures
127 .get(1)
128 .and_then(|m| m.as_str().parse::<usize>().ok())
129 .map(|l| l.min(num_layers))
130 .map(DeviceForLoadTensor::Idx)
131 .unwrap_or(DeviceForLoadTensor::Base)
132 } else {
133 DeviceForLoadTensor::Base
134 }
135 };
136
137 Ok(Arc::new(closure))
138 }
139 }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, PartialEq)]
144pub enum VisionLoaderType {
146 #[serde(rename = "phi3v")]
147 Phi3V,
148 #[serde(rename = "idefics2")]
149 Idefics2,
150 #[serde(rename = "llava_next")]
151 LLaVANext,
152 #[serde(rename = "llava")]
153 LLaVA,
154 #[serde(rename = "vllama")]
155 VLlama,
156 #[serde(rename = "qwen2vl")]
157 Qwen2VL,
158 #[serde(rename = "idefics3")]
159 Idefics3,
160 #[serde(rename = "minicpmo")]
161 MiniCpmO,
162 #[serde(rename = "phi4mm")]
163 Phi4MM,
164 #[serde(rename = "qwen2_5vl")]
165 Qwen2_5VL,
166 #[serde(rename = "gemma3")]
167 Gemma3,
168 #[serde(rename = "mistral3")]
169 Mistral3,
170 #[serde(rename = "llama4")]
171 Llama4,
172 #[serde(rename = "gemma3n")]
173 Gemma3n,
174}
175
176impl VisionLoaderType {
178 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
179 match name {
180 "Phi3VForCausalLM" => Ok(Self::Phi3V),
181 "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
182 "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
183 "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
184 "MllamaForConditionalGeneration" => Ok(Self::VLlama),
185 "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
186 "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
187 "MiniCPMO" => Ok(Self::MiniCpmO),
188 "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
189 "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
190 "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
191 "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
192 "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
193 "Gemma3nForConditionalGeneration" => Ok(Self::Gemma3n),
194 other => anyhow::bail!(
195 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
196 ),
197 }
198 }
199}
200
201impl FromStr for VisionLoaderType {
202 type Err = String;
203 fn from_str(s: &str) -> Result<Self, Self::Err> {
204 match s {
205 "phi3v" => Ok(Self::Phi3V),
206 "idefics2" => Ok(Self::Idefics2),
207 "llava_next" => Ok(Self::LLaVANext),
208 "llava" => Ok(Self::LLaVA),
209 "vllama" => Ok(Self::VLlama),
210 "qwen2vl" => Ok(Self::Qwen2VL),
211 "idefics3" => Ok(Self::Idefics3),
212 "minicpmo" => Ok(Self::MiniCpmO),
213 "phi4mm" => Ok(Self::Phi4MM),
214 "qwen2_5vl" => Ok(Self::Qwen2_5VL),
215 "gemma3" => Ok(Self::Gemma3),
216 "mistral3" => Ok(Self::Mistral3),
217 "llama4" => Ok(Self::Llama4),
218 "gemma3n" => Ok(Self::Gemma3n),
219 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`, `gemma3n`.")),
220 }
221 }
222}
223
224impl std::fmt::Display for VisionLoaderType {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 let name = match self {
227 VisionLoaderType::Phi3V => "phi3v",
228 VisionLoaderType::Idefics2 => "idefics2",
229 VisionLoaderType::LLaVANext => "llava_next",
230 VisionLoaderType::LLaVA => "llava",
231 VisionLoaderType::VLlama => "vllama",
232 VisionLoaderType::Qwen2VL => "qwen2vl",
233 VisionLoaderType::Idefics3 => "idefics3",
234 VisionLoaderType::MiniCpmO => "minicpmo",
235 VisionLoaderType::Phi4MM => "phi4mm",
236 VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
237 VisionLoaderType::Gemma3 => "gemma3",
238 VisionLoaderType::Mistral3 => "mistral3",
239 VisionLoaderType::Llama4 => "llama4",
240 VisionLoaderType::Gemma3n => "gemma3n",
241 };
242 write!(f, "{name}")
243 }
244}
245
246#[derive(Deserialize)]
247struct AutoVisionLoaderConfig {
248 architectures: Vec<String>,
249}
250
251pub struct AutoVisionLoader;
253
254impl AutoVisionLoader {
255 fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
256 let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
257 if auto_cfg.architectures.len() != 1 {
258 anyhow::bail!("Expected exactly one architecture in config");
259 }
260
261 let name = &auto_cfg.architectures[0];
262 let tp = VisionLoaderType::from_causal_lm_name(name)?;
263
264 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
265
266 Ok(match tp {
268 VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
269 VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
270 VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
271 VisionLoaderType::LLaVA => Box::new(LLaVALoader),
272 VisionLoaderType::VLlama => Box::new(VLlamaLoader),
273 VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
274 VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
275 VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
276 VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
277 VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
278 VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
279 VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
280 VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
281 VisionLoaderType::Gemma3n => Box::new(Gemma3nLoader),
282 })
283 }
284}
285
286impl VisionModelLoader for AutoVisionLoader {
287 fn load(
288 &self,
289 config: &str,
290 vb: ShardedVarBuilder,
291 normal_loading_metadata: NormalLoadingMetadata,
292 attention_mechanism: AttentionImplementation,
293 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
294 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
295 }
296
297 fn is_gptx(&self, config: &str) -> bool {
298 Self::get_loader(config)
299 .expect("AutoVisionLoader get_loader")
300 .is_gptx(config)
301 }
302
303 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
304 Self::get_loader(config)?.get_config_repr(config)
305 }
306
307 fn get_processor(
308 &self,
309 model_config: &str,
310 proc_cfg: Option<ProcessorConfig>,
311 preproc_cfg: PreProcessorConfig,
312 max_edge: Option<u32>,
313 ) -> Arc<dyn Processor + Send + Sync> {
314 Self::get_loader(model_config)
315 .expect("AutoVisionLoader get_loader")
316 .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
317 }
318
319 fn supports_paged_attention(&self, config: &str) -> bool {
320 Self::get_loader(config)
321 .expect("AutoVisionLoader")
322 .supports_paged_attention(config)
323 }
324
325 fn modalities(&self, config: &str) -> Result<Modalities> {
326 Self::get_loader(config)?.modalities(config)
327 }
328
329 fn supports_prefix_cacher(&self, config: &str) -> bool {
330 Self::get_loader(config)
331 .expect("AutoVisionLoader")
332 .supports_prefix_cacher(config)
333 }
334
335 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
336 Self::get_loader(config)
337 .expect("AutoVisionLoader")
338 .prefixer(config)
339 }
340
341 fn get_device_for_tensor(
342 &self,
343 config: &str,
344 mapper: &dyn DeviceMapper,
345 loading_isq: bool,
346 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
347 Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
348 }
349}
350
351impl IsqModelLoader for AutoVisionLoader {
352 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
353 Self::get_loader(config)?.isq_layer_regexes(config)
354 }
355 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
356 Self::get_loader(config)?.immediate_isq_predicates(config)
357 }
358}
359
360impl DeviceMappedModelLoader for AutoVisionLoader {
361 fn mapped_max_act_size_elems(
362 &self,
363 config: &str,
364 params: &AutoDeviceMapParams,
365 prompt_chunksize: usize,
366 ) -> Result<usize> {
367 Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
368 }
369 fn non_mapped_max_act_size_elems(
370 &self,
371 config: &str,
372 params: &AutoDeviceMapParams,
373 ) -> Result<usize> {
374 Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
375 }
376 fn non_mapped_size_in_bytes(
377 &self,
378 config: &str,
379 dtype: DType,
380 weight_pack_factor: usize,
381 _matformer_config: Option<&MatformerSliceConfig>,
382 ) -> Result<usize> {
383 Self::get_loader(config)?.non_mapped_size_in_bytes(
384 config,
385 dtype,
386 weight_pack_factor,
387 _matformer_config,
388 )
389 }
390 fn layer_sizes_in_bytes(
391 &self,
392 config: &str,
393 dtype: DType,
394 weight_pack_factor: usize,
395 _matformer_config: Option<&MatformerSliceConfig>,
396 ) -> Result<Vec<usize>> {
397 Self::get_loader(config)?.layer_sizes_in_bytes(
398 config,
399 dtype,
400 weight_pack_factor,
401 _matformer_config,
402 )
403 }
404 fn num_layers(&self, config: &str) -> Result<usize> {
405 Self::get_loader(config)?.num_layers(config)
406 }
407 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
408 Self::get_loader(config)?.model_config(config)
409 }
410}
411
412macro_rules! bias_if {
413 ($cond:expr, $size:expr) => {
414 if $cond {
415 $size
416 } else {
417 0
418 }
419 };
420}
421
422fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
423 let pre_layer_norm = cfg.hidden_size;
424 let final_layer_norm = cfg.hidden_size;
425
426 let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
427 let num_positions = num_patches + 1;
428
429 let class_embedding = cfg.hidden_size;
430
431 let position_ids = num_positions;
432 let position_embedding = num_positions * cfg.hidden_size;
433
434 let conv2dconfig = Conv2dConfig {
435 stride: cfg.patch_size,
436 ..Default::default()
437 };
438 let patch_embedding =
439 cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
440
441 let encoder_layer_elems = {
442 let layer_norm1 = cfg.hidden_size;
443 let layer_norm2 = cfg.hidden_size;
444
445 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
446 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
447 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
448 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
449
450 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
451 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
452
453 layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
454 };
455
456 pre_layer_norm
457 + final_layer_norm
458 + class_embedding
459 + position_ids
460 + position_embedding
461 + patch_embedding
462 + cfg.num_hidden_layers * encoder_layer_elems
463}
464
465pub struct Phi3VLoader;
471
472pub struct Phi3VPrefixer;
473
474impl MultimodalPromptPrefixer for Phi3VPrefixer {
475 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
476 format!(
478 "{}{prompt}",
479 image_indexes
480 .into_iter()
481 .map(|image_index| format!("<|image_{}|>", image_index + 1))
482 .join("")
483 )
484 }
485}
486
487impl VisionModelLoader for Phi3VLoader {
488 fn load(
489 &self,
490 config: &str,
491 vb: ShardedVarBuilder,
492 normal_loading_metadata: NormalLoadingMetadata,
493 attention_mechanism: AttentionImplementation,
494 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
495 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
496 Ok(Box::new(Phi3::new(
497 &cfg,
498 vb,
499 self.is_gptx(config),
500 normal_loading_metadata,
501 attention_mechanism,
502 )?))
503 }
504 fn is_gptx(&self, _config: &str) -> bool {
505 true
506 }
507 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
508 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
509 Ok(Box::new(cfg))
510 }
511 fn get_processor(
512 &self,
513 _model_config: &str,
514 processor_config: Option<ProcessorConfig>,
515 preprocessor_config: PreProcessorConfig,
516 _max_edge: Option<u32>,
517 ) -> Arc<dyn Processor + Send + Sync> {
518 Phi3Processor::new_processor(processor_config, preprocessor_config)
519 }
520 fn supports_paged_attention(&self, _config: &str) -> bool {
521 true
522 }
523 fn supports_prefix_cacher(&self, _config: &str) -> bool {
524 true
525 }
526 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
527 Arc::new(Phi3VPrefixer)
528 }
529 fn modalities(&self, _config: &str) -> Result<Modalities> {
530 Ok(Modalities {
531 input: vec![SupportedModality::Text, SupportedModality::Vision],
532 output: vec![SupportedModality::Text],
533 })
534 }
535}
536
537impl IsqModelLoader for Phi3VLoader {
538 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
539 Ok(vec![
540 Regex::new(r"lm_head\.(weight|bias)$")?,
541 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
543 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
544 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
546 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
547 ])
548 }
549 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
550 self.isq_layer_regexes(config)
551 }
552}
553
554impl DeviceMappedModelLoader for Phi3VLoader {
555 fn mapped_max_act_size_elems(
556 &self,
557 config: &str,
558 params: &AutoDeviceMapParams,
559 _prompt_chunksize: usize,
560 ) -> Result<usize> {
561 let AutoDeviceMapParams::Vision {
563 max_seq_len,
564 max_batch_size,
565 max_image_shape: _,
566 max_num_images,
567 } = params
568 else {
569 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
570 };
571
572 let cfg: Phi3Config = serde_json::from_str(config)?;
573
574 let vcfg = &PHI3V_CLIP_CONFIG;
575
576 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
577 let img_seq_len = (num_patches + 1) * max_num_images;
578
579 let max_text_attn = {
580 let max_seq_len = img_seq_len + max_seq_len;
582 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
583 };
584
585 Ok(max_text_attn)
586 }
587
588 fn non_mapped_max_act_size_elems(
589 &self,
590 config: &str,
591 params: &AutoDeviceMapParams,
592 ) -> Result<usize> {
593 let AutoDeviceMapParams::Vision {
595 max_seq_len: _,
596 max_batch_size,
597 max_image_shape: _,
598 max_num_images,
599 } = params
600 else {
601 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
602 };
603
604 let cfg: Phi3Config = serde_json::from_str(config)?;
605
606 let vcfg = &PHI3V_CLIP_CONFIG;
607
608 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
609 let img_seq_len = num_patches + 1;
610
611 let max_vision_attn = {
612 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
613 };
614
615 Ok(max_vision_attn)
616 }
617
618 fn non_mapped_size_in_bytes(
619 &self,
620 config: &str,
621 dtype: DType,
622 weight_pack_factor: usize,
623 _matformer_config: Option<&MatformerSliceConfig>,
624 ) -> Result<usize> {
625 let cfg: Phi3Config = serde_json::from_str(config)?;
626 let elems = {
627 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
628 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
630 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
631 } else {
632 0
633 };
634 let norm = cfg.hidden_size;
635
636 let image_embed = {
637 let projection_cls = cfg
638 .embd_layer
639 .projection_cls
640 .clone()
641 .unwrap_or("linear".to_string());
642 let with_learnable_separator =
643 cfg.embd_layer.with_learnable_separator.unwrap_or(false);
644 let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
645 let image_dim_out = cfg.img_processor.image_dim_out;
646
647 let proj = match (projection_cls.as_str(), use_hd_transform) {
648 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
649 ("mlp", true) => {
650 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
651 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
652 a + b
653 }
654 ("mlp", false) => {
655 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
656 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
657 a + b
658 }
659 _ => {
660 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
661 }
662 };
663
664 let (glb_gn, sub_gn) = if with_learnable_separator {
665 let glb_gn = image_dim_out * 4;
666 let sub_gn = image_dim_out * 4;
667 (glb_gn, sub_gn)
668 } else {
669 (0, 0)
670 };
671
672 let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
673
674 proj + glb_gn + sub_gn + clip_vit
675 };
676
677 embed_tokens + lm_head + norm + image_embed
678 };
679
680 Ok(elems * dtype.size_in_bytes())
681 }
682
683 fn layer_sizes_in_bytes(
684 &self,
685 config: &str,
686 dtype: DType,
687 weight_pack_factor: usize,
688 _matformer_config: Option<&MatformerSliceConfig>,
689 ) -> Result<Vec<usize>> {
690 let cfg: Phi3Config = serde_json::from_str(config)?;
691 let per_layer_elems = {
692 let input_layernorm = cfg.hidden_size;
693 let post_attention_layernorm = cfg.hidden_size;
694
695 let size_in = cfg.hidden_size;
696 let head_dim = cfg.head_dim();
697 let op_size =
698 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
699 let qkv_proj = size_in * op_size / weight_pack_factor;
700 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
701
702 let h_size = cfg.hidden_size;
703 let i_size = cfg.intermediate_size;
704 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
705 let down_proj = h_size * i_size / weight_pack_factor;
706
707 input_layernorm
708 + post_attention_layernorm
709 + qkv_proj
710 + o_proj
711 + gate_up_proj
712 + down_proj
713 };
714 Ok(vec![
715 per_layer_elems * dtype.size_in_bytes();
716 cfg.num_hidden_layers
717 ])
718 }
719
720 fn num_layers(&self, config: &str) -> Result<usize> {
721 let cfg: Phi3Config = serde_json::from_str(config)?;
722 Ok(cfg.num_hidden_layers)
723 }
724
725 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
726 let cfg: Phi3Config = serde_json::from_str(config)?;
727
728 let cfg = ModelConfigMetadata {
729 max_seq_len: cfg.max_position_embeddings,
730 num_layers: cfg.num_hidden_layers,
731 hidden_size: cfg.hidden_size,
732 num_kv_heads: cfg.num_key_value_heads,
733 num_attn_heads: cfg.num_attention_heads,
734 sliding_window: cfg.sliding_window,
735 k_head_dim: cfg.head_dim(),
736 v_head_dim: cfg.head_dim(),
737 };
738
739 Ok(Box::new(cfg))
740 }
741
742 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
743 Some(vec![NonMappedSubModel::Vision])
744 }
745}
746
747pub struct Idefics2Loader;
753
754pub struct Idefics2Prefixer;
755
756impl MultimodalPromptPrefixer for Idefics2Prefixer {
757 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
758 prompt.to_string()
760 }
761}
762
763impl VisionModelLoader for Idefics2Loader {
764 fn load(
765 &self,
766 config: &str,
767 vb: ShardedVarBuilder,
768 normal_loading_metadata: NormalLoadingMetadata,
769 attention_mechanism: AttentionImplementation,
770 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
771 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
772 Ok(Box::new(Idefics2::new(
773 &cfg,
774 vb,
775 self.is_gptx(config),
776 normal_loading_metadata,
777 attention_mechanism,
778 )?))
779 }
780 fn is_gptx(&self, _config: &str) -> bool {
781 true
782 }
783 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
784 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
785 Ok(Box::new(cfg))
786 }
787 fn get_processor(
788 &self,
789 _model_config: &str,
790 processor_config: Option<ProcessorConfig>,
791 preprocessor_config: PreProcessorConfig,
792 max_edge: Option<u32>,
793 ) -> Arc<dyn Processor + Send + Sync> {
794 Arc::new(Idefics2Processor::new(
795 processor_config.unwrap(),
796 preprocessor_config,
797 max_edge,
798 ))
799 }
800 fn supports_paged_attention(&self, _config: &str) -> bool {
801 true
802 }
803 fn supports_prefix_cacher(&self, _config: &str) -> bool {
804 true
805 }
806 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
807 Arc::new(Idefics2Prefixer)
808 }
809 fn modalities(&self, _config: &str) -> Result<Modalities> {
810 Ok(Modalities {
811 input: vec![SupportedModality::Text, SupportedModality::Vision],
812 output: vec![SupportedModality::Text],
813 })
814 }
815}
816
817impl IsqModelLoader for Idefics2Loader {
818 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
819 Ok(vec![
820 Regex::new(r"lm_head\.(weight|bias)$")?,
821 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
823 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
824 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
825 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
826 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
828 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
829 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
830 ])
831 }
832 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
833 Ok(vec![
834 Regex::new(r"lm_head\.(weight|bias)$")?,
835 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
837 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
838 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
839 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
840 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
842 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
843 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
844 ])
845 }
846}
847
848impl DeviceMappedModelLoader for Idefics2Loader {
849 fn mapped_max_act_size_elems(
850 &self,
851 config: &str,
852 params: &AutoDeviceMapParams,
853 _prompt_chunksize: usize,
854 ) -> Result<usize> {
855 let AutoDeviceMapParams::Vision {
856 max_seq_len,
857 max_batch_size,
858 max_image_shape: _,
859 max_num_images,
860 } = params
861 else {
862 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
863 };
864
865 let cfg: Idefics2Config = serde_json::from_str(config)?;
866
867 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
868 let img_seq_len = (num_patches + 1) * max_num_images;
869
870 let max_text_attn = {
871 let max_seq_len = img_seq_len + max_seq_len;
873 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
874 };
875
876 Ok(max_text_attn)
877 }
878
879 fn non_mapped_max_act_size_elems(
880 &self,
881 config: &str,
882 params: &AutoDeviceMapParams,
883 ) -> Result<usize> {
884 let AutoDeviceMapParams::Vision {
885 max_seq_len: _,
886 max_batch_size,
887 max_image_shape: _,
888 max_num_images,
889 } = params
890 else {
891 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
892 };
893
894 let cfg: Idefics2Config = serde_json::from_str(config)?;
895
896 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
897 let img_seq_len = num_patches + 1;
898
899 let max_vision_attn = {
900 let images_factor = 5;
902
903 (max_batch_size * images_factor * max_num_images)
904 * cfg.vision_config.num_attention_heads
905 * img_seq_len
906 * img_seq_len
907 };
908
909 Ok(max_vision_attn)
910 }
911
912 fn non_mapped_size_in_bytes(
913 &self,
914 config: &str,
915 dtype: DType,
916 weight_pack_factor: usize,
917 _matformer_config: Option<&MatformerSliceConfig>,
918 ) -> Result<usize> {
919 let cfg: Idefics2Config = serde_json::from_str(config)?;
920 let text_elems = {
921 let tie_word_embeddings = cfg.tie_word_embeddings;
922 let cfg = &cfg.text_config;
923
924 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
925 let lm_head = if !tie_word_embeddings {
926 cfg.hidden_size * cfg.vocab_size
927 } else {
928 0
929 };
930 let norm = cfg.hidden_size;
931 embed_tokens + lm_head + norm
932 };
933
934 let connector_elems = {
935 let tcfg = &cfg.text_config;
936 let vcfg = &cfg.vision_config;
937 let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
938 let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
939 let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
940
941 let perceiver_elems = {
942 let tcfg = &cfg.text_config;
943 let pcfg = &cfg.perceiver_config;
944
945 let n_latents = pcfg.resampler_n_latents;
946 let hidden_size = tcfg.hidden_size;
947 let depth = pcfg.resampler_depth;
948
949 let norm = tcfg.hidden_size;
950 let latents = n_latents * hidden_size;
951
952 let layer_elems = {
953 let input_latents_norm = hidden_size;
954 let input_context_norm = hidden_size;
955 let post_attn_norm = hidden_size;
956
957 let num_heads = pcfg.resampler_n_heads;
958 let head_dim = pcfg.resampler_head_dim;
959 let num_key_value_heads = pcfg.num_key_value_heads;
960
961 let q_proj = hidden_size * num_heads * head_dim;
962 let k_proj = hidden_size * num_key_value_heads * head_dim;
963 let v_proj = hidden_size * num_key_value_heads * head_dim;
964 let o_proj = num_heads * head_dim * hidden_size;
965
966 let gate_proj = hidden_size * hidden_size * 4;
967 let up_proj = hidden_size * hidden_size * 4;
968 let down_proj = hidden_size * 4 * hidden_size;
969
970 input_latents_norm
971 + input_context_norm
972 + post_attn_norm
973 + q_proj
974 + k_proj
975 + v_proj
976 + o_proj
977 + gate_proj
978 + up_proj
979 + down_proj
980 };
981
982 norm + latents + layer_elems * depth
983 };
984
985 gate_proj + up_proj + down_proj + perceiver_elems
986 };
987
988 let vision_transformer = {
989 let cfg = &cfg.vision_config;
990
991 let post_layernorm = cfg.hidden_size;
992
993 let conv_config = Conv2dConfig {
994 stride: cfg.patch_size,
995 ..Default::default()
996 };
997 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
998 * cfg.patch_size
999 * cfg.patch_size;
1000
1001 let num_patches_per_side = cfg.image_size / cfg.patch_size;
1002 let num_patches = num_patches_per_side.pow(2);
1003 let position_embedding = num_patches * cfg.hidden_size;
1004
1005 let layer_elems = {
1006 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1007 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1008
1009 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
1010 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
1011
1012 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1013 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1014 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1015 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1016
1017 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
1018 };
1019
1020 post_layernorm + patch_embedding + position_embedding + layer_elems
1021 };
1022
1023 let elems = text_elems + connector_elems + vision_transformer;
1024
1025 Ok(elems * dtype.size_in_bytes())
1026 }
1027
1028 fn layer_sizes_in_bytes(
1029 &self,
1030 config: &str,
1031 dtype: DType,
1032 weight_pack_factor: usize,
1033 _matformer_config: Option<&MatformerSliceConfig>,
1034 ) -> Result<Vec<usize>> {
1035 let cfg: Idefics2Config = serde_json::from_str(config)?;
1036 let cfg = cfg.text_config;
1037 let per_layer_elems = {
1038 let input_layernorm = cfg.hidden_size;
1039 let post_attention_layernorm = cfg.hidden_size;
1040
1041 let size_in = cfg.hidden_size;
1042 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1043 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1044 let q_proj = size_in * size_q / weight_pack_factor;
1045 let k_proj = size_in * size_kv / weight_pack_factor;
1046 let v_proj = size_in * size_kv / weight_pack_factor;
1047 let o_proj = size_q * size_in / weight_pack_factor;
1048
1049 let h_size = cfg.hidden_size;
1050 let i_size = cfg.intermediate_size;
1051 let gate_proj = h_size * i_size / weight_pack_factor;
1052 let up_proj = h_size * i_size / weight_pack_factor;
1053 let down_proj = i_size * h_size / weight_pack_factor;
1054
1055 input_layernorm
1056 + post_attention_layernorm
1057 + q_proj
1058 + k_proj
1059 + v_proj
1060 + o_proj
1061 + gate_proj
1062 + up_proj
1063 + down_proj
1064 };
1065 Ok(vec![
1066 per_layer_elems * dtype.size_in_bytes();
1067 cfg.num_hidden_layers
1068 ])
1069 }
1070
1071 fn num_layers(&self, config: &str) -> Result<usize> {
1072 let cfg: Idefics2Config = serde_json::from_str(config)?;
1073 Ok(cfg.text_config.num_hidden_layers)
1074 }
1075 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1076 let cfg: Idefics2Config = serde_json::from_str(config)?;
1077 let cfg = &cfg.text_config;
1078
1079 let cfg = ModelConfigMetadata {
1080 max_seq_len: cfg.max_position_embeddings,
1081 num_layers: cfg.num_hidden_layers,
1082 hidden_size: cfg.hidden_size,
1083 num_kv_heads: cfg.num_key_value_heads,
1084 num_attn_heads: cfg.num_attention_heads,
1085 sliding_window: cfg.sliding_window,
1086 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1087 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1088 };
1089
1090 Ok(Box::new(cfg))
1091 }
1092
1093 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1094 Some(vec![NonMappedSubModel::Vision])
1095 }
1096}
1097
1098pub struct LLaVANextLoader;
1104
1105pub struct LLaVANextPrefixer;
1106
1107impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1108 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1109 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1110 }
1111}
1112
1113impl VisionModelLoader for LLaVANextLoader {
1114 fn load(
1115 &self,
1116 config: &str,
1117 vb: ShardedVarBuilder,
1118 normal_loading_metadata: NormalLoadingMetadata,
1119 attention_mechanism: AttentionImplementation,
1120 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1121 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1122 Ok(Box::new(LLaVANext::new(
1123 &cfg,
1124 vb,
1125 self.is_gptx(config),
1126 normal_loading_metadata,
1127 attention_mechanism,
1128 )?))
1129 }
1130 fn is_gptx(&self, _config: &str) -> bool {
1131 false
1132 }
1133 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1134 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1135 Ok(Box::new(cfg))
1136 }
1137 fn get_processor(
1138 &self,
1139 model_config: &str,
1140 _processor_config: Option<ProcessorConfig>,
1141 _preprocessor_config: PreProcessorConfig,
1142 _max_edge: Option<u32>,
1143 ) -> Arc<dyn Processor + Send + Sync> {
1144 Arc::new(LLaVANextProcessor::new(model_config))
1145 }
1146 fn supports_paged_attention(&self, _config: &str) -> bool {
1147 true
1148 }
1149 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1150 true
1151 }
1152 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1153 Arc::new(LLaVANextPrefixer)
1154 }
1155 fn modalities(&self, _config: &str) -> Result<Modalities> {
1156 Ok(Modalities {
1157 input: vec![SupportedModality::Text, SupportedModality::Vision],
1158 output: vec![SupportedModality::Text],
1159 })
1160 }
1161}
1162
1163impl IsqModelLoader for LLaVANextLoader {
1164 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1165 Ok(vec![
1166 Regex::new(r"lm_head\.(weight|bias)$")?,
1167 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1169 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1170 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1171 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1172 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1174 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1175 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1176 ])
1177 }
1178 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1179 Ok(vec![
1180 Regex::new(r"lm_head\.(weight|bias)$")?,
1181 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1183 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1184 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1185 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1186 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1188 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1189 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1190 ])
1191 }
1192}
1193
1194impl DeviceMappedModelLoader for LLaVANextLoader {
1195 fn mapped_max_act_size_elems(
1196 &self,
1197 config: &str,
1198 params: &AutoDeviceMapParams,
1199 _prompt_chunksize: usize,
1200 ) -> Result<usize> {
1201 let AutoDeviceMapParams::Vision {
1202 max_seq_len,
1203 max_batch_size,
1204 max_image_shape,
1205 max_num_images,
1206 } = params
1207 else {
1208 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1209 };
1210
1211 let config: LLaVAConfig = serde_json::from_str(config)?;
1212
1213 #[allow(clippy::cast_possible_truncation)]
1214 let img_seq_len =
1215 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1216 &config,
1217 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1218 );
1219 let img_seq_len = img_seq_len * max_num_images;
1220
1221 let max_text_attn = {
1222 let cfg = &config.text_config;
1223 let max_seq_len = img_seq_len + max_seq_len;
1225
1226 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1227 };
1228
1229 Ok(max_text_attn)
1230 }
1231
1232 fn non_mapped_max_act_size_elems(
1233 &self,
1234 config: &str,
1235 params: &AutoDeviceMapParams,
1236 ) -> Result<usize> {
1237 let AutoDeviceMapParams::Vision {
1238 max_seq_len: _,
1239 max_batch_size,
1240 max_image_shape,
1241 max_num_images,
1242 } = params
1243 else {
1244 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1245 };
1246
1247 let config: LLaVAConfig = serde_json::from_str(config)?;
1248
1249 #[allow(clippy::cast_possible_truncation)]
1250 let img_seq_len =
1251 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1252 &config,
1253 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1254 );
1255
1256 let max_vision_attn = {
1257 (max_batch_size * max_num_images)
1258 * config.vision_config.num_attention_heads
1259 * img_seq_len
1260 * img_seq_len
1261 };
1262
1263 Ok(max_vision_attn)
1264 }
1265
1266 fn non_mapped_size_in_bytes(
1267 &self,
1268 config: &str,
1269 dtype: DType,
1270 weight_pack_factor: usize,
1271 _matformer_config: Option<&MatformerSliceConfig>,
1272 ) -> Result<usize> {
1273 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1274 let text_elems = {
1275 let cfg = &cfg.text_config;
1276 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1277 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1278 let norm = cfg.hidden_size;
1279 embed_tokens + lm_head + norm
1280 };
1281
1282 let image_newline = cfg.text_config.hidden_size;
1283 let mmproj = {
1284 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1285 + cfg.text_config.hidden_size;
1286 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1287 + cfg.text_config.hidden_size;
1288
1289 linear_1 + linear_2
1290 };
1291 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1292
1293 let elems = text_elems + image_newline + mmproj + vision_tower;
1294 Ok(elems * dtype.size_in_bytes())
1295 }
1296
1297 fn layer_sizes_in_bytes(
1298 &self,
1299 config: &str,
1300 dtype: DType,
1301 weight_pack_factor: usize,
1302 _matformer_config: Option<&MatformerSliceConfig>,
1303 ) -> Result<Vec<usize>> {
1304 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1305 let per_layer_elems = {
1306 let cfg = &cfg.text_config;
1307 let input_layernorm = cfg.hidden_size;
1308 let post_attention_layernorm = cfg.hidden_size;
1309
1310 let size_in = cfg.hidden_size;
1311 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1312 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1313 let q_proj = size_in * size_q / weight_pack_factor;
1314 let k_proj = size_in * size_kv / weight_pack_factor;
1315 let v_proj = size_in * size_kv / weight_pack_factor;
1316 let o_proj = size_q * size_in / weight_pack_factor;
1317
1318 let h_size = cfg.hidden_size;
1319 let i_size = cfg.intermediate_size;
1320 let gate_proj = h_size * i_size / weight_pack_factor;
1321 let up_proj = h_size * i_size / weight_pack_factor;
1322 let down_proj = i_size * h_size / weight_pack_factor;
1323
1324 input_layernorm
1325 + post_attention_layernorm
1326 + q_proj
1327 + k_proj
1328 + v_proj
1329 + o_proj
1330 + gate_proj
1331 + up_proj
1332 + down_proj
1333 };
1334 Ok(vec![
1335 per_layer_elems * dtype.size_in_bytes();
1336 cfg.text_config.num_hidden_layers
1337 ])
1338 }
1339
1340 fn num_layers(&self, config: &str) -> Result<usize> {
1341 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1342 Ok(cfg.text_config.num_hidden_layers)
1343 }
1344
1345 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1346 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1347 let cfg = &cfg.text_config;
1348
1349 let cfg = ModelConfigMetadata {
1350 max_seq_len: cfg.max_position_embeddings,
1351 num_layers: cfg.num_hidden_layers,
1352 hidden_size: cfg.hidden_size,
1353 num_kv_heads: cfg.num_key_value_heads,
1354 num_attn_heads: cfg.num_attention_heads,
1355 sliding_window: cfg.sliding_window,
1356 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1357 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1358 };
1359
1360 Ok(Box::new(cfg))
1361 }
1362
1363 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1364 Some(vec![NonMappedSubModel::Vision])
1365 }
1366}
1367
1368pub struct LLaVALoader;
1374
1375pub struct LLaVAPrefixer;
1376
1377impl MultimodalPromptPrefixer for LLaVAPrefixer {
1378 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1379 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1380 }
1381}
1382
1383impl VisionModelLoader for LLaVALoader {
1384 fn load(
1385 &self,
1386 config: &str,
1387 vb: ShardedVarBuilder,
1388 normal_loading_metadata: NormalLoadingMetadata,
1389 attention_mechanism: AttentionImplementation,
1390 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1391 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1392 Ok(Box::new(LLaVA::new(
1393 &cfg,
1394 vb,
1395 self.is_gptx(config),
1396 normal_loading_metadata,
1397 attention_mechanism,
1398 )?))
1399 }
1400 fn is_gptx(&self, _config: &str) -> bool {
1401 false
1402 }
1403 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1404 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1405 Ok(Box::new(cfg))
1406 }
1407 fn get_processor(
1408 &self,
1409 model_config: &str,
1410 _processor_config: Option<ProcessorConfig>,
1411 _preprocessor_config: PreProcessorConfig,
1412 _max_edge: Option<u32>,
1413 ) -> Arc<dyn Processor + Send + Sync> {
1414 Arc::new(LLaVAProcessor::new(model_config))
1415 }
1416 fn supports_paged_attention(&self, _config: &str) -> bool {
1417 true
1418 }
1419 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1420 true
1421 }
1422 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1423 Arc::new(LLaVAPrefixer)
1424 }
1425 fn modalities(&self, _config: &str) -> Result<Modalities> {
1426 Ok(Modalities {
1427 input: vec![SupportedModality::Text, SupportedModality::Vision],
1428 output: vec![SupportedModality::Text],
1429 })
1430 }
1431}
1432
1433impl IsqModelLoader for LLaVALoader {
1434 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1435 Ok(vec![
1436 Regex::new(r"lm_head\.(weight|bias)$")?,
1437 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1439 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1440 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1441 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1442 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1444 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1445 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1446 ])
1447 }
1448 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1449 Ok(vec![
1450 Regex::new(r"lm_head\.(weight|bias)$")?,
1451 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1453 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1454 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1455 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1456 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1458 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1459 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1460 ])
1461 }
1462}
1463
1464impl DeviceMappedModelLoader for LLaVALoader {
1465 fn mapped_max_act_size_elems(
1466 &self,
1467 config: &str,
1468 params: &AutoDeviceMapParams,
1469 _prompt_chunksize: usize,
1470 ) -> Result<usize> {
1471 let AutoDeviceMapParams::Vision {
1472 max_seq_len,
1473 max_batch_size,
1474 max_image_shape: _,
1475 max_num_images,
1476 } = params
1477 else {
1478 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1479 };
1480
1481 let config: LLaVAConfig = serde_json::from_str(config)?;
1482
1483 let img_seq_len =
1484 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1485 let img_seq_len = img_seq_len * max_num_images;
1486
1487 let max_text_attn = {
1488 let cfg = &config.text_config;
1489 let max_seq_len = img_seq_len + max_seq_len;
1491
1492 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1493 };
1494
1495 Ok(max_text_attn)
1496 }
1497
1498 fn non_mapped_max_act_size_elems(
1499 &self,
1500 config: &str,
1501 params: &AutoDeviceMapParams,
1502 ) -> Result<usize> {
1503 let AutoDeviceMapParams::Vision {
1504 max_seq_len: _,
1505 max_batch_size,
1506 max_image_shape: _,
1507 max_num_images,
1508 } = params
1509 else {
1510 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1511 };
1512
1513 let config: LLaVAConfig = serde_json::from_str(config)?;
1514
1515 let img_seq_len =
1516 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1517
1518 let max_vision_attn = {
1519 (max_batch_size * max_num_images)
1520 * config.vision_config.num_attention_heads
1521 * img_seq_len
1522 * img_seq_len
1523 };
1524
1525 Ok(max_vision_attn)
1526 }
1527
1528 fn non_mapped_size_in_bytes(
1529 &self,
1530 config: &str,
1531 dtype: DType,
1532 weight_pack_factor: usize,
1533 _matformer_config: Option<&MatformerSliceConfig>,
1534 ) -> Result<usize> {
1535 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1536 let text_elems = {
1537 let cfg = &cfg.text_config;
1538 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1539 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1540 let norm = cfg.hidden_size;
1541 embed_tokens + lm_head + norm
1542 };
1543
1544 let image_newline = cfg.text_config.hidden_size;
1545 let mmproj = {
1546 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1547 + cfg.text_config.hidden_size;
1548 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1549 + cfg.text_config.hidden_size;
1550
1551 linear_1 + linear_2
1552 };
1553 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1554
1555 let elems = text_elems + image_newline + mmproj + vision_tower;
1556 Ok(elems * dtype.size_in_bytes())
1557 }
1558
1559 fn layer_sizes_in_bytes(
1560 &self,
1561 config: &str,
1562 dtype: DType,
1563 weight_pack_factor: usize,
1564 _matformer_config: Option<&MatformerSliceConfig>,
1565 ) -> Result<Vec<usize>> {
1566 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1567 let per_layer_elems = {
1568 let cfg = &cfg.text_config;
1569 let input_layernorm = cfg.hidden_size;
1570 let post_attention_layernorm = cfg.hidden_size;
1571
1572 let size_in = cfg.hidden_size;
1573 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1574 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1575 let q_proj = size_in * size_q / weight_pack_factor;
1576 let k_proj = size_in * size_kv / weight_pack_factor;
1577 let v_proj = size_in * size_kv / weight_pack_factor;
1578 let o_proj = size_q * size_in / weight_pack_factor;
1579
1580 let h_size = cfg.hidden_size;
1581 let i_size = cfg.intermediate_size;
1582 let gate_proj = h_size * i_size / weight_pack_factor;
1583 let up_proj = h_size * i_size / weight_pack_factor;
1584 let down_proj = i_size * h_size / weight_pack_factor;
1585
1586 input_layernorm
1587 + post_attention_layernorm
1588 + q_proj
1589 + k_proj
1590 + v_proj
1591 + o_proj
1592 + gate_proj
1593 + up_proj
1594 + down_proj
1595 };
1596 Ok(vec![
1597 per_layer_elems * dtype.size_in_bytes();
1598 cfg.text_config.num_hidden_layers
1599 ])
1600 }
1601
1602 fn num_layers(&self, config: &str) -> Result<usize> {
1603 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1604 Ok(cfg.text_config.num_hidden_layers)
1605 }
1606
1607 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1608 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1609 let cfg = &cfg.text_config;
1610
1611 let cfg = ModelConfigMetadata {
1612 max_seq_len: cfg.max_position_embeddings,
1613 num_layers: cfg.num_hidden_layers,
1614 hidden_size: cfg.hidden_size,
1615 num_kv_heads: cfg.num_key_value_heads,
1616 num_attn_heads: cfg.num_attention_heads,
1617 sliding_window: cfg.sliding_window,
1618 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1619 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1620 };
1621
1622 Ok(Box::new(cfg))
1623 }
1624
1625 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1626 Some(vec![NonMappedSubModel::Vision])
1627 }
1628}
1629
1630pub struct VLlamaLoader;
1636
1637pub struct VLlamaPrefixer;
1638
1639impl MultimodalPromptPrefixer for VLlamaPrefixer {
1640 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1641 format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1642 }
1643}
1644
1645impl VisionModelLoader for VLlamaLoader {
1646 fn load(
1647 &self,
1648 config: &str,
1649 vb: ShardedVarBuilder,
1650 normal_loading_metadata: NormalLoadingMetadata,
1651 attention_mechanism: AttentionImplementation,
1652 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1653 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1654 Ok(Box::new(MLlamaModel::new(
1655 &cfg,
1656 vb,
1657 self.is_gptx(config),
1658 normal_loading_metadata,
1659 attention_mechanism,
1660 )?))
1661 }
1662 fn is_gptx(&self, _config: &str) -> bool {
1663 true
1664 }
1665 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1666 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1667 Ok(Box::new(cfg))
1668 }
1669 fn get_processor(
1670 &self,
1671 _model_config: &str,
1672 _processor_config: Option<ProcessorConfig>,
1673 _preprocessor_config: PreProcessorConfig,
1674 _max_edge: Option<u32>,
1675 ) -> Arc<dyn Processor + Send + Sync> {
1676 Arc::new(MLlamaProcessor::new())
1677 }
1678 fn supports_paged_attention(&self, _config: &str) -> bool {
1679 false
1680 }
1681 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1682 true
1683 }
1684 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1685 Arc::new(VLlamaPrefixer)
1686 }
1687 fn modalities(&self, _config: &str) -> Result<Modalities> {
1688 Ok(Modalities {
1689 input: vec![SupportedModality::Text, SupportedModality::Vision],
1690 output: vec![SupportedModality::Text],
1691 })
1692 }
1693}
1694
1695impl IsqModelLoader for VLlamaLoader {
1696 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1697 let config: MLlamaConfig = serde_json::from_str(config)?;
1698 let cross_attn_layers = &config.text_config.cross_attention_layers;
1699 let transformer_layers =
1700 (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1701 let mut text_regexes = Vec::new();
1702 for layer in transformer_layers {
1703 text_regexes.extend(vec![
1704 Regex::new(&format!(
1706 r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1707 ))?,
1708 Regex::new(&format!(
1709 r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1710 ))?,
1711 Regex::new(&format!(
1712 r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1713 ))?,
1714 Regex::new(&format!(
1715 r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1716 ))?,
1717 Regex::new(&format!(
1719 r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1720 ))?,
1721 Regex::new(&format!(
1722 r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1723 ))?,
1724 Regex::new(&format!(
1725 r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1726 ))?,
1727 ]);
1728 }
1729 let vision_regexes = vec![
1730 Regex::new(
1732 r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1733 )?,
1734 Regex::new(
1735 r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1736 )?,
1737 Regex::new(
1738 r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1739 )?,
1740 Regex::new(
1741 r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1742 )?,
1743 Regex::new(
1745 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1746 )?,
1747 Regex::new(
1748 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1749 )?,
1750 Regex::new(
1751 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1752 )?,
1753 Regex::new(
1754 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1755 )?,
1756 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1758 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1759 ];
1760
1761 Ok([text_regexes, vision_regexes].concat())
1762 }
1763 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1764 self.isq_layer_regexes(config)
1765 }
1766}
1767
1768impl DeviceMappedModelLoader for VLlamaLoader {
1769 fn mapped_max_act_size_elems(
1770 &self,
1771 config: &str,
1772 params: &AutoDeviceMapParams,
1773 _prompt_chunksize: usize,
1774 ) -> Result<usize> {
1775 let AutoDeviceMapParams::Vision {
1776 max_seq_len,
1777 max_batch_size,
1778 max_image_shape: _,
1779 max_num_images,
1780 } = params
1781 else {
1782 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1783 };
1784
1785 let config: MLlamaConfig = serde_json::from_str(config)?;
1786
1787 let img_seq_len = {
1788 let cfg = &config.vision_config;
1789 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1790 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1791 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1792 };
1793 let img_seq_len = img_seq_len * max_num_images;
1794
1795 let max_cross_text_attn = {
1796 let cfg = &config.text_config;
1797 max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1798 };
1799
1800 let max_self_text_attn = {
1801 let cfg = &config.text_config;
1802 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1803 };
1804
1805 Ok(max_self_text_attn.max(max_cross_text_attn))
1806 }
1807
1808 fn non_mapped_max_act_size_elems(
1809 &self,
1810 config: &str,
1811 params: &AutoDeviceMapParams,
1812 ) -> Result<usize> {
1813 let AutoDeviceMapParams::Vision {
1814 max_seq_len: _,
1815 max_batch_size,
1816 max_image_shape: _,
1817 max_num_images,
1818 } = params
1819 else {
1820 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1821 };
1822
1823 let config: MLlamaConfig = serde_json::from_str(config)?;
1824
1825 let img_seq_len = {
1826 let cfg = &config.vision_config;
1827 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1828 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1829 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1830 };
1831 let max_vision_attn = {
1832 let cfg = &config.vision_config;
1833 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1834 };
1835
1836 Ok(max_vision_attn)
1837 }
1838
1839 fn non_mapped_size_in_bytes(
1840 &self,
1841 config: &str,
1842 dtype: DType,
1843 weight_pack_factor: usize,
1844 _matformer_config: Option<&MatformerSliceConfig>,
1845 ) -> Result<usize> {
1846 let config: MLlamaConfig = serde_json::from_str(config)?;
1847 let text_elems = {
1848 let cfg = &config.text_config;
1849 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1850 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1852 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1853 } else {
1854 0
1855 };
1856 let norm = cfg.hidden_size;
1857 embed_tokens + lm_head + norm
1858 };
1859
1860 let vision_elems = {
1861 let cfg = &config.vision_config;
1862
1863 let conv_cfg = Conv2dConfig {
1864 stride: cfg.patch_size,
1865 ..Default::default()
1866 };
1867 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1868 * cfg.patch_size
1869 * cfg.patch_size;
1870
1871 let class_embedding = cfg.hidden_size;
1872
1873 let gated_positional_embedding = {
1874 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1875 let embedding = num_patches * cfg.hidden_size;
1876 let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1877 * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1878
1879 embedding + tile_embedding
1880 };
1881
1882 let pre_tile_positional_embedding =
1883 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1884 let post_tile_positional_embedding =
1885 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1886
1887 let layernorm_pre = cfg.hidden_size;
1888 let layernorm_post = cfg.hidden_size;
1889
1890 let encoder_layer = {
1891 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1892 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1893
1894 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1895 let q_proj =
1896 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1897 let k_proj =
1898 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1899 let v_proj =
1900 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1901 let o_proj =
1902 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1903
1904 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1905 + cfg.intermediate_size;
1906 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1907 + cfg.hidden_size;
1908
1909 input_layernorm
1910 + post_attention_layernorm
1911 + q_proj
1912 + k_proj
1913 + v_proj
1914 + o_proj
1915 + fc1
1916 + fc2
1917 };
1918
1919 patch_embedding
1920 + class_embedding
1921 + gated_positional_embedding
1922 + pre_tile_positional_embedding
1923 + post_tile_positional_embedding
1924 + layernorm_pre
1925 + layernorm_post
1926 + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1927 };
1928
1929 let elems = text_elems + vision_elems;
1930 Ok(elems * dtype.size_in_bytes())
1931 }
1932
1933 fn layer_sizes_in_bytes(
1934 &self,
1935 config: &str,
1936 dtype: DType,
1937 weight_pack_factor: usize,
1938 _matformer_config: Option<&MatformerSliceConfig>,
1939 ) -> Result<Vec<usize>> {
1940 let config: MLlamaConfig = serde_json::from_str(config)?;
1941 let cfg = &config.text_config;
1942
1943 let mut layer_sizes = Vec::new();
1944
1945 for i in 0..cfg.num_hidden_layers {
1946 let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1947 1
1949 } else {
1950 weight_pack_factor
1951 };
1952
1953 let per_layer_elems = {
1954 let input_layernorm = cfg.hidden_size;
1955 let post_attention_layernorm = cfg.hidden_size;
1956
1957 let size_in = cfg.hidden_size;
1958 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1959 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1960 let q_proj = size_in * size_q / weight_pack_factor;
1961 let k_proj = size_in * size_kv / weight_pack_factor;
1962 let v_proj = size_in * size_kv / weight_pack_factor;
1963 let o_proj = size_q * size_in / weight_pack_factor;
1964
1965 let h_size = cfg.hidden_size;
1966 let i_size = cfg.intermediate_size;
1967 let gate_proj = h_size * i_size / weight_pack_factor;
1968 let up_proj = h_size * i_size / weight_pack_factor;
1969 let down_proj = i_size * h_size / weight_pack_factor;
1970
1971 input_layernorm
1972 + post_attention_layernorm
1973 + q_proj
1974 + k_proj
1975 + v_proj
1976 + o_proj
1977 + gate_proj
1978 + up_proj
1979 + down_proj
1980 };
1981
1982 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1983 }
1984
1985 Ok(layer_sizes)
1986 }
1987
1988 fn num_layers(&self, config: &str) -> Result<usize> {
1989 let config: MLlamaConfig = serde_json::from_str(config)?;
1990 Ok(config.text_config.num_hidden_layers)
1991 }
1992
1993 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1994 let cfg: MLlamaConfig = serde_json::from_str(config)?;
1995 let cfg = &cfg.text_config;
1996
1997 let cfg = ModelConfigMetadata {
1998 max_seq_len: cfg.max_position_embeddings,
1999 num_layers: cfg.num_hidden_layers,
2000 hidden_size: cfg.hidden_size,
2001 num_kv_heads: cfg.num_key_value_heads,
2002 num_attn_heads: cfg.num_attention_heads,
2003 sliding_window: None,
2004 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2005 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2006 };
2007
2008 Ok(Box::new(cfg))
2009 }
2010
2011 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2012 Some(vec![NonMappedSubModel::Vision])
2013 }
2014}
2015
2016pub struct Qwen2VLLoader;
2022
2023pub struct Qwen2VLPrefixer;
2024
2025impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
2026 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2027 format!(
2028 "{}{prompt}",
2029 format!(
2030 "{}{}{}",
2031 Qwen2VLProcessor::VISION_START,
2032 Qwen2VLProcessor::IMAGE_PAD,
2033 Qwen2VLProcessor::VISION_END
2034 )
2035 .repeat(image_indexes.len())
2036 )
2037 }
2038}
2039
2040impl VisionModelLoader for Qwen2VLLoader {
2041 fn load(
2042 &self,
2043 config: &str,
2044 vb: ShardedVarBuilder,
2045 normal_loading_metadata: NormalLoadingMetadata,
2046 attention_mechanism: AttentionImplementation,
2047 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2048 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2049 Ok(Box::new(Qwen2VLModel::new(
2050 &cfg,
2051 vb,
2052 self.is_gptx(config),
2053 normal_loading_metadata,
2054 attention_mechanism,
2055 )?))
2056 }
2057 fn is_gptx(&self, _config: &str) -> bool {
2058 true
2059 }
2060 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2061 let config: Qwen2VLConfig = serde_json::from_str(config)?;
2062 Ok(Box::new(config))
2063 }
2064 fn get_processor(
2065 &self,
2066 _model_config: &str,
2067 _processor_config: Option<ProcessorConfig>,
2068 _preprocessor_config: PreProcessorConfig,
2069 max_edge: Option<u32>,
2070 ) -> Arc<dyn Processor + Send + Sync> {
2071 Arc::new(Qwen2VLProcessor::new(max_edge))
2072 }
2073 fn supports_paged_attention(&self, _config: &str) -> bool {
2074 false
2075 }
2076 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2077 Arc::new(Qwen2VLPrefixer)
2078 }
2079 fn modalities(&self, _config: &str) -> Result<Modalities> {
2080 Ok(Modalities {
2081 input: vec![SupportedModality::Text, SupportedModality::Vision],
2082 output: vec![SupportedModality::Text],
2083 })
2084 }
2085}
2086
2087impl IsqModelLoader for Qwen2VLLoader {
2088 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2089 Ok(vec![
2090 Regex::new(r"lm_head\.(weight|bias)$")?,
2091 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2093 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2094 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2095 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2096 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2098 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2099 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2100 ])
2101 }
2102 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2103 self.isq_layer_regexes(config)
2104 }
2105}
2106
2107impl DeviceMappedModelLoader for Qwen2VLLoader {
2108 fn mapped_max_act_size_elems(
2109 &self,
2110 config: &str,
2111 params: &AutoDeviceMapParams,
2112 _prompt_chunksize: usize,
2113 ) -> Result<usize> {
2114 let AutoDeviceMapParams::Vision {
2115 max_seq_len,
2116 max_batch_size,
2117 max_image_shape,
2118 max_num_images,
2119 } = params
2120 else {
2121 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2122 };
2123
2124 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2125
2126 let img_seq_len = {
2127 let cfg = &cfg.vision_config;
2128 let grid_t = max_num_images / cfg.temporal_patch_size;
2129 let grid_h = max_image_shape.0 / cfg.patch_size;
2130 let grid_w = max_image_shape.1 / cfg.patch_size;
2131 grid_t * grid_h * grid_w
2132 };
2133 let img_seq_len = img_seq_len * max_num_images;
2134
2135 let max_text_attn = {
2136 let max_seq_len = img_seq_len + max_seq_len;
2138 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2139 };
2140
2141 Ok(max_text_attn)
2142 }
2143
2144 fn non_mapped_max_act_size_elems(
2145 &self,
2146 config: &str,
2147 params: &AutoDeviceMapParams,
2148 ) -> Result<usize> {
2149 let AutoDeviceMapParams::Vision {
2150 max_seq_len: _,
2151 max_batch_size,
2152 max_image_shape,
2153 max_num_images,
2154 } = params
2155 else {
2156 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2157 };
2158
2159 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2160
2161 let img_seq_len = {
2162 let cfg = &cfg.vision_config;
2163 let grid_t = max_num_images / cfg.temporal_patch_size;
2164 let grid_h = max_image_shape.0 / cfg.patch_size;
2165 let grid_w = max_image_shape.1 / cfg.patch_size;
2166 grid_t * grid_h * grid_w
2167 };
2168
2169 let max_vision_attn = {
2170 let cfg = &cfg.vision_config;
2171 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2172 };
2173
2174 Ok(max_vision_attn)
2175 }
2176
2177 fn non_mapped_size_in_bytes(
2178 &self,
2179 config: &str,
2180 dtype: DType,
2181 weight_pack_factor: usize,
2182 _matformer_config: Option<&MatformerSliceConfig>,
2183 ) -> Result<usize> {
2184 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2185 let text_elems = {
2186 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2187 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2189 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2190 } else {
2191 0
2192 };
2193 let norm = cfg.hidden_size;
2194 embed_tokens + lm_head + norm
2195 };
2196
2197 let patch_merger = {
2198 let cfg = &cfg.vision_config;
2199 let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2200
2201 let mlp0 = hidden_size * hidden_size + hidden_size;
2202 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2203
2204 let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2205
2206 mlp0 + mlp2 + ln_q
2207 };
2208
2209 let patch_embed = {
2210 let cfg = &cfg.vision_config;
2211 let conv_cfg = Conv3dConfig {
2212 stride: cfg.patch_size,
2213 ..Default::default()
2214 };
2215 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2216 cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2217 * kernel_sizes[0]
2218 * kernel_sizes[1]
2219 * kernel_sizes[2]
2220 };
2221
2222 let encoder_layer = {
2223 let cfg = &cfg.vision_config;
2224 let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2225 let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2226
2227 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2228 let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2229 let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2230 let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2231
2232 let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2233 let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2234
2235 norm1 + norm2 + fc1 + fc2 + qkv + out
2236 };
2237
2238 let elems =
2239 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2240
2241 Ok(elems * dtype.size_in_bytes())
2242 }
2243
2244 fn layer_sizes_in_bytes(
2245 &self,
2246 config: &str,
2247 dtype: DType,
2248 weight_pack_factor: usize,
2249 _matformer_config: Option<&MatformerSliceConfig>,
2250 ) -> Result<Vec<usize>> {
2251 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2252 let per_layer_elems = {
2253 let input_layernorm = cfg.hidden_size;
2254 let post_attention_layernorm = cfg.hidden_size;
2255
2256 let size_in = cfg.hidden_size;
2257 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2258 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2259 let q_proj = size_in * size_q / weight_pack_factor + size_q;
2260 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2261 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2262 let o_proj = size_q * size_in / weight_pack_factor;
2263
2264 let h_size = cfg.hidden_size;
2265 let i_size = cfg.intermediate_size;
2266 let gate_proj = h_size * i_size / weight_pack_factor;
2267 let up_proj = h_size * i_size / weight_pack_factor;
2268 let down_proj = i_size * h_size / weight_pack_factor;
2269
2270 input_layernorm
2271 + post_attention_layernorm
2272 + q_proj
2273 + k_proj
2274 + v_proj
2275 + o_proj
2276 + gate_proj
2277 + up_proj
2278 + down_proj
2279 };
2280 Ok(vec![
2281 per_layer_elems * dtype.size_in_bytes();
2282 cfg.num_hidden_layers
2283 ])
2284 }
2285
2286 fn num_layers(&self, config: &str) -> Result<usize> {
2287 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2288 Ok(cfg.num_hidden_layers)
2289 }
2290
2291 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2292 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2293
2294 let cfg = ModelConfigMetadata {
2295 max_seq_len: cfg.max_position_embeddings,
2296 num_layers: cfg.num_hidden_layers,
2297 hidden_size: cfg.hidden_size,
2298 num_kv_heads: cfg.num_key_value_heads,
2299 num_attn_heads: cfg.num_attention_heads,
2300 sliding_window: cfg.sliding_window,
2301 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2302 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2303 };
2304
2305 Ok(Box::new(cfg))
2306 }
2307
2308 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2309 Some(vec![NonMappedSubModel::Vision])
2310 }
2311}
2312
2313pub struct Idefics3Loader;
2319
2320pub struct Idefics3Prefixer;
2321
2322impl MultimodalPromptPrefixer for Idefics3Prefixer {
2323 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2324 prompt.to_string()
2326 }
2327}
2328
2329impl VisionModelLoader for Idefics3Loader {
2330 fn load(
2331 &self,
2332 config: &str,
2333 vb: ShardedVarBuilder,
2334 normal_loading_metadata: NormalLoadingMetadata,
2335 attention_mechanism: AttentionImplementation,
2336 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2337 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2338 Ok(Box::new(Idefics3Model::new(
2339 &cfg,
2340 vb,
2341 self.is_gptx(config),
2342 normal_loading_metadata,
2343 attention_mechanism,
2344 )?))
2345 }
2346 fn is_gptx(&self, _config: &str) -> bool {
2347 true
2348 }
2349 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2350 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2351 Ok(Box::new(cfg))
2352 }
2353 fn get_processor(
2354 &self,
2355 _model_config: &str,
2356 processor_config: Option<ProcessorConfig>,
2357 preprocessor_config: PreProcessorConfig,
2358 max_edge: Option<u32>,
2359 ) -> Arc<dyn Processor + Send + Sync> {
2360 Arc::new(Idefics3Processor::new(
2361 processor_config.unwrap_or_default(),
2362 preprocessor_config,
2363 max_edge,
2364 ))
2365 }
2366 fn supports_paged_attention(&self, _config: &str) -> bool {
2367 true
2368 }
2369 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2370 true
2371 }
2372 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2373 Arc::new(Idefics3Prefixer)
2374 }
2375 fn modalities(&self, _config: &str) -> Result<Modalities> {
2376 Ok(Modalities {
2377 input: vec![SupportedModality::Text, SupportedModality::Vision],
2378 output: vec![SupportedModality::Text],
2379 })
2380 }
2381}
2382
2383impl IsqModelLoader for Idefics3Loader {
2384 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2385 Ok(vec![
2386 Regex::new(r"lm_head\.(weight|bias)$")?,
2387 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2389 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2390 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2391 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2392 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2394 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2395 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2396 ])
2397 }
2398 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2399 Ok(vec![
2400 Regex::new(r"lm_head\.(weight|bias)$")?,
2401 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2403 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2404 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2405 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2406 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2408 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2409 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2410 ])
2427 }
2428}
2429
2430impl DeviceMappedModelLoader for Idefics3Loader {
2431 fn mapped_max_act_size_elems(
2432 &self,
2433 config: &str,
2434 params: &AutoDeviceMapParams,
2435 _prompt_chunksize: usize,
2436 ) -> Result<usize> {
2437 let AutoDeviceMapParams::Vision {
2438 max_seq_len,
2439 max_batch_size,
2440 max_image_shape: _,
2441 max_num_images,
2442 } = params
2443 else {
2444 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2445 };
2446
2447 let cfg: Idefics3Config = serde_json::from_str(config)?;
2448
2449 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2450 let img_seq_len = (num_patches + 1) * max_num_images;
2451
2452 let max_text_attn = {
2453 let max_seq_len = img_seq_len + max_seq_len;
2455 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2456 };
2457
2458 Ok(max_text_attn)
2459 }
2460
2461 fn non_mapped_max_act_size_elems(
2462 &self,
2463 config: &str,
2464 params: &AutoDeviceMapParams,
2465 ) -> Result<usize> {
2466 let AutoDeviceMapParams::Vision {
2467 max_seq_len: _,
2468 max_batch_size,
2469 max_image_shape: _,
2470 max_num_images,
2471 } = params
2472 else {
2473 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2474 };
2475
2476 let cfg: Idefics3Config = serde_json::from_str(config)?;
2477
2478 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2479 let img_seq_len = num_patches + 1;
2480
2481 let max_vision_attn = {
2482 let images_factor = 5;
2484
2485 (max_batch_size * images_factor * max_num_images)
2486 * cfg.vision_config.num_attention_heads
2487 * img_seq_len
2488 * img_seq_len
2489 };
2490
2491 Ok(max_vision_attn)
2492 }
2493
2494 fn non_mapped_size_in_bytes(
2495 &self,
2496 config: &str,
2497 dtype: DType,
2498 weight_pack_factor: usize,
2499 _matformer_config: Option<&MatformerSliceConfig>,
2500 ) -> Result<usize> {
2501 let cfg: Idefics3Config = serde_json::from_str(config)?;
2502 let text_elems = {
2503 let cfg = &cfg.text_config;
2504
2505 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2506 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2507 let norm = cfg.hidden_size;
2508 embed_tokens + lm_head + norm
2509 };
2510
2511 let connector_elems = {
2512 let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2513 let out_dim = cfg.text_config.hidden_size;
2514
2515 in_dim * out_dim
2516 };
2517
2518 let vision_transformer = {
2519 let cfg = &cfg.vision_config;
2520
2521 let post_layernorm = cfg.hidden_size;
2522
2523 let conv_config = Conv2dConfig {
2524 stride: cfg.patch_size,
2525 ..Default::default()
2526 };
2527 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2528 * cfg.patch_size
2529 * cfg.patch_size;
2530
2531 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2532 let num_patches = num_patches_per_side.pow(2);
2533 let position_embedding = num_patches * cfg.hidden_size;
2534
2535 let layer_elems = {
2536 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2537 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2538
2539 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2540 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2541
2542 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2543 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2544 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2545 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2546
2547 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2548 };
2549
2550 post_layernorm
2551 + patch_embedding
2552 + position_embedding
2553 + layer_elems * cfg.num_hidden_layers
2554 };
2555
2556 let elems = text_elems + connector_elems + vision_transformer;
2557
2558 Ok(elems * dtype.size_in_bytes())
2559 }
2560
2561 fn layer_sizes_in_bytes(
2562 &self,
2563 config: &str,
2564 dtype: DType,
2565 weight_pack_factor: usize,
2566 _matformer_config: Option<&MatformerSliceConfig>,
2567 ) -> Result<Vec<usize>> {
2568 let cfg: Idefics3Config = serde_json::from_str(config)?;
2569 let cfg = cfg.text_config;
2570 let per_layer_elems = {
2571 let input_layernorm = cfg.hidden_size;
2572 let post_attention_layernorm = cfg.hidden_size;
2573
2574 let size_in = cfg.hidden_size;
2575 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2576 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2577 let q_proj = size_in * size_q / weight_pack_factor;
2578 let k_proj = size_in * size_kv / weight_pack_factor;
2579 let v_proj = size_in * size_kv / weight_pack_factor;
2580 let o_proj = size_q * size_in / weight_pack_factor;
2581
2582 let h_size = cfg.hidden_size;
2583 let i_size = cfg.intermediate_size;
2584 let gate_proj = h_size * i_size / weight_pack_factor;
2585 let up_proj = h_size * i_size / weight_pack_factor;
2586 let down_proj = i_size * h_size / weight_pack_factor;
2587
2588 input_layernorm
2589 + post_attention_layernorm
2590 + q_proj
2591 + k_proj
2592 + v_proj
2593 + o_proj
2594 + gate_proj
2595 + up_proj
2596 + down_proj
2597 };
2598 Ok(vec![
2599 per_layer_elems * dtype.size_in_bytes();
2600 cfg.num_hidden_layers
2601 ])
2602 }
2603
2604 fn num_layers(&self, config: &str) -> Result<usize> {
2605 let cfg: Idefics3Config = serde_json::from_str(config)?;
2606 Ok(cfg.text_config.num_hidden_layers)
2607 }
2608 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2609 let cfg: Idefics3Config = serde_json::from_str(config)?;
2610 let cfg = &cfg.text_config;
2611
2612 let cfg = ModelConfigMetadata {
2613 max_seq_len: cfg.max_position_embeddings,
2614 num_layers: cfg.num_hidden_layers,
2615 hidden_size: cfg.hidden_size,
2616 num_kv_heads: cfg.num_key_value_heads,
2617 num_attn_heads: cfg.num_attention_heads,
2618 sliding_window: None,
2619 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2620 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2621 };
2622
2623 Ok(Box::new(cfg))
2624 }
2625
2626 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2627 Some(vec![NonMappedSubModel::Vision])
2628 }
2629}
2630
2631pub struct MiniCpmOLoader;
2637
2638pub struct MiniCpmOPrefixer;
2639
2640impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2641 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2642 format!(
2643 "{}{prompt}",
2644 "(<image>./</image>)".repeat(image_indexes.len())
2645 )
2646 }
2647}
2648
2649impl VisionModelLoader for MiniCpmOLoader {
2650 fn load(
2651 &self,
2652 config: &str,
2653 vb: ShardedVarBuilder,
2654 normal_loading_metadata: NormalLoadingMetadata,
2655 attention_mechanism: AttentionImplementation,
2656 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2657 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2658 Ok(Box::new(MiniCpmOModel::new(
2659 &cfg,
2660 vb,
2661 self.is_gptx(config),
2662 normal_loading_metadata,
2663 attention_mechanism,
2664 )?))
2665 }
2666 fn is_gptx(&self, _config: &str) -> bool {
2667 true
2668 }
2669 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2670 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2671 Ok(Box::new(cfg))
2672 }
2673 fn get_processor(
2674 &self,
2675 _model_config: &str,
2676 processor_config: Option<ProcessorConfig>,
2677 preprocessor_config: PreProcessorConfig,
2678 max_edge: Option<u32>,
2679 ) -> Arc<dyn Processor + Send + Sync> {
2680 Arc::new(MiniCpmOProcessor::new(
2681 processor_config.unwrap_or_default(),
2682 preprocessor_config,
2683 max_edge,
2684 ))
2685 }
2686 fn supports_paged_attention(&self, _config: &str) -> bool {
2687 true
2688 }
2689 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2690 Arc::new(MiniCpmOPrefixer)
2691 }
2692 fn modalities(&self, _config: &str) -> Result<Modalities> {
2693 Ok(Modalities {
2694 input: vec![SupportedModality::Text, SupportedModality::Vision],
2695 output: vec![SupportedModality::Text],
2696 })
2697 }
2698}
2699
2700impl IsqModelLoader for MiniCpmOLoader {
2701 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2702 Ok(vec![
2703 Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2704 Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2706 Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2707 Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2708 Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2709 Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2711 Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2712 Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2713 ])
2714 }
2715 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2716 self.isq_layer_regexes(config)
2717 }
2718}
2719
2720impl DeviceMappedModelLoader for MiniCpmOLoader {
2721 fn mapped_max_act_size_elems(
2722 &self,
2723 config: &str,
2724 params: &AutoDeviceMapParams,
2725 _prompt_chunksize: usize,
2726 ) -> Result<usize> {
2727 let AutoDeviceMapParams::Vision {
2728 max_seq_len,
2729 max_batch_size,
2730 max_image_shape: _,
2731 max_num_images,
2732 } = params
2733 else {
2734 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2735 };
2736
2737 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2738
2739 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2740 let img_seq_len = (num_patches + 1) * max_num_images;
2741
2742 let max_text_attn = {
2743 let max_seq_len = img_seq_len + max_seq_len;
2745 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2746 };
2747
2748 Ok(max_text_attn)
2749 }
2750
2751 fn non_mapped_max_act_size_elems(
2752 &self,
2753 config: &str,
2754 params: &AutoDeviceMapParams,
2755 ) -> Result<usize> {
2756 let AutoDeviceMapParams::Vision {
2757 max_seq_len: _,
2758 max_batch_size,
2759 max_image_shape: _,
2760 max_num_images,
2761 } = params
2762 else {
2763 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2764 };
2765
2766 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2767
2768 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2769 let img_seq_len = num_patches + 1;
2770
2771 let max_vision_attn = {
2772 let images_factor = 5;
2774
2775 (max_batch_size * images_factor * max_num_images)
2776 * cfg.vision_config.num_attention_heads
2777 * img_seq_len
2778 * img_seq_len
2779 };
2780
2781 Ok(max_vision_attn)
2782 }
2783
2784 fn non_mapped_size_in_bytes(
2785 &self,
2786 config: &str,
2787 dtype: DType,
2788 weight_pack_factor: usize,
2789 _matformer_config: Option<&MatformerSliceConfig>,
2790 ) -> Result<usize> {
2791 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2792 let text_elems = {
2793 let cfg = &cfg.text_config;
2794
2795 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2796 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2797 let norm = cfg.hidden_size;
2798 embed_tokens + lm_head + norm
2799 };
2800
2801 let vision_transformer = {
2802 let cfg = &cfg.vision_config;
2803
2804 let post_layernorm = cfg.hidden_size;
2805
2806 let conv_config = Conv2dConfig {
2807 stride: cfg.patch_size,
2808 ..Default::default()
2809 };
2810 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2811 * cfg.patch_size
2812 * cfg.patch_size;
2813
2814 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2815 let num_patches = num_patches_per_side.pow(2);
2816 let position_embedding = num_patches * cfg.hidden_size;
2817
2818 let layer_elems = {
2819 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2820 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2821
2822 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2823 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2824
2825 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2826 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2827 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2828 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2829
2830 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2831 };
2832
2833 post_layernorm
2834 + patch_embedding
2835 + position_embedding
2836 + layer_elems * cfg.num_hidden_layers
2837 };
2838
2839 let elems = text_elems + vision_transformer;
2840
2841 Ok(elems * dtype.size_in_bytes())
2842 }
2843
2844 fn layer_sizes_in_bytes(
2845 &self,
2846 config: &str,
2847 dtype: DType,
2848 weight_pack_factor: usize,
2849 _matformer_config: Option<&MatformerSliceConfig>,
2850 ) -> Result<Vec<usize>> {
2851 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2852 let cfg = cfg.text_config;
2853 let per_layer_elems = {
2854 let input_layernorm = cfg.hidden_size;
2855 let post_attention_layernorm = cfg.hidden_size;
2856
2857 let size_in = cfg.hidden_size;
2858 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2859 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2860 let q_proj = size_in * size_q / weight_pack_factor;
2861 let k_proj = size_in * size_kv / weight_pack_factor;
2862 let v_proj = size_in * size_kv / weight_pack_factor;
2863 let o_proj = size_q * size_in / weight_pack_factor;
2864
2865 let h_size = cfg.hidden_size;
2866 let i_size = cfg.intermediate_size;
2867 let gate_proj = h_size * i_size / weight_pack_factor;
2868 let up_proj = h_size * i_size / weight_pack_factor;
2869 let down_proj = i_size * h_size / weight_pack_factor;
2870
2871 input_layernorm
2872 + post_attention_layernorm
2873 + q_proj
2874 + k_proj
2875 + v_proj
2876 + o_proj
2877 + gate_proj
2878 + up_proj
2879 + down_proj
2880 };
2881 Ok(vec![
2882 per_layer_elems * dtype.size_in_bytes();
2883 cfg.num_hidden_layers
2884 ])
2885 }
2886
2887 fn num_layers(&self, config: &str) -> Result<usize> {
2888 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2889 Ok(cfg.text_config.num_hidden_layers)
2890 }
2891 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2892 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2893 let cfg = &cfg.text_config;
2894
2895 let cfg = ModelConfigMetadata {
2896 max_seq_len: cfg.max_position_embeddings,
2897 num_layers: cfg.num_hidden_layers,
2898 hidden_size: cfg.hidden_size,
2899 num_kv_heads: cfg.num_key_value_heads,
2900 num_attn_heads: cfg.num_attention_heads,
2901 sliding_window: None,
2902 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2903 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2904 };
2905
2906 Ok(Box::new(cfg))
2907 }
2908}
2909
2910pub struct Phi4MMLoader;
2916
2917pub struct Phi4MMPrefixer;
2918
2919impl MultimodalPromptPrefixer for Phi4MMPrefixer {
2920 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2921 format!(
2924 "{}{prompt}",
2925 image_indexes
2926 .into_iter()
2927 .map(|image_index| format!("<|image_{}|>", image_index + 1))
2928 .join("")
2929 )
2930 }
2931 fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
2932 format!(
2935 "{}{prompt}",
2936 audio_indexes
2937 .into_iter()
2938 .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
2939 .join("")
2940 )
2941 }
2942}
2943
2944impl VisionModelLoader for Phi4MMLoader {
2945 fn load(
2946 &self,
2947 config: &str,
2948 vb: ShardedVarBuilder,
2949 normal_loading_metadata: NormalLoadingMetadata,
2950 attention_mechanism: AttentionImplementation,
2951 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2952 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2953 Ok(Box::new(Phi4MMModel::new(
2954 &cfg,
2955 vb,
2956 self.is_gptx(config),
2957 normal_loading_metadata,
2958 attention_mechanism,
2959 )?))
2960 }
2961 fn is_gptx(&self, _config: &str) -> bool {
2962 true
2963 }
2964 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2965 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2966 Ok(Box::new(cfg))
2967 }
2968 fn get_processor(
2969 &self,
2970 _model_config: &str,
2971 processor_config: Option<ProcessorConfig>,
2972 preprocessor_config: PreProcessorConfig,
2973 _max_edge: Option<u32>,
2974 ) -> Arc<dyn Processor + Send + Sync> {
2975 Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2976 }
2977 fn supports_paged_attention(&self, _config: &str) -> bool {
2978 true
2979 }
2980 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2981 true
2982 }
2983 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2984 Arc::new(Phi4MMPrefixer)
2985 }
2986 fn modalities(&self, _config: &str) -> Result<Modalities> {
2987 Ok(Modalities {
2988 input: vec![
2989 SupportedModality::Text,
2990 SupportedModality::Vision,
2991 SupportedModality::Audio,
2992 ],
2993 output: vec![SupportedModality::Text],
2994 })
2995 }
2996}
2997
2998impl IsqModelLoader for Phi4MMLoader {
2999 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3000 Ok(vec![
3001 Regex::new(r"lm_head\.(weight|bias)$")?,
3002 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
3004 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3005 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
3007 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3008 ])
3009 }
3010 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3011 self.isq_layer_regexes(config)
3012 }
3013}
3014
3015impl DeviceMappedModelLoader for Phi4MMLoader {
3016 fn mapped_max_act_size_elems(
3017 &self,
3018 config: &str,
3019 params: &AutoDeviceMapParams,
3020 _prompt_chunksize: usize,
3021 ) -> Result<usize> {
3022 let AutoDeviceMapParams::Vision {
3024 max_seq_len,
3025 max_batch_size,
3026 max_image_shape: _,
3027 max_num_images,
3028 } = params
3029 else {
3030 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3031 };
3032
3033 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3034
3035 let vcfg = &PHI4_MM_VISION_CFG;
3036
3037 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3038 let img_seq_len = (num_patches + 1) * max_num_images;
3039
3040 let max_text_attn = {
3041 let max_seq_len = img_seq_len + max_seq_len;
3043 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3044 };
3045
3046 Ok(max_text_attn)
3047 }
3048
3049 fn non_mapped_max_act_size_elems(
3050 &self,
3051 _config: &str,
3052 params: &AutoDeviceMapParams,
3053 ) -> Result<usize> {
3054 let AutoDeviceMapParams::Vision {
3055 max_seq_len: _,
3056 max_batch_size,
3057 max_image_shape,
3058 max_num_images,
3059 } = params
3060 else {
3061 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3062 };
3063
3064 let vcfg = &PHI4_MM_VISION_CFG;
3065
3066 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3067 let img_seq_len = num_patches + 1;
3068
3069 let max_batch_size = max_batch_size
3070 * (max_image_shape
3071 .0
3072 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3073 * max_image_shape
3074 .1
3075 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3076 + 1);
3077
3078 let max_vision_attn = (max_batch_size * max_num_images)
3079 * vcfg.num_attention_heads
3080 * img_seq_len
3081 * img_seq_len;
3082 let max_qkv = 3
3083 * (max_batch_size
3084 * vcfg.num_attention_heads
3085 * img_seq_len
3086 * (vcfg.hidden_size / vcfg.num_attention_heads));
3087
3088 Ok(max_vision_attn + max_qkv)
3089 }
3090
3091 fn non_mapped_size_in_bytes(
3092 &self,
3093 config: &str,
3094 dtype: DType,
3095 weight_pack_factor: usize,
3096 _matformer_config: Option<&MatformerSliceConfig>,
3097 ) -> Result<usize> {
3098 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3099 let elems = {
3100 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3101 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3103 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3104 } else {
3105 0
3106 };
3107 let norm = cfg.hidden_size;
3108
3109 let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3110 let projection_cls = img_embed
3111 .projection_cls
3112 .clone()
3113 .unwrap_or("linear".to_string());
3114 let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3115 let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3116 let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3117
3118 let proj = match (projection_cls.as_str(), use_hd_transform) {
3119 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3120 ("mlp", true) => {
3121 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3122 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3123 a + b
3124 }
3125 ("mlp", false) => {
3126 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3127 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3128 a + b
3129 }
3130 _ => {
3131 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3132 }
3133 };
3134
3135 let (glb_gn, sub_gn) = if with_learnable_separator {
3136 let glb_gn = image_dim_out * 4;
3137 let sub_gn = image_dim_out * 4;
3138 (glb_gn, sub_gn)
3139 } else {
3140 (0, 0)
3141 };
3142
3143 let vision_transformer = {
3144 let cfg = &PHI4_MM_VISION_CFG;
3145
3146 let post_layernorm = cfg.hidden_size;
3147
3148 let conv_config = Conv2dConfig {
3149 stride: cfg.patch_size,
3150 ..Default::default()
3151 };
3152 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3153 * cfg.patch_size
3154 * cfg.patch_size;
3155
3156 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3157 let num_patches = num_patches_per_side.pow(2);
3158 let position_embedding = num_patches * cfg.hidden_size;
3159
3160 let layer_elems = {
3161 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3162 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3163
3164 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3165 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3166
3167 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3168 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3169 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3170 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3171
3172 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3173 };
3174
3175 post_layernorm
3176 + patch_embedding
3177 + position_embedding
3178 + layer_elems * cfg.num_hidden_layers
3179 };
3180
3181 proj + glb_gn + sub_gn + vision_transformer
3182 } else {
3183 0
3184 };
3185
3186 embed_tokens + lm_head + norm + image_embed
3187 };
3188
3189 Ok(elems * dtype.size_in_bytes())
3190 }
3191
3192 fn layer_sizes_in_bytes(
3193 &self,
3194 config: &str,
3195 dtype: DType,
3196 weight_pack_factor: usize,
3197 _matformer_config: Option<&MatformerSliceConfig>,
3198 ) -> Result<Vec<usize>> {
3199 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3200 let per_layer_elems = {
3201 let input_layernorm = cfg.hidden_size;
3202 let post_attention_layernorm = cfg.hidden_size;
3203
3204 let size_in = cfg.hidden_size;
3205 let head_dim = cfg.head_dim();
3206 let op_size =
3207 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3208 let qkv_proj = size_in * op_size / weight_pack_factor;
3209 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3210
3211 let h_size = cfg.hidden_size;
3212 let i_size = cfg.intermediate_size;
3213 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3214 let down_proj = h_size * i_size / weight_pack_factor;
3215
3216 input_layernorm
3217 + post_attention_layernorm
3218 + qkv_proj
3219 + o_proj
3220 + gate_up_proj
3221 + down_proj
3222 };
3223 Ok(vec![
3224 per_layer_elems * dtype.size_in_bytes();
3225 cfg.num_hidden_layers
3226 ])
3227 }
3228
3229 fn num_layers(&self, config: &str) -> Result<usize> {
3230 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3231 Ok(cfg.num_hidden_layers)
3232 }
3233
3234 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3235 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3236
3237 let cfg = ModelConfigMetadata {
3238 max_seq_len: cfg.max_position_embeddings,
3239 num_layers: cfg.num_hidden_layers,
3240 hidden_size: cfg.hidden_size,
3241 num_kv_heads: cfg.num_key_value_heads(),
3242 num_attn_heads: cfg.num_attention_heads,
3243 sliding_window: cfg.sliding_window,
3244 k_head_dim: cfg.head_dim(),
3245 v_head_dim: cfg.head_dim(),
3246 };
3247
3248 Ok(Box::new(cfg))
3249 }
3250
3251 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3252 Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
3253 }
3254}
3255
3256pub struct Qwen2_5VLLoader;
3262
3263pub struct Qwen2_5VLPrefixer;
3264
3265impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3266 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3267 format!(
3268 "{}{prompt}",
3269 format!(
3270 "{}{}{}",
3271 Qwen2_5VLProcessor::VISION_START,
3272 Qwen2_5VLProcessor::IMAGE_PAD,
3273 Qwen2_5VLProcessor::VISION_END
3274 )
3275 .repeat(image_indexes.len())
3276 )
3277 }
3278}
3279
3280impl VisionModelLoader for Qwen2_5VLLoader {
3281 fn load(
3282 &self,
3283 config: &str,
3284 vb: ShardedVarBuilder,
3285 normal_loading_metadata: NormalLoadingMetadata,
3286 attention_mechanism: AttentionImplementation,
3287 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3288 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3289 Ok(Box::new(Qwen2_5VLModel::new(
3290 &cfg,
3291 vb,
3292 self.is_gptx(config),
3293 normal_loading_metadata,
3294 attention_mechanism,
3295 )?))
3296 }
3297 fn is_gptx(&self, _config: &str) -> bool {
3298 true
3299 }
3300 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3301 let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3302 Ok(Box::new(config))
3303 }
3304 fn get_processor(
3305 &self,
3306 _model_config: &str,
3307 _processor_config: Option<ProcessorConfig>,
3308 _preprocessor_config: PreProcessorConfig,
3309 max_edge: Option<u32>,
3310 ) -> Arc<dyn Processor + Send + Sync> {
3311 Arc::new(Qwen2_5VLProcessor::new(max_edge))
3312 }
3313 fn supports_paged_attention(&self, _config: &str) -> bool {
3314 false
3315 }
3316 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3317 Arc::new(Qwen2_5VLPrefixer)
3318 }
3319 fn modalities(&self, _config: &str) -> Result<Modalities> {
3320 Ok(Modalities {
3321 input: vec![SupportedModality::Text, SupportedModality::Vision],
3322 output: vec![SupportedModality::Text],
3323 })
3324 }
3325}
3326
3327impl IsqModelLoader for Qwen2_5VLLoader {
3328 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3329 Ok(vec![
3330 Regex::new(r"lm_head\.(weight|bias)$")?,
3331 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3333 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3334 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3335 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3336 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3338 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3339 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3340 ])
3341 }
3342 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3343 self.isq_layer_regexes(config)
3344 }
3345}
3346
3347impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3348 fn mapped_max_act_size_elems(
3349 &self,
3350 config: &str,
3351 params: &AutoDeviceMapParams,
3352 _prompt_chunksize: usize,
3353 ) -> Result<usize> {
3354 let AutoDeviceMapParams::Vision {
3355 max_seq_len,
3356 max_batch_size,
3357 max_image_shape,
3358 max_num_images,
3359 } = params
3360 else {
3361 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3362 };
3363
3364 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3365
3366 let img_seq_len = {
3367 let cfg = &cfg.vision_config;
3368 let grid_t = max_num_images / cfg.temporal_patch_size;
3369 let grid_h = max_image_shape.0 / cfg.patch_size;
3370 let grid_w = max_image_shape.1 / cfg.patch_size;
3371 grid_t * grid_h * grid_w
3372 };
3373 let img_seq_len = img_seq_len * max_num_images;
3374
3375 let max_text_attn = {
3376 let max_seq_len = img_seq_len + max_seq_len;
3378 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3379 };
3380
3381 Ok(max_text_attn)
3382 }
3383
3384 fn non_mapped_max_act_size_elems(
3385 &self,
3386 config: &str,
3387 params: &AutoDeviceMapParams,
3388 ) -> Result<usize> {
3389 let AutoDeviceMapParams::Vision {
3390 max_seq_len: _,
3391 max_batch_size,
3392 max_image_shape,
3393 max_num_images,
3394 } = params
3395 else {
3396 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3397 };
3398
3399 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3400
3401 let img_seq_len = {
3402 let cfg = &cfg.vision_config;
3403 let grid_t = max_num_images / cfg.temporal_patch_size;
3404 let grid_h = max_image_shape.0 / cfg.patch_size;
3405 let grid_w = max_image_shape.1 / cfg.patch_size;
3406 grid_t * grid_h * grid_w
3407 };
3408
3409 let max_vision_attn = {
3410 let cfg = &cfg.vision_config;
3411 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3412 };
3413
3414 Ok(max_vision_attn)
3415 }
3416
3417 fn non_mapped_size_in_bytes(
3418 &self,
3419 config: &str,
3420 dtype: DType,
3421 weight_pack_factor: usize,
3422 _matformer_config: Option<&MatformerSliceConfig>,
3423 ) -> Result<usize> {
3424 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3425 let text_elems = {
3426 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3427 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3429 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3430 } else {
3431 0
3432 };
3433 let norm = cfg.hidden_size;
3434 embed_tokens + lm_head + norm
3435 };
3436
3437 let patch_merger = {
3438 let cfg = &cfg.vision_config;
3439 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3440
3441 let mlp0 = hidden_size * hidden_size + hidden_size;
3442 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3443
3444 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3445
3446 mlp0 + mlp2 + ln_q
3447 };
3448
3449 let patch_embed = {
3450 let cfg = &cfg.vision_config;
3451 let conv_cfg = Conv3dConfig {
3452 stride: cfg.patch_size,
3453 ..Default::default()
3454 };
3455 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3456 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3457 * kernel_sizes[0]
3458 * kernel_sizes[1]
3459 * kernel_sizes[2]
3460 };
3461
3462 let encoder_layer = {
3463 let cfg = &cfg.vision_config;
3464 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3465 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3466
3467 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3468 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3469 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3470
3471 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3472 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3473
3474 norm1 + norm2 + fc1 + fc2 + qkv + out
3475 };
3476
3477 let elems =
3478 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3479
3480 Ok(elems * dtype.size_in_bytes())
3481 }
3482
3483 fn layer_sizes_in_bytes(
3484 &self,
3485 config: &str,
3486 dtype: DType,
3487 weight_pack_factor: usize,
3488 _matformer_config: Option<&MatformerSliceConfig>,
3489 ) -> Result<Vec<usize>> {
3490 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3491 let per_layer_elems = {
3492 let input_layernorm = cfg.hidden_size;
3493 let post_attention_layernorm = cfg.hidden_size;
3494
3495 let size_in = cfg.hidden_size;
3496 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3497 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3498 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3499 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3500 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3501 let o_proj = size_q * size_in / weight_pack_factor;
3502
3503 let h_size = cfg.hidden_size;
3504 let i_size = cfg.intermediate_size;
3505 let gate_proj = h_size * i_size / weight_pack_factor;
3506 let up_proj = h_size * i_size / weight_pack_factor;
3507 let down_proj = i_size * h_size / weight_pack_factor;
3508
3509 input_layernorm
3510 + post_attention_layernorm
3511 + q_proj
3512 + k_proj
3513 + v_proj
3514 + o_proj
3515 + gate_proj
3516 + up_proj
3517 + down_proj
3518 };
3519 Ok(vec![
3520 per_layer_elems * dtype.size_in_bytes();
3521 cfg.num_hidden_layers
3522 ])
3523 }
3524
3525 fn num_layers(&self, config: &str) -> Result<usize> {
3526 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3527 Ok(cfg.num_hidden_layers)
3528 }
3529
3530 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3531 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3532
3533 let cfg = ModelConfigMetadata {
3534 max_seq_len: cfg.max_position_embeddings,
3535 num_layers: cfg.num_hidden_layers,
3536 hidden_size: cfg.hidden_size,
3537 num_kv_heads: cfg.num_key_value_heads,
3538 num_attn_heads: cfg.num_attention_heads,
3539 sliding_window: cfg.sliding_window,
3540 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3541 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3542 };
3543
3544 Ok(Box::new(cfg))
3545 }
3546
3547 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3548 Some(vec![NonMappedSubModel::Vision])
3549 }
3550}
3551
3552pub struct Gemma3Loader;
3558
3559pub struct Gemma3Prefixer;
3560
3561impl MultimodalPromptPrefixer for Gemma3Prefixer {
3562 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3563 prompt.to_string()
3564 }
3565}
3566
3567impl VisionModelLoader for Gemma3Loader {
3568 fn load(
3569 &self,
3570 config: &str,
3571 vb: ShardedVarBuilder,
3572 normal_loading_metadata: NormalLoadingMetadata,
3573 attention_mechanism: AttentionImplementation,
3574 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3575 let cfg: Gemma3Config = serde_json::from_str(config)?;
3576 Ok(Box::new(Gemma3Model::new(
3577 &cfg,
3578 vb,
3579 self.is_gptx(config),
3580 normal_loading_metadata,
3581 attention_mechanism,
3582 )?))
3583 }
3584 fn is_gptx(&self, _config: &str) -> bool {
3585 true
3586 }
3587 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3588 let config: Gemma3Config = serde_json::from_str(config)?;
3589 Ok(Box::new(config))
3590 }
3591 fn get_processor(
3592 &self,
3593 config: &str,
3594 processor_config: Option<ProcessorConfig>,
3595 _preprocessor_config: PreProcessorConfig,
3596 _max_edge: Option<u32>,
3597 ) -> Arc<dyn Processor + Send + Sync> {
3598 let config: Gemma3Config = serde_json::from_str(config).unwrap();
3599 Arc::new(Gemma3Processor::new(
3601 processor_config.unwrap_or_default(),
3602 matches!(config, Gemma3Config::WithVision { .. }),
3603 ))
3604 }
3605 fn supports_paged_attention(&self, _config: &str) -> bool {
3606 true
3607 }
3608 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3609 true
3610 }
3611 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3612 Arc::new(Gemma3Prefixer)
3613 }
3614 fn modalities(&self, _config: &str) -> Result<Modalities> {
3615 Ok(Modalities {
3616 input: vec![SupportedModality::Text, SupportedModality::Vision],
3617 output: vec![SupportedModality::Text],
3618 })
3619 }
3620}
3621
3622impl IsqModelLoader for Gemma3Loader {
3623 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3624 Ok(vec![
3625 Regex::new(r"lm_head\.(weight|bias)$")?,
3626 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3628 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3629 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3630 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3631 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3633 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3634 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3635 ])
3636 }
3637 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3638 Ok(vec![
3639 Regex::new(r"lm_head\.(weight|bias)$")?,
3640 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3642 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3643 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3644 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3645 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3647 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3648 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3649 ])
3650 }
3651}
3652
3653impl DeviceMappedModelLoader for Gemma3Loader {
3654 fn mapped_max_act_size_elems(
3655 &self,
3656 config: &str,
3657 params: &AutoDeviceMapParams,
3658 prompt_chunksize: usize,
3659 ) -> Result<usize> {
3660 let AutoDeviceMapParams::Vision {
3661 max_seq_len,
3662 max_batch_size,
3663 max_image_shape: _,
3664 max_num_images,
3665 } = params
3666 else {
3667 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3668 };
3669
3670 let cfg: Gemma3Config = serde_json::from_str(config)?;
3671
3672 match cfg {
3673 Gemma3Config::Text(text_config) => Ok(max_batch_size
3674 * text_config.num_attention_heads
3675 * prompt_chunksize
3676 * prompt_chunksize),
3677 Gemma3Config::WithVision {
3678 text_config,
3679 vision_config,
3680 ..
3681 } => {
3682 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3683 let img_seq_len = (num_patches + 1) * max_num_images;
3684
3685 let max_text_attn = {
3686 let max_seq_len = img_seq_len + *max_seq_len;
3688 max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3689 };
3690 Ok(max_text_attn)
3691 }
3692 }
3693 }
3694
3695 fn non_mapped_max_act_size_elems(
3696 &self,
3697 config: &str,
3698 params: &AutoDeviceMapParams,
3699 ) -> Result<usize> {
3700 let AutoDeviceMapParams::Vision {
3701 max_seq_len: _,
3702 max_batch_size,
3703 max_image_shape: _,
3704 max_num_images,
3705 } = params
3706 else {
3707 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3708 };
3709
3710 let cfg: Gemma3Config = serde_json::from_str(config)?;
3711
3712 match cfg {
3713 Gemma3Config::WithVision { vision_config, .. } => {
3714 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3715 let img_seq_len = num_patches + 1;
3716
3717 let max_vision_attn = {
3718 (max_batch_size * max_num_images)
3719 * vision_config.num_attention_heads
3720 * img_seq_len
3721 * img_seq_len
3722 };
3723
3724 Ok(max_vision_attn)
3725 }
3726 Gemma3Config::Text(_) => Ok(0),
3727 }
3728 }
3729
3730 fn non_mapped_size_in_bytes(
3731 &self,
3732 config: &str,
3733 dtype: DType,
3734 weight_pack_factor: usize,
3735 _matformer_config: Option<&MatformerSliceConfig>,
3736 ) -> Result<usize> {
3737 let cfg: Gemma3Config = serde_json::from_str(config)?;
3738
3739 let text_elems = {
3740 let cfg = match &cfg {
3741 Gemma3Config::Text(cfg) => cfg,
3742 Gemma3Config::WithVision { text_config, .. } => text_config,
3743 };
3744 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3745 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3747 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3748 } else {
3749 0
3750 };
3751 let norm = cfg.hidden_size;
3752 embed_tokens + lm_head + norm
3753 };
3754
3755 let vision_transformer = if let Gemma3Config::WithVision {
3756 vision_config: cfg, ..
3757 } = &cfg
3758 {
3759 let post_layernorm = cfg.hidden_size;
3760
3761 let conv_config = Conv2dConfig {
3762 stride: cfg.patch_size,
3763 ..Default::default()
3764 };
3765 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3766 * cfg.patch_size
3767 * cfg.patch_size;
3768
3769 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3770 let num_patches = num_patches_per_side.pow(2);
3771 let position_embedding = num_patches * cfg.hidden_size;
3772
3773 let layer_elems = {
3774 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3775 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3776
3777 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3778 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3779
3780 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3781 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3782 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3783 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3784
3785 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3786 };
3787
3788 post_layernorm
3789 + patch_embedding
3790 + position_embedding
3791 + layer_elems * cfg.num_hidden_layers
3792 } else {
3793 0
3794 };
3795
3796 let elems = text_elems + vision_transformer;
3797
3798 Ok(elems * dtype.size_in_bytes())
3799 }
3800
3801 fn layer_sizes_in_bytes(
3802 &self,
3803 config: &str,
3804 dtype: DType,
3805 weight_pack_factor: usize,
3806 _matformer_config: Option<&MatformerSliceConfig>,
3807 ) -> Result<Vec<usize>> {
3808 let cfg: Gemma3Config = serde_json::from_str(config)?;
3809
3810 let txt_cfg = match &cfg {
3811 Gemma3Config::Text(cfg) => cfg,
3812 Gemma3Config::WithVision { text_config, .. } => text_config,
3813 };
3814 let per_layer_elems = {
3815 let cfg = txt_cfg;
3816
3817 let input_layernorm = cfg.hidden_size;
3818 let post_attention_layernorm = cfg.hidden_size;
3819
3820 let size_in = cfg.hidden_size;
3821 let size_q = cfg.head_dim * cfg.num_attention_heads;
3822 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3823 let q_proj =
3824 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3825 let k_proj =
3826 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3827 let v_proj =
3828 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3829 let o_proj =
3830 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3831
3832 let h_size = cfg.hidden_size;
3833 let i_size = cfg.intermediate_size;
3834 let gate_proj = h_size * i_size / weight_pack_factor;
3835 let up_proj = h_size * i_size / weight_pack_factor;
3836 let down_proj = i_size * h_size / weight_pack_factor;
3837
3838 input_layernorm
3839 + post_attention_layernorm
3840 + q_proj
3841 + k_proj
3842 + v_proj
3843 + o_proj
3844 + gate_proj
3845 + up_proj
3846 + down_proj
3847 };
3848 Ok(vec![
3849 per_layer_elems * dtype.size_in_bytes();
3850 txt_cfg.num_hidden_layers
3851 ])
3852 }
3853
3854 fn num_layers(&self, config: &str) -> Result<usize> {
3855 let cfg: Gemma3Config = serde_json::from_str(config)?;
3856
3857 let txt_cfg = match &cfg {
3858 Gemma3Config::Text(cfg) => cfg,
3859 Gemma3Config::WithVision { text_config, .. } => text_config,
3860 };
3861
3862 Ok(txt_cfg.num_hidden_layers)
3863 }
3864
3865 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3866 let cfg: Gemma3Config = serde_json::from_str(config)?;
3867
3868 let cfg = match &cfg {
3869 Gemma3Config::Text(cfg) => cfg,
3870 Gemma3Config::WithVision { text_config, .. } => text_config,
3871 };
3872
3873 let cfg = ModelConfigMetadata {
3874 max_seq_len: cfg.max_position_embeddings,
3875 num_layers: cfg.num_hidden_layers,
3876 hidden_size: cfg.hidden_size,
3877 num_kv_heads: cfg.num_key_value_heads,
3878 num_attn_heads: cfg.num_attention_heads,
3879 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3881 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3882 };
3883
3884 Ok(Box::new(cfg))
3885 }
3886
3887 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3888 Some(vec![NonMappedSubModel::Vision])
3889 }
3890}
3891
3892pub struct Mistral3Loader;
3898
3899pub struct Mistral3Prefixer;
3900
3901impl MultimodalPromptPrefixer for Mistral3Prefixer {
3902 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3903 prompt.to_string()
3904 }
3905}
3906
3907impl VisionModelLoader for Mistral3Loader {
3908 fn load(
3909 &self,
3910 config: &str,
3911 vb: ShardedVarBuilder,
3912 normal_loading_metadata: NormalLoadingMetadata,
3913 attention_mechanism: AttentionImplementation,
3914 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3915 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3916 Ok(Box::new(Mistral3Model::new(
3917 &cfg,
3918 vb,
3919 self.is_gptx(config),
3920 normal_loading_metadata,
3921 attention_mechanism,
3922 )?))
3923 }
3924 fn is_gptx(&self, _config: &str) -> bool {
3925 true
3926 }
3927 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3928 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3929 Ok(Box::new(cfg))
3930 }
3931 fn get_processor(
3932 &self,
3933 _model_config: &str,
3934 processor_config: Option<ProcessorConfig>,
3935 _preprocessor_config: PreProcessorConfig,
3936 _max_edge: Option<u32>,
3937 ) -> Arc<dyn Processor + Send + Sync> {
3938 Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3939 }
3940 fn supports_paged_attention(&self, _config: &str) -> bool {
3941 true
3942 }
3943 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3944 true
3945 }
3946 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3947 Arc::new(Mistral3Prefixer)
3948 }
3949 fn modalities(&self, _config: &str) -> Result<Modalities> {
3950 Ok(Modalities {
3951 input: vec![SupportedModality::Text, SupportedModality::Vision],
3952 output: vec![SupportedModality::Text],
3953 })
3954 }
3955}
3956
3957impl IsqModelLoader for Mistral3Loader {
3958 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3959 Ok(vec![
3960 Regex::new(r"lm_head\.(weight|bias)$")?,
3961 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3963 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3964 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3965 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3966 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3968 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3969 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3970 ])
3971 }
3972 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3973 Ok(vec![
3974 Regex::new(r"lm_head\.(weight|bias)$")?,
3975 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3977 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3978 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3979 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3980 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3982 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3983 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3984 ])
3985 }
3986}
3987
3988#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3989impl DeviceMappedModelLoader for Mistral3Loader {
3990 fn mapped_max_act_size_elems(
3991 &self,
3992 config: &str,
3993 params: &AutoDeviceMapParams,
3994 _prompt_chunksize: usize,
3995 ) -> Result<usize> {
3996 let cfg: Mistral3Config = serde_json::from_str(config)?;
3997 let vcfg = &cfg.vision_config;
3998 let tcfg = &cfg.text_config;
3999
4000 let AutoDeviceMapParams::Vision {
4001 max_seq_len,
4002 max_batch_size,
4003 max_image_shape: (mut height, mut width),
4004 max_num_images,
4005 } = params
4006 else {
4007 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4008 };
4009
4010 let img_seq_len = {
4011 let (max_height, max_width) = (1540, 1540);
4015 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4016 if ratio > 1. {
4017 height = (height as f64 / ratio).floor() as usize;
4018 width = (width as f64 / ratio).floor() as usize;
4019 }
4020
4021 let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
4022 let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
4023
4024 height = num_height_tokens * vcfg.patch_size;
4025 width = num_width_tokens * vcfg.patch_size;
4026
4027 let num_height_tokens = height / vcfg.patch_size;
4028 let num_width_tokens = width / vcfg.patch_size;
4029
4030 (num_width_tokens + 1) * num_height_tokens
4031 };
4032
4033 let max_seq_len = img_seq_len * max_num_images + *max_seq_len;
4035 Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
4036 }
4037
4038 fn non_mapped_max_act_size_elems(
4039 &self,
4040 config: &str,
4041 params: &AutoDeviceMapParams,
4042 ) -> Result<usize> {
4043 let cfg: Mistral3Config = serde_json::from_str(config)?;
4044 let cfg = &cfg.vision_config;
4045
4046 let AutoDeviceMapParams::Vision {
4047 max_seq_len: _,
4048 max_batch_size,
4049 max_image_shape: (mut height, mut width),
4050 max_num_images,
4051 } = params
4052 else {
4053 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4054 };
4055
4056 let img_seq_len = {
4057 let (max_height, max_width) = (1540, 1540);
4061 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4062 if ratio > 1. {
4063 height = (height as f64 / ratio).floor() as usize;
4064 width = (width as f64 / ratio).floor() as usize;
4065 }
4066
4067 let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4068 let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4069
4070 height = num_height_tokens * cfg.patch_size;
4071 width = num_width_tokens * cfg.patch_size;
4072
4073 let num_height_tokens = height / cfg.patch_size;
4074 let num_width_tokens = width / cfg.patch_size;
4075
4076 (num_width_tokens + 1) * num_height_tokens
4077 };
4078
4079 Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4080 }
4081
4082 fn non_mapped_size_in_bytes(
4083 &self,
4084 config: &str,
4085 dtype: DType,
4086 weight_pack_factor: usize,
4087 _matformer_config: Option<&MatformerSliceConfig>,
4088 ) -> Result<usize> {
4089 let cfg: Mistral3Config = serde_json::from_str(config)?;
4090
4091 let text_elems = {
4092 let cfg = &cfg.text_config;
4093
4094 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4095 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4097 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4098 } else {
4099 0
4100 };
4101 let norm = cfg.hidden_size;
4102 embed_tokens + lm_head + norm
4103 };
4104
4105 let vision_elems = {
4106 let cfg = &cfg.vision_config;
4107
4108 let patch_embed = {
4109 let conv_cfg = Conv2dConfig {
4110 stride: cfg.patch_size,
4111 ..Default::default()
4112 };
4113 cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4114 * cfg.patch_size
4115 * cfg.patch_size
4116 * cfg.patch_size
4117 };
4118 let ln_pre = cfg.hidden_size;
4119 let vision_layer = {
4120 let attn_norm = cfg.hidden_size;
4121 let ffn_norm = cfg.hidden_size;
4122
4123 let gate = cfg.hidden_size * cfg.intermediate_size;
4124 let up = cfg.hidden_size * cfg.intermediate_size;
4125 let down = cfg.hidden_size * cfg.intermediate_size;
4126
4127 let q = cfg.hidden_size * cfg.hidden_size;
4128 let k = cfg.hidden_size * cfg.hidden_size;
4129 let v = cfg.hidden_size * cfg.hidden_size;
4130 let o = cfg.hidden_size * cfg.hidden_size;
4131
4132 attn_norm + ffn_norm + gate + up + down + q + k + v + o
4133 };
4134
4135 patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4136 };
4137
4138 let elems = text_elems + vision_elems;
4139
4140 Ok(elems * dtype.size_in_bytes())
4141 }
4142
4143 fn layer_sizes_in_bytes(
4144 &self,
4145 config: &str,
4146 dtype: DType,
4147 weight_pack_factor: usize,
4148 _matformer_config: Option<&MatformerSliceConfig>,
4149 ) -> Result<Vec<usize>> {
4150 let cfg: Mistral3Config = serde_json::from_str(config)?;
4151 let cfg = &cfg.text_config;
4152
4153 let per_layer_elems = {
4154 let input_layernorm = cfg.hidden_size;
4155 let post_attention_layernorm = cfg.hidden_size;
4156
4157 let size_in = cfg.hidden_size;
4158 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4159 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4160 let q_proj = size_in * size_q / weight_pack_factor;
4161 let k_proj = size_in * size_kv / weight_pack_factor;
4162 let v_proj = size_in * size_kv / weight_pack_factor;
4163 let o_proj = size_q * size_in / weight_pack_factor;
4164
4165 let h_size = cfg.hidden_size;
4166 let i_size = cfg.intermediate_size;
4167 let gate_proj = h_size * i_size / weight_pack_factor;
4168 let up_proj = h_size * i_size / weight_pack_factor;
4169 let down_proj = i_size * h_size / weight_pack_factor;
4170
4171 input_layernorm
4172 + post_attention_layernorm
4173 + q_proj
4174 + k_proj
4175 + v_proj
4176 + o_proj
4177 + gate_proj
4178 + up_proj
4179 + down_proj
4180 };
4181 Ok(vec![
4182 per_layer_elems * dtype.size_in_bytes();
4183 cfg.num_hidden_layers
4184 ])
4185 }
4186
4187 fn num_layers(&self, config: &str) -> Result<usize> {
4188 let cfg: Mistral3Config = serde_json::from_str(config)?;
4189 let cfg = &cfg.text_config;
4190 Ok(cfg.num_hidden_layers)
4191 }
4192
4193 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4194 let cfg: Mistral3Config = serde_json::from_str(config)?;
4195 let cfg = &cfg.text_config;
4196
4197 let cfg = ModelConfigMetadata {
4198 max_seq_len: cfg.max_position_embeddings,
4199 num_layers: cfg.num_hidden_layers,
4200 hidden_size: cfg.hidden_size,
4201 num_kv_heads: cfg.num_key_value_heads,
4202 num_attn_heads: cfg.num_attention_heads,
4203 sliding_window: cfg.sliding_window,
4204 k_head_dim: cfg.head_dim(),
4205 v_head_dim: cfg.head_dim(),
4206 };
4207
4208 Ok(Box::new(cfg))
4209 }
4210
4211 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4212 Some(vec![NonMappedSubModel::Vision])
4213 }
4214}
4215
4216pub struct VLlama4Loader;
4222
4223pub struct VLlama4Prefixer;
4224
4225impl MultimodalPromptPrefixer for VLlama4Prefixer {
4226 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4227 format!(
4228 "{}{prompt}",
4229 llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4230 )
4231 }
4232}
4233
4234impl VisionModelLoader for VLlama4Loader {
4235 fn load(
4236 &self,
4237 config: &str,
4238 vb: ShardedVarBuilder,
4239 normal_loading_metadata: NormalLoadingMetadata,
4240 attention_mechanism: AttentionImplementation,
4241 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4242 let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4243 Ok(Box::new(Llama4Model::new(
4244 &cfg,
4245 vb,
4246 self.is_gptx(config),
4247 normal_loading_metadata,
4248 attention_mechanism,
4249 )?))
4250 }
4251 fn is_gptx(&self, _config: &str) -> bool {
4252 false
4253 }
4254 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4255 let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4256 Ok(Box::new(cfg))
4257 }
4258 fn get_processor(
4259 &self,
4260 _model_config: &str,
4261 processor_config: Option<ProcessorConfig>,
4262 _preprocessor_config: PreProcessorConfig,
4263 _max_edge: Option<u32>,
4264 ) -> Arc<dyn Processor + Send + Sync> {
4265 Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4266 }
4267 fn supports_paged_attention(&self, _config: &str) -> bool {
4268 true
4269 }
4270 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4271 Arc::new(VLlama4Prefixer)
4272 }
4273 fn modalities(&self, _config: &str) -> Result<Modalities> {
4274 Ok(Modalities {
4275 input: vec![SupportedModality::Text, SupportedModality::Vision],
4276 output: vec![SupportedModality::Text],
4277 })
4278 }
4279}
4280
4281impl IsqModelLoader for VLlama4Loader {
4282 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4283 Ok(vec![
4284 Regex::new(r"lm_head\.(weight|bias)$")?,
4285 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4287 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4288 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4289 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4290 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4292 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4293 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4294 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4295 Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4296 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4297 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4298 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4299 Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4301 Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4302 Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4303 ])
4304 }
4305 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4306 Ok(vec![
4307 Regex::new(r"lm_head\.(weight|bias)$")?,
4308 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4310 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4311 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4312 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4313 Regex::new(
4315 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4316 )?,
4317 Regex::new(
4318 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4319 )?,
4320 Regex::new(
4321 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4322 )?,
4323 Regex::new(
4324 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4325 )?,
4326 Regex::new(
4327 r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4328 )?,
4329 Regex::new(
4330 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4331 )?,
4332 Regex::new(
4333 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4334 )?,
4335 Regex::new(
4336 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4337 )?,
4338 Regex::new(
4340 r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4341 )?,
4342 Regex::new(
4343 r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4344 )?,
4345 Regex::new(
4346 r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4347 )?,
4348 ])
4349 }
4350}
4351
4352impl VLlama4Loader {
4353 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4356 fn run_dummy_processing(
4357 &self,
4358 cfg: &Llama4Config,
4359 height: usize,
4360 width: usize,
4361 max_num_images: usize,
4362 max_batch_size: usize,
4363 ) -> Result<(usize, usize)> {
4364 let cfg = &cfg.vision_config;
4365
4366 let img_processor =
4367 Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4368 let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4369 let res = img_processor.preprocess(
4370 vec![image; max_num_images],
4371 vec![],
4372 &PreProcessorConfig::default(),
4373 &Device::Cpu,
4374 (max_batch_size, max_num_images),
4375 )?;
4376
4377 let pixels_batch_size = res.pixel_values.dim(0)?;
4378 let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4379
4380 let (image_h, image_w) = (
4381 res.pixel_values.dim(D::Minus2).unwrap(),
4382 res.pixel_values.dim(D::Minus1).unwrap(),
4383 );
4384 let num_patches_per_chunk = (image_h / img_processor.patch_size)
4385 * (image_w / img_processor.patch_size)
4386 / img_processor.downsample_ratio;
4387
4388 Ok((
4389 pixels_max_batch_size,
4390 num_patches_per_chunk * pixels_max_batch_size,
4391 ))
4392 }
4393}
4394
4395impl DeviceMappedModelLoader for VLlama4Loader {
4396 fn mapped_max_act_size_elems(
4397 &self,
4398 config: &str,
4399 params: &AutoDeviceMapParams,
4400 _prompt_chunksize: usize,
4401 ) -> Result<usize> {
4402 let AutoDeviceMapParams::Vision {
4403 max_seq_len,
4404 max_batch_size,
4405 max_image_shape: (height, width),
4406 max_num_images,
4407 } = params
4408 else {
4409 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4410 };
4411
4412 let cfg: Llama4Config = serde_json::from_str(config)?;
4413
4414 let (_pixels_batch_size, num_text_image_toks) =
4415 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4416
4417 let max_seq_len = max_seq_len + num_text_image_toks;
4418
4419 Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4420 }
4421 fn non_mapped_max_act_size_elems(
4422 &self,
4423 config: &str,
4424 params: &AutoDeviceMapParams,
4425 ) -> Result<usize> {
4426 let AutoDeviceMapParams::Vision {
4427 max_seq_len: _,
4428 max_batch_size,
4429 max_image_shape: (height, width),
4430 max_num_images,
4431 } = params
4432 else {
4433 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4434 };
4435
4436 let cfg: Llama4Config = serde_json::from_str(config)?;
4437
4438 let (pixels_batch_size, _num_text_image_toks) =
4439 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4440 let max_seq_len = cfg.vision_config.num_patches();
4441
4442 Ok((max_batch_size * pixels_batch_size)
4443 * cfg.vision_config.num_attention_heads
4444 * max_seq_len
4445 * max_seq_len)
4446 }
4447
4448 fn non_mapped_size_in_bytes(
4449 &self,
4450 config: &str,
4451 dtype: DType,
4452 weight_pack_factor: usize,
4453 _matformer_config: Option<&MatformerSliceConfig>,
4454 ) -> Result<usize> {
4455 let cfg: Llama4Config = serde_json::from_str(config)?;
4456 let tcfg = &cfg.text_config;
4457
4458 let text_elems = {
4459 let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4460 let lm_head = if !tcfg.tie_word_embeddings {
4461 tcfg.hidden_size * tcfg.vocab_size
4462 } else {
4463 0
4464 };
4465 let norm = tcfg.hidden_size;
4466 embed_tokens + lm_head + norm
4467 };
4468
4469 let vision_elems = {
4470 let cfg = &cfg.vision_config;
4471
4472 let num_patches = cfg.num_patches();
4473
4474 let unfold_elems =
4475 (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4476 let class_embeddng_elems = cfg.hidden_size;
4477 let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4478 let layernorm_pre_elems = cfg.hidden_size;
4479 let layernorm_post_elems = cfg.hidden_size;
4480
4481 let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4482 / weight_pack_factor
4483 + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4484
4485 let encoder_layer = {
4486 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4487 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4488
4489 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4490 let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4491 / weight_pack_factor
4492 + cfg.num_attention_heads * head_dim;
4493 let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4494 / weight_pack_factor
4495 + cfg.num_attention_heads * head_dim;
4496 let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4497 / weight_pack_factor
4498 + cfg.num_attention_heads * head_dim;
4499 let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4500 / weight_pack_factor
4501 + cfg.num_attention_heads * head_dim;
4502
4503 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4504 + cfg.intermediate_size;
4505 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4506 + cfg.hidden_size;
4507
4508 input_layernorm
4509 + post_attention_layernorm
4510 + q_proj
4511 + k_proj
4512 + v_proj
4513 + o_proj
4514 + fc1
4515 + fc2
4516 };
4517
4518 unfold_elems
4519 + class_embeddng_elems
4520 + positional_embedding_vlm_elems
4521 + layernorm_post_elems
4522 + layernorm_pre_elems
4523 + pixel_shuffle_elems
4524 + encoder_layer * cfg.num_hidden_layers
4525 };
4526
4527 let elems = text_elems + vision_elems;
4528
4529 Ok(elems * dtype.size_in_bytes())
4530 }
4531
4532 fn layer_sizes_in_bytes(
4533 &self,
4534 config: &str,
4535 dtype: DType,
4536 weight_pack_factor: usize,
4537 _matformer_config: Option<&MatformerSliceConfig>,
4538 ) -> Result<Vec<usize>> {
4539 let cfg: Llama4Config = serde_json::from_str(config)?;
4540 let tcfg = &cfg.text_config;
4541
4542 let mut per_layer_elems = Vec::new();
4543
4544 for layer_idx in 0..tcfg.num_hidden_layers {
4545 let input_layernorm = tcfg.hidden_size;
4546 let post_attention_layernorm = tcfg.hidden_size;
4547
4548 let size_in = tcfg.hidden_size;
4549 let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4550 let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4551 let q_proj = size_in * size_q / weight_pack_factor;
4552 let k_proj = size_in * size_kv / weight_pack_factor;
4553 let v_proj = size_in * size_kv / weight_pack_factor;
4554 let o_proj = size_q * size_in / weight_pack_factor;
4555
4556 let use_moe = tcfg.moe_layers().contains(&layer_idx);
4557 let moe_block = if use_moe {
4558 let h_size = tcfg.hidden_size;
4559 let i_size = tcfg.intermediate_size;
4560 let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4561 let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4562 let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4563
4564 gate_proj + up_proj + down_proj
4565 } else {
4566 let h_size = tcfg.hidden_size;
4567 let i_size = tcfg.intermediate_size_mlp;
4568 let gate_proj = h_size * i_size / weight_pack_factor;
4569 let up_proj = h_size * i_size / weight_pack_factor;
4570 let down_proj = i_size * h_size / weight_pack_factor;
4571
4572 gate_proj + up_proj + down_proj
4573 };
4574
4575 per_layer_elems.push(
4576 input_layernorm
4577 + post_attention_layernorm
4578 + q_proj
4579 + k_proj
4580 + v_proj
4581 + o_proj
4582 + moe_block,
4583 );
4584 }
4585
4586 Ok(per_layer_elems
4587 .into_iter()
4588 .map(|x| x * dtype.size_in_bytes())
4589 .collect())
4590 }
4591
4592 fn num_layers(&self, config: &str) -> Result<usize> {
4593 let cfg: Llama4Config = serde_json::from_str(config)?;
4594 Ok(cfg.text_config.num_hidden_layers)
4595 }
4596
4597 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4598 let cfg: Llama4Config = serde_json::from_str(config)?;
4599 let cfg = &cfg.text_config;
4600
4601 let cfg = ModelConfigMetadata {
4602 max_seq_len: cfg.max_position_embeddings,
4603 num_layers: cfg.num_hidden_layers,
4604 hidden_size: cfg.hidden_size,
4605 num_kv_heads: cfg.num_attention_heads,
4606 num_attn_heads: cfg.num_attention_heads,
4607 sliding_window: None,
4608 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4609 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4610 };
4611
4612 Ok(Box::new(cfg))
4613 }
4614
4615 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4616 Some(vec![NonMappedSubModel::Vision])
4617 }
4618}
4619
4620pub struct Gemma3nLoader;
4626
4627pub struct Gemma3nPrefixer;
4628
4629impl MultimodalPromptPrefixer for Gemma3nPrefixer {
4630 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4631 prompt.to_string()
4632 }
4633}
4634
4635impl VisionModelLoader for Gemma3nLoader {
4636 fn load(
4637 &self,
4638 config: &str,
4639 vb: ShardedVarBuilder,
4640 normal_loading_metadata: NormalLoadingMetadata,
4641 attention_mechanism: AttentionImplementation,
4642 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4643 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4644 Ok(Box::new(Gemma3nModel::new(
4645 &cfg,
4646 vb,
4647 self.is_gptx(config),
4648 normal_loading_metadata,
4649 attention_mechanism,
4650 )?))
4651 }
4652 fn is_gptx(&self, _config: &str) -> bool {
4653 true
4654 }
4655 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4656 let config: Gemma3nConfig = serde_json::from_str(config)?;
4657 Ok(Box::new(config))
4658 }
4659 fn get_processor(
4660 &self,
4661 _config: &str,
4662 processor_config: Option<ProcessorConfig>,
4663 _preprocessor_config: PreProcessorConfig,
4664 _max_edge: Option<u32>,
4665 ) -> Arc<dyn Processor + Send + Sync> {
4666 Arc::new(Gemma3nProcessor::new(
4668 processor_config.unwrap_or_default(),
4669 true,
4670 ))
4671 }
4672 fn supports_paged_attention(&self, _config: &str) -> bool {
4673 false
4674 }
4675 fn supports_prefix_cacher(&self, _config: &str) -> bool {
4676 true
4677 }
4678 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4679 Arc::new(Gemma3Prefixer)
4680 }
4681 fn modalities(&self, _config: &str) -> Result<Modalities> {
4682 Ok(Modalities {
4683 input: vec![
4684 SupportedModality::Text,
4685 SupportedModality::Vision,
4686 SupportedModality::Audio,
4687 ],
4688 output: vec![SupportedModality::Text],
4689 })
4690 }
4691}
4692
4693impl IsqModelLoader for Gemma3nLoader {
4694 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4695 Ok(vec![
4696 Regex::new(r"lm_head\.(weight|bias)$")?,
4697 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4699 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4700 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4701 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4702 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4704 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4705 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4706 Regex::new(r"conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$")?,
4708 Regex::new(r"conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$")?,
4709 Regex::new(r"conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$")?,
4710 Regex::new(
4711 r"conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4712 )?,
4713 Regex::new(r"conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4714 Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$")?,
4716 Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$")?,
4717 Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$")?,
4718 Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$")?,
4719 Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$")?,
4721 Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$")?,
4722 Regex::new(r"subsample_conv_projection\.input_proj_linear\.(weight|bias)$")?,
4724 Regex::new(r"embed_vision\.embedding_projection\.(weight|bias)$")?,
4726 Regex::new(r"embed_audio\.embedding_projection\.(weight|bias)$")?,
4727 ])
4728 }
4729 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4730 Ok(vec![
4731 Regex::new(r"lm_head\.(weight|bias)$")?,
4732 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4734 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4735 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4736 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4737 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4739 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4740 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4741 Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
4743 Regex::new(r"model\.language_model\.altup_projections\.(\d+)\.(weight|bias)$")?,
4744 Regex::new(r"model\.language_model\.altup_unembed_projections\.(\d+)\.(weight|bias)$")?,
4745 Regex::new(
4747 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$",
4748 )?,
4749 Regex::new(
4750 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$",
4751 )?,
4752 Regex::new(
4753 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$",
4754 )?,
4755 Regex::new(
4756 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4757 )?,
4758 Regex::new(r"model\.audio_tower\.conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4759 Regex::new(
4761 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$",
4762 )?,
4763 Regex::new(
4764 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$",
4765 )?,
4766 Regex::new(
4767 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$",
4768 )?,
4769 Regex::new(
4770 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$",
4771 )?,
4772 Regex::new(
4774 r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$",
4775 )?,
4776 Regex::new(
4777 r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$",
4778 )?,
4779 Regex::new(
4781 r"model\.audio_tower\.subsample_conv_projection\.input_proj_linear\.(weight|bias)$",
4782 )?,
4783 Regex::new(r"model\.embed_vision\.embedding_projection\.(weight|bias)$")?,
4785 Regex::new(r"model\.embed_audio\.embedding_projection\.(weight|bias)$")?,
4786 ])
4787 }
4788}
4789
4790impl DeviceMappedModelLoader for Gemma3nLoader {
4791 fn mapped_max_act_size_elems(
4792 &self,
4793 config: &str,
4794 params: &AutoDeviceMapParams,
4795 _prompt_chunksize: usize,
4796 ) -> Result<usize> {
4797 let AutoDeviceMapParams::Vision {
4798 max_seq_len,
4799 max_batch_size,
4800 max_image_shape: _,
4801 max_num_images,
4802 } = params
4803 else {
4804 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4805 };
4806
4807 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4808 let text_cfg = &cfg.text_config;
4809
4810 let mut total_seq_len = *max_seq_len;
4814
4815 {
4817 let msfa_spatial_size = 16; let vision_tokens_per_image = msfa_spatial_size * msfa_spatial_size; total_seq_len += vision_tokens_per_image * max_num_images;
4822 }
4823
4824 {
4826 let audio_tokens = cfg.audio_soft_tokens_per_image;
4829 total_seq_len += audio_tokens;
4830 }
4831
4832 let max_text_attn =
4834 max_batch_size * text_cfg.num_attention_heads * total_seq_len * total_seq_len;
4835
4836 Ok(max_text_attn)
4837 }
4838
4839 fn non_mapped_max_act_size_elems(
4840 &self,
4841 config: &str,
4842 params: &AutoDeviceMapParams,
4843 ) -> Result<usize> {
4844 let AutoDeviceMapParams::Vision {
4845 max_seq_len: _,
4846 max_batch_size,
4847 max_image_shape: _,
4848 max_num_images,
4849 } = params
4850 else {
4851 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4852 };
4853
4854 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4855
4856 let mut max_activation = 0;
4858
4859 {
4861 let vision_tower_act = {
4871 let num_heads = 16; let spatial_size = 24; let seq_len = spatial_size * spatial_size;
4877
4878 max_batch_size * max_num_images * num_heads * seq_len * seq_len
4880 };
4881
4882 let vision_embed_act = {
4884 let msfa_channels = 2048; let spatial_size = 16; let vision_features =
4888 max_batch_size * max_num_images * msfa_channels * spatial_size * spatial_size;
4889
4890 let projected = max_batch_size
4892 * max_num_images
4893 * spatial_size
4894 * spatial_size
4895 * cfg.text_config.hidden_size;
4896
4897 vision_features.max(projected)
4898 };
4899
4900 max_activation = max_activation.max(vision_tower_act).max(vision_embed_act);
4901 }
4902
4903 {
4905 let audio_cfg = &cfg.audio_config;
4906
4907 let max_audio_frames = 1280;
4912
4913 let subsample_factor: usize = audio_cfg
4914 .sscp_conv_stride_size
4915 .iter()
4916 .map(|stride| stride[0]) .product();
4918 let audio_seq_after_subsample = max_audio_frames / subsample_factor;
4919
4920 let audio_encoder_act = {
4922 let intermediate_size = audio_cfg.hidden_size * 4; max_batch_size * audio_seq_after_subsample * intermediate_size
4927 };
4928
4929 let audio_attn_act = {
4931 let chunk_size = audio_cfg.conf_attention_chunk_size;
4933 let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
4934 + audio_cfg.conf_attention_context_right;
4935
4936 let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
4938
4939 max_batch_size
4940 * audio_cfg.conf_num_attention_heads
4941 * num_chunks
4942 * chunk_size
4943 * context_size
4944 };
4945
4946 max_activation = max_activation.max(audio_encoder_act).max(audio_attn_act);
4947 }
4948
4949 Ok(max_activation)
4950 }
4951
4952 fn non_mapped_size_in_bytes(
4953 &self,
4954 config: &str,
4955 dtype: DType,
4956 weight_pack_factor: usize,
4957 matformer_config: Option<&MatformerSliceConfig>,
4958 ) -> Result<usize> {
4959 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4960
4961 let text_cfg = if let Some(matformer_cfg) = matformer_config {
4963 use crate::device_map::DummyDeviceMapper;
4964 use crate::vision_models::gemma3n::text::handle_matformer_slicing;
4965
4966 let dummy_mapper = DummyDeviceMapper {
4967 nm_device: Device::Cpu,
4968 };
4969 let (adjusted_cfg, _, _, _, _) = handle_matformer_slicing(
4970 &cfg.text_config,
4971 &Some(matformer_cfg.clone()),
4972 &dummy_mapper,
4973 )?;
4974 adjusted_cfg
4975 } else {
4976 cfg.text_config.clone()
4977 };
4978
4979 let text_cfg = &text_cfg;
4980
4981 let text_elems = {
4983 let embed_tokens = text_cfg.hidden_size * text_cfg.vocab_size;
4985 let embed_tokens_per_layer = text_cfg.num_hidden_layers
4986 * text_cfg.hidden_size_per_layer_input
4987 * text_cfg.vocab_size_per_layer_input;
4988
4989 let lm_head = if !text_cfg.tie_word_embeddings || weight_pack_factor != 1 {
4991 text_cfg.hidden_size * text_cfg.vocab_size / weight_pack_factor
4992 } else {
4993 0
4994 };
4995
4996 let norm = text_cfg.hidden_size;
4998
4999 let altup_projections =
5001 (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5002 / weight_pack_factor;
5003 let altup_unembed_projections =
5004 (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5005 / weight_pack_factor;
5006
5007 let per_layer_model_projection = text_cfg.num_hidden_layers
5009 * text_cfg.hidden_size
5010 * text_cfg.hidden_size_per_layer_input
5011 / weight_pack_factor;
5012 let per_layer_projection_norm = text_cfg.hidden_size;
5013
5014 embed_tokens
5015 + embed_tokens_per_layer
5016 + lm_head
5017 + norm
5018 + altup_projections
5019 + altup_unembed_projections
5020 + per_layer_model_projection
5021 + per_layer_projection_norm
5022 };
5023
5024 let vision_elems = {
5026 let vision_cfg = &cfg.vision_config;
5027 let vision_tower_elems = {
5031 use crate::vision_models::gemma3n::vision::{
5032 gemma3n_mobilenet_def, make_divisible, BlockType, INPUT_CHANNELS,
5033 MSFA_EXPANSION_RATIO, MSFA_IN_CHANNELS, MSFA_OUT_CHANNELS, STEM_KERNEL_SIZE,
5034 STEM_OUT_CHANNELS,
5035 };
5036
5037 let stem_conv =
5039 INPUT_CHANNELS * STEM_OUT_CHANNELS * STEM_KERNEL_SIZE * STEM_KERNEL_SIZE;
5040 let stem_norm = STEM_OUT_CHANNELS; let mut in_chs = STEM_OUT_CHANNELS;
5044 let mut total_elems = stem_conv + stem_norm;
5045
5046 let block_defs = gemma3n_mobilenet_def();
5048
5049 for stage_blocks in block_defs.iter() {
5050 for block_type in stage_blocks.iter() {
5051 match block_type {
5052 BlockType::EdgeResidual {
5053 out_channels,
5054 kernel_size,
5055 stride: _,
5056 expand_ratio,
5057 ..
5058 } => {
5059 #[allow(clippy::cast_precision_loss)]
5060 let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5061 total_elems += in_chs * mid_chs * kernel_size * kernel_size; total_elems += mid_chs; total_elems += mid_chs * out_channels; total_elems += out_channels; in_chs = *out_channels;
5067 }
5068 BlockType::UniversalInvertedResidual {
5069 out_channels,
5070 start_kernel_size,
5071 mid_kernel_size,
5072 stride: _,
5073 expand_ratio,
5074 ..
5075 } => {
5076 #[allow(clippy::cast_precision_loss)]
5077 let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5078 if *expand_ratio != 1.0 {
5080 total_elems += in_chs * mid_chs; total_elems += mid_chs; }
5083 if *start_kernel_size > 0 {
5084 total_elems += mid_chs * start_kernel_size * start_kernel_size; total_elems += mid_chs; }
5087 if *mid_kernel_size > 0 {
5088 total_elems += mid_chs * mid_kernel_size * mid_kernel_size; total_elems += mid_chs; }
5091 total_elems += mid_chs * out_channels; total_elems += out_channels; total_elems += out_channels; in_chs = *out_channels;
5095 }
5096 BlockType::MultiQueryAttention {
5097 num_heads,
5098 kv_dim,
5099 kv_stride: _,
5100 ..
5101 } => {
5102 let dw_kernel_size = 3; total_elems += in_chs; total_elems += in_chs * num_heads * kv_dim; total_elems += in_chs * kv_dim; total_elems += in_chs * dw_kernel_size * dw_kernel_size; total_elems += *kv_dim; total_elems += 1; total_elems += *kv_dim; total_elems += num_heads * kv_dim * in_chs; total_elems += in_chs; }
5114 }
5115 }
5116 }
5117
5118 let msfa_in = MSFA_IN_CHANNELS.iter().sum::<usize>();
5120 let msfa_out = MSFA_OUT_CHANNELS;
5121 #[allow(clippy::cast_precision_loss)]
5122 let msfa_mid = make_divisible(msfa_in as f64 * MSFA_EXPANSION_RATIO, 8);
5123
5124 total_elems += msfa_in * msfa_mid; total_elems += msfa_mid; total_elems += msfa_mid * msfa_out; total_elems += msfa_out; total_elems += msfa_out; total_elems
5132 };
5133
5134 let embed_vision_elems = {
5136 let embedding = vision_cfg.vocab_size * vision_cfg.hidden_size;
5138
5139 let hard_norm = vision_cfg.hidden_size;
5141 let soft_norm = vision_cfg.hidden_size;
5142
5143 let projection = vision_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5145
5146 let post_norm = text_cfg.hidden_size;
5148
5149 embedding + hard_norm + soft_norm + projection + post_norm
5150 };
5151
5152 vision_tower_elems + embed_vision_elems
5153 };
5154
5155 let audio_elems = {
5157 let audio_cfg = &cfg.audio_config;
5158
5159 let subsample_conv_projection_elems = {
5161 let mut conv_elems = 0;
5163
5164 let in_ch_0 = 1;
5166 let out_ch_0 = audio_cfg.sscp_conv_channel_size[0];
5167 let kernel_0 = &audio_cfg.sscp_conv_kernel_size[0];
5168 conv_elems += in_ch_0 * out_ch_0 * kernel_0[0] * kernel_0[1];
5169
5170 let in_ch_1 = out_ch_0;
5172 let out_ch_1 = audio_cfg.sscp_conv_channel_size[1];
5173 let kernel_1 = &audio_cfg.sscp_conv_kernel_size[1];
5174 conv_elems += in_ch_1 * out_ch_1 * kernel_1[0] * kernel_1[1];
5175
5176 let norm_0 = out_ch_0; let norm_1 = out_ch_1; let mut f_out = audio_cfg.input_feat_size;
5182 for i in 0..2 {
5183 let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
5184 let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
5185 let pad_left = 1;
5186 let pad_right = 1;
5187 f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
5188 }
5189 let input_proj_in_features = out_ch_1 * f_out;
5190 let input_proj_linear =
5191 input_proj_in_features * audio_cfg.hidden_size / weight_pack_factor;
5192
5193 conv_elems + norm_0 + norm_1 + input_proj_linear
5194 };
5195
5196 let conformer_elems = {
5198 let mut total = 0;
5199
5200 for _ in 0..audio_cfg.conf_num_hidden_layers {
5201 let attention_elems = {
5203 let pre_attn_norm = audio_cfg.hidden_size;
5205 let post_norm = audio_cfg.hidden_size;
5206
5207 let q_proj =
5209 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5210 let k_proj =
5211 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5212 let v_proj =
5213 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5214 let post =
5215 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5216
5217 let pos_proj =
5219 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5220 let per_dim_scale =
5221 audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads; let inv_timescales = audio_cfg.hidden_size / 2; let pos_indices = audio_cfg.conf_attention_context_left
5224 + audio_cfg.conf_attention_context_right
5225 + 1;
5226
5227 let chunk_size = audio_cfg.conf_attention_chunk_size;
5229 let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5230 + audio_cfg.conf_attention_context_right;
5231 let local_causal_valid_mask = chunk_size * context_size; let invalid_logits_tensor = 1; pre_attn_norm
5235 + post_norm
5236 + q_proj
5237 + k_proj
5238 + v_proj
5239 + post
5240 + pos_proj
5241 + per_dim_scale
5242 + inv_timescales
5243 + pos_indices
5244 + local_causal_valid_mask
5245 + invalid_logits_tensor
5246 };
5247
5248 let ffw_elems = {
5250 let intermediate_size = audio_cfg.hidden_size * 4;
5256
5257 let ffw_start = {
5258 let pre_norm = audio_cfg.hidden_size;
5259 let layer_1 =
5260 audio_cfg.hidden_size * intermediate_size / weight_pack_factor;
5261 let layer_2 =
5262 intermediate_size * audio_cfg.hidden_size / weight_pack_factor;
5263 let post_norm = audio_cfg.hidden_size;
5264 pre_norm + layer_1 + layer_2 + post_norm
5265 };
5266
5267 let ffw_end = ffw_start; ffw_start + ffw_end
5270 };
5271
5272 let lconv1d_elems = {
5274 let pre_layer_norm = audio_cfg.hidden_size;
5276 let conv_norm = audio_cfg.hidden_size;
5277
5278 let linear_start = audio_cfg.hidden_size * (audio_cfg.hidden_size * 2)
5280 / weight_pack_factor;
5281 let linear_end =
5282 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5283
5284 let depthwise = audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
5286
5287 pre_layer_norm + conv_norm + linear_start + linear_end + depthwise
5288 };
5289
5290 let block_norm = audio_cfg.hidden_size;
5292
5293 total += attention_elems + ffw_elems + lconv1d_elems + block_norm;
5294 }
5295
5296 total
5297 };
5298
5299 let embed_audio_elems = {
5301 let embedding = audio_cfg.vocab_size * audio_cfg.hidden_size;
5303
5304 let hard_embedding_norm = audio_cfg.hidden_size; let soft_embedding_norm = audio_cfg.hidden_size; let embedding_post_projection_norm = text_cfg.hidden_size; let embedding_projection =
5311 audio_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5312
5313 embedding
5314 + hard_embedding_norm
5315 + soft_embedding_norm
5316 + embedding_post_projection_norm
5317 + embedding_projection
5318 };
5319
5320 subsample_conv_projection_elems + conformer_elems + embed_audio_elems
5321 };
5322
5323 let vision_dtype = if dtype == DType::F16 {
5324 DType::F32
5326 } else {
5327 dtype
5328 };
5329
5330 let total_elems = text_elems * dtype.size_in_bytes()
5331 + vision_elems * vision_dtype.size_in_bytes()
5332 + audio_elems * dtype.size_in_bytes();
5333
5334 Ok(total_elems)
5335 }
5336
5337 fn layer_sizes_in_bytes(
5338 &self,
5339 config: &str,
5340 dtype: DType,
5341 weight_pack_factor: usize,
5342 matformer_config: Option<&MatformerSliceConfig>,
5343 ) -> Result<Vec<usize>> {
5344 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5345
5346 let (text_cfg, _layer_rename_map, _layers_skipped) = if let Some(matformer_cfg) =
5348 matformer_config
5349 {
5350 use crate::device_map::DummyDeviceMapper;
5351 use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5352
5353 let dummy_mapper = DummyDeviceMapper {
5354 nm_device: Device::Cpu,
5355 };
5356 let (adjusted_cfg, _, _, layer_rename_map, layers_skipped) = handle_matformer_slicing(
5357 &cfg.text_config,
5358 &Some(matformer_cfg.clone()),
5359 &dummy_mapper,
5360 )?;
5361 (adjusted_cfg, layer_rename_map, layers_skipped)
5362 } else {
5363 (cfg.text_config.clone(), None, None)
5364 };
5365
5366 let text_cfg = &text_cfg;
5367
5368 let mut layer_sizes = Vec::new();
5370
5371 for layer_idx in 0..text_cfg.num_hidden_layers {
5375 let per_layer_elems = {
5376 let input_layernorm = text_cfg.hidden_size;
5378 let post_attention_layernorm = text_cfg.hidden_size;
5379 let pre_feedforward_layernorm = text_cfg.hidden_size;
5380 let post_feedforward_layernorm = text_cfg.hidden_size;
5381 let post_per_layer_input_norm = text_cfg.hidden_size;
5382
5383 let size_in = text_cfg.hidden_size;
5385 let size_q = text_cfg.num_attention_heads * text_cfg.head_dim;
5386 let size_kv = text_cfg.num_key_value_heads * text_cfg.head_dim;
5387
5388 let q_proj = size_in * size_q / weight_pack_factor;
5389 let k_proj = size_in * size_kv / weight_pack_factor;
5390 let v_proj = size_in * size_kv / weight_pack_factor;
5391 let o_proj = size_q * size_in / weight_pack_factor;
5392
5393 let q_norm = text_cfg.head_dim;
5395 let k_norm = text_cfg.head_dim;
5396 let v_norm = text_cfg.head_dim; let intermediate_size = text_cfg
5400 .intermediate_size
5401 .0
5402 .clone()
5403 .map_left(|left| left[layer_idx])
5404 .left_or_else(|(sizes, _)| sizes[layer_idx]);
5405 let gate_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5406 let up_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5407 let down_proj = intermediate_size * text_cfg.hidden_size / weight_pack_factor;
5408
5409 let altup_elems = {
5411 let correct_output_scale = text_cfg.hidden_size;
5412 let correction_coefs = text_cfg.altup_num_inputs * text_cfg.altup_num_inputs;
5413 let prediction_coefs =
5414 text_cfg.altup_num_inputs * text_cfg.altup_num_inputs.pow(2);
5415 let modality_router = text_cfg.hidden_size * text_cfg.altup_num_inputs;
5416 let router_norm = text_cfg.hidden_size;
5417
5418 correct_output_scale
5419 + correction_coefs
5420 + prediction_coefs
5421 + modality_router
5422 + router_norm
5423 };
5424
5425 let laurel_elems = {
5427 let left = text_cfg.hidden_size * text_cfg.laurel_rank;
5428 let right = text_cfg.laurel_rank * text_cfg.hidden_size;
5429 let post_norm = text_cfg.hidden_size;
5430
5431 left + right + post_norm
5432 };
5433
5434 let per_layer_input_gate =
5436 text_cfg.hidden_size * text_cfg.hidden_size_per_layer_input;
5437 let per_layer_projection =
5438 text_cfg.hidden_size_per_layer_input * text_cfg.hidden_size;
5439
5440 input_layernorm
5441 + post_attention_layernorm
5442 + pre_feedforward_layernorm
5443 + post_feedforward_layernorm
5444 + post_per_layer_input_norm
5445 + q_proj
5446 + k_proj
5447 + v_proj
5448 + o_proj
5449 + q_norm
5450 + k_norm
5451 + v_norm
5452 + gate_proj
5453 + up_proj
5454 + down_proj
5455 + altup_elems
5456 + laurel_elems
5457 + per_layer_input_gate
5458 + per_layer_projection
5459 };
5460
5461 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5462 }
5463
5464 Ok(layer_sizes)
5465 }
5466
5467 fn num_layers(&self, config: &str) -> Result<usize> {
5468 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5469 Ok(cfg.text_config.num_hidden_layers)
5470 }
5471
5472 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5473 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5474 let cfg = cfg.text_config;
5475
5476 let cfg = ModelConfigMetadata {
5477 max_seq_len: cfg.max_position_embeddings,
5478 num_layers: cfg.num_hidden_layers,
5479 hidden_size: cfg.hidden_size,
5480 num_kv_heads: cfg.num_key_value_heads,
5481 num_attn_heads: cfg.num_attention_heads,
5482 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5484 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5485 };
5486
5487 Ok(Box::new(cfg))
5488 }
5489
5490 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5491 Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
5492 }
5493}