1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::attention::ATTENTION_CHUNK_SIZE;
24use crate::device_map::DeviceMapper;
25use crate::layers::Conv3dConfig;
26use crate::matformer::MatformerSliceConfig;
27use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
28use crate::pipeline::isq::IsqModelLoader;
29use crate::pipeline::loaders::AutoDeviceMapParams;
30use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
31use crate::pipeline::{
32 EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
33 SupportedModality,
34};
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::vision_models::clip::ClipConfig;
37use crate::vision_models::gemma3::config::Gemma3Config;
38use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
39use crate::vision_models::gemma3n::config::{Gemma3nConfig, IntermediateSize};
40use crate::vision_models::gemma3n::{Gemma3nModel, Gemma3nProcessor};
41use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
42use crate::vision_models::idefics2_input_processor::Idefics2Processor;
43use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
44use crate::vision_models::image_processor::ImagePreProcessor;
45use crate::vision_models::inputs_processor::Phi4MMProcessor;
46use crate::vision_models::llama4::{
47 self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
48};
49use crate::vision_models::llava::config::Config as LLaVAConfig;
50use crate::vision_models::llava15::Model as LLaVA;
51use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
52use crate::vision_models::llava_next::Model as LLaVANext;
53use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
54use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
55use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
56use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
57use crate::vision_models::phi3_inputs_processor::Phi3Processor;
58use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
59use crate::vision_models::preprocessor_config::PreProcessorConfig;
60use crate::vision_models::processor_config::ProcessorConfig;
61use crate::vision_models::qwen2_5_vl::{
62 Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
63};
64use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
65use crate::vision_models::{minicpmo, phi4};
66
67pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
68 #[allow(clippy::too_many_arguments)]
70 fn forward(
71 &self,
72 input_ids: &Tensor,
73 pixel_values: Option<Tensor>,
74 seqlen_offsets: &[usize],
75 context_lens: Vec<(usize, usize)>,
76 position_ids: Vec<usize>,
77 model_specific_args: Box<dyn Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
79 flash_params: &FlashParams,
80 ) -> candle_core::Result<Tensor>;
81 fn device(&self) -> &Device;
82 fn cache(&self) -> &EitherCache;
83 fn cache_mut(&mut self) -> &mut EitherCache;
84 fn max_seq_len(&self) -> usize;
85 fn config(&self) -> &ModelConfigMetadata;
86 fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
88}
89
90pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
91 fn load(
92 &self,
93 config: &str,
94 vb: ShardedVarBuilder,
95 normal_loading_metadata: NormalLoadingMetadata,
96 attention_mechanism: AttentionImplementation,
97 ) -> Result<Box<dyn VisionModel + Send + Sync>>;
98 fn is_gptx(&self, config: &str) -> bool;
99 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
100 fn get_processor(
101 &self,
102 model_config: &str,
103 processor_config: Option<ProcessorConfig>,
104 preprocessor_config: PreProcessorConfig,
105 max_edge: Option<u32>,
106 ) -> Arc<dyn Processor + Send + Sync>;
107 fn supports_paged_attention(&self, config: &str) -> bool;
108 fn supports_prefix_cacher(&self, _config: &str) -> bool {
109 false
111 }
112 fn modalities(&self, config: &str) -> Result<Modalities>;
113 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
114 fn get_device_for_tensor(
115 &self,
116 config: &str,
117 _mapper: &dyn DeviceMapper,
118 loading_isq: bool,
119 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
120 if loading_isq {
121 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
122 } else {
123 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
124 let num_layers = self.model_config(config)?.num_layers();
125 let closure = move |name: String| {
126 if let Some(captures) = re.captures(&name) {
127 captures
128 .get(1)
129 .and_then(|m| m.as_str().parse::<usize>().ok())
130 .map(|l| l.min(num_layers))
131 .map(DeviceForLoadTensor::Idx)
132 .unwrap_or(DeviceForLoadTensor::Base)
133 } else {
134 DeviceForLoadTensor::Base
135 }
136 };
137
138 Ok(Arc::new(closure))
139 }
140 }
141}
142
143#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
144#[derive(Clone, Debug, Deserialize, PartialEq)]
145pub enum VisionLoaderType {
147 #[serde(rename = "phi3v")]
148 Phi3V,
149 #[serde(rename = "idefics2")]
150 Idefics2,
151 #[serde(rename = "llava_next")]
152 LLaVANext,
153 #[serde(rename = "llava")]
154 LLaVA,
155 #[serde(rename = "vllama")]
156 VLlama,
157 #[serde(rename = "qwen2vl")]
158 Qwen2VL,
159 #[serde(rename = "idefics3")]
160 Idefics3,
161 #[serde(rename = "minicpmo")]
162 MiniCpmO,
163 #[serde(rename = "phi4mm")]
164 Phi4MM,
165 #[serde(rename = "qwen2_5vl")]
166 Qwen2_5VL,
167 #[serde(rename = "gemma3")]
168 Gemma3,
169 #[serde(rename = "mistral3")]
170 Mistral3,
171 #[serde(rename = "llama4")]
172 Llama4,
173 #[serde(rename = "gemma3n")]
174 Gemma3n,
175}
176
177impl VisionLoaderType {
179 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
180 match name {
181 "Phi3VForCausalLM" => Ok(Self::Phi3V),
182 "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
183 "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
184 "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
185 "MllamaForConditionalGeneration" => Ok(Self::VLlama),
186 "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
187 "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
188 "MiniCPMO" => Ok(Self::MiniCpmO),
189 "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
190 "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
191 "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
192 "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
193 "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
194 "Gemma3nForConditionalGeneration" => Ok(Self::Gemma3n),
195 other => anyhow::bail!(
196 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
197 ),
198 }
199 }
200}
201
202impl FromStr for VisionLoaderType {
203 type Err = String;
204 fn from_str(s: &str) -> Result<Self, Self::Err> {
205 match s {
206 "phi3v" => Ok(Self::Phi3V),
207 "idefics2" => Ok(Self::Idefics2),
208 "llava_next" => Ok(Self::LLaVANext),
209 "llava" => Ok(Self::LLaVA),
210 "vllama" => Ok(Self::VLlama),
211 "qwen2vl" => Ok(Self::Qwen2VL),
212 "idefics3" => Ok(Self::Idefics3),
213 "minicpmo" => Ok(Self::MiniCpmO),
214 "phi4mm" => Ok(Self::Phi4MM),
215 "qwen2_5vl" => Ok(Self::Qwen2_5VL),
216 "gemma3" => Ok(Self::Gemma3),
217 "mistral3" => Ok(Self::Mistral3),
218 "llama4" => Ok(Self::Llama4),
219 "gemma3n" => Ok(Self::Gemma3n),
220 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`, `gemma3n`.")),
221 }
222 }
223}
224
225impl std::fmt::Display for VisionLoaderType {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 let name = match self {
228 VisionLoaderType::Phi3V => "phi3v",
229 VisionLoaderType::Idefics2 => "idefics2",
230 VisionLoaderType::LLaVANext => "llava_next",
231 VisionLoaderType::LLaVA => "llava",
232 VisionLoaderType::VLlama => "vllama",
233 VisionLoaderType::Qwen2VL => "qwen2vl",
234 VisionLoaderType::Idefics3 => "idefics3",
235 VisionLoaderType::MiniCpmO => "minicpmo",
236 VisionLoaderType::Phi4MM => "phi4mm",
237 VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
238 VisionLoaderType::Gemma3 => "gemma3",
239 VisionLoaderType::Mistral3 => "mistral3",
240 VisionLoaderType::Llama4 => "llama4",
241 VisionLoaderType::Gemma3n => "gemma3n",
242 };
243 write!(f, "{name}")
244 }
245}
246
247#[derive(Deserialize)]
248struct AutoVisionLoaderConfig {
249 architectures: Vec<String>,
250}
251
252pub struct AutoVisionLoader;
254
255impl AutoVisionLoader {
256 fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
257 let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
258 if auto_cfg.architectures.len() != 1 {
259 anyhow::bail!("Expected exactly one architecture in config");
260 }
261
262 let name = &auto_cfg.architectures[0];
263 let tp = VisionLoaderType::from_causal_lm_name(name)?;
264
265 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
266
267 Ok(match tp {
269 VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
270 VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
271 VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
272 VisionLoaderType::LLaVA => Box::new(LLaVALoader),
273 VisionLoaderType::VLlama => Box::new(VLlamaLoader),
274 VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
275 VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
276 VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
277 VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
278 VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
279 VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
280 VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
281 VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
282 VisionLoaderType::Gemma3n => Box::new(Gemma3nLoader),
283 })
284 }
285}
286
287impl VisionModelLoader for AutoVisionLoader {
288 fn load(
289 &self,
290 config: &str,
291 vb: ShardedVarBuilder,
292 normal_loading_metadata: NormalLoadingMetadata,
293 attention_mechanism: AttentionImplementation,
294 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
295 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
296 }
297
298 fn is_gptx(&self, config: &str) -> bool {
299 Self::get_loader(config)
300 .expect("AutoVisionLoader get_loader")
301 .is_gptx(config)
302 }
303
304 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
305 Self::get_loader(config)?.get_config_repr(config)
306 }
307
308 fn get_processor(
309 &self,
310 model_config: &str,
311 proc_cfg: Option<ProcessorConfig>,
312 preproc_cfg: PreProcessorConfig,
313 max_edge: Option<u32>,
314 ) -> Arc<dyn Processor + Send + Sync> {
315 Self::get_loader(model_config)
316 .expect("AutoVisionLoader get_loader")
317 .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
318 }
319
320 fn supports_paged_attention(&self, config: &str) -> bool {
321 Self::get_loader(config)
322 .expect("AutoVisionLoader")
323 .supports_paged_attention(config)
324 }
325
326 fn modalities(&self, config: &str) -> Result<Modalities> {
327 Self::get_loader(config)?.modalities(config)
328 }
329
330 fn supports_prefix_cacher(&self, config: &str) -> bool {
331 Self::get_loader(config)
332 .expect("AutoVisionLoader")
333 .supports_prefix_cacher(config)
334 }
335
336 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
337 Self::get_loader(config)
338 .expect("AutoVisionLoader")
339 .prefixer(config)
340 }
341
342 fn get_device_for_tensor(
343 &self,
344 config: &str,
345 mapper: &dyn DeviceMapper,
346 loading_isq: bool,
347 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
348 Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
349 }
350}
351
352impl IsqModelLoader for AutoVisionLoader {
353 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
354 Self::get_loader(config)?.isq_layer_regexes(config)
355 }
356 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
357 Self::get_loader(config)?.immediate_isq_predicates(config)
358 }
359}
360
361impl DeviceMappedModelLoader for AutoVisionLoader {
362 fn mapped_max_act_size_elems(
363 &self,
364 config: &str,
365 params: &AutoDeviceMapParams,
366 ) -> Result<usize> {
367 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
368 }
369 fn non_mapped_max_act_size_elems(
370 &self,
371 config: &str,
372 params: &AutoDeviceMapParams,
373 ) -> Result<usize> {
374 Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
375 }
376 fn non_mapped_size_in_bytes(
377 &self,
378 config: &str,
379 dtype: DType,
380 weight_pack_factor: usize,
381 _matformer_config: Option<&MatformerSliceConfig>,
382 ) -> Result<usize> {
383 Self::get_loader(config)?.non_mapped_size_in_bytes(
384 config,
385 dtype,
386 weight_pack_factor,
387 _matformer_config,
388 )
389 }
390 fn layer_sizes_in_bytes(
391 &self,
392 config: &str,
393 dtype: DType,
394 weight_pack_factor: usize,
395 _matformer_config: Option<&MatformerSliceConfig>,
396 ) -> Result<Vec<usize>> {
397 Self::get_loader(config)?.layer_sizes_in_bytes(
398 config,
399 dtype,
400 weight_pack_factor,
401 _matformer_config,
402 )
403 }
404 fn num_layers(&self, config: &str) -> Result<usize> {
405 Self::get_loader(config)?.num_layers(config)
406 }
407 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
408 Self::get_loader(config)?.model_config(config)
409 }
410}
411
412macro_rules! bias_if {
413 ($cond:expr, $size:expr) => {
414 if $cond {
415 $size
416 } else {
417 0
418 }
419 };
420}
421
422fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
423 let pre_layer_norm = cfg.hidden_size;
424 let final_layer_norm = cfg.hidden_size;
425
426 let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
427 let num_positions = num_patches + 1;
428
429 let class_embedding = cfg.hidden_size;
430
431 let position_ids = num_positions;
432 let position_embedding = num_positions * cfg.hidden_size;
433
434 let conv2dconfig = Conv2dConfig {
435 stride: cfg.patch_size,
436 ..Default::default()
437 };
438 let patch_embedding =
439 cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
440
441 let encoder_layer_elems = {
442 let layer_norm1 = cfg.hidden_size;
443 let layer_norm2 = cfg.hidden_size;
444
445 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
446 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
447 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
448 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
449
450 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
451 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
452
453 layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
454 };
455
456 pre_layer_norm
457 + final_layer_norm
458 + class_embedding
459 + position_ids
460 + position_embedding
461 + patch_embedding
462 + cfg.num_hidden_layers * encoder_layer_elems
463}
464
465pub struct Phi3VLoader;
471
472pub struct Phi3VPrefixer;
473
474impl MultimodalPromptPrefixer for Phi3VPrefixer {
475 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
476 format!(
478 "{}{prompt}",
479 image_indexes
480 .into_iter()
481 .map(|image_index| format!("<|image_{}|>", image_index + 1))
482 .join("")
483 )
484 }
485}
486
487impl VisionModelLoader for Phi3VLoader {
488 fn load(
489 &self,
490 config: &str,
491 vb: ShardedVarBuilder,
492 normal_loading_metadata: NormalLoadingMetadata,
493 attention_mechanism: AttentionImplementation,
494 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
495 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
496 Ok(Box::new(Phi3::new(
497 &cfg,
498 vb,
499 self.is_gptx(config),
500 normal_loading_metadata,
501 attention_mechanism,
502 )?))
503 }
504 fn is_gptx(&self, _config: &str) -> bool {
505 true
506 }
507 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
508 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
509 Ok(Box::new(cfg))
510 }
511 fn get_processor(
512 &self,
513 _model_config: &str,
514 processor_config: Option<ProcessorConfig>,
515 preprocessor_config: PreProcessorConfig,
516 _max_edge: Option<u32>,
517 ) -> Arc<dyn Processor + Send + Sync> {
518 Phi3Processor::new_processor(processor_config, preprocessor_config)
519 }
520 fn supports_paged_attention(&self, _config: &str) -> bool {
521 true
522 }
523 fn supports_prefix_cacher(&self, _config: &str) -> bool {
524 true
525 }
526 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
527 Arc::new(Phi3VPrefixer)
528 }
529 fn modalities(&self, _config: &str) -> Result<Modalities> {
530 Ok(Modalities {
531 input: vec![SupportedModality::Text, SupportedModality::Vision],
532 output: vec![SupportedModality::Text],
533 })
534 }
535}
536
537impl IsqModelLoader for Phi3VLoader {
538 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
539 Ok(vec![
540 Regex::new(r"lm_head\.(weight|bias)$")?,
541 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
543 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
544 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
546 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
547 ])
548 }
549 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
550 self.isq_layer_regexes(config)
551 }
552}
553
554impl DeviceMappedModelLoader for Phi3VLoader {
555 fn mapped_max_act_size_elems(
556 &self,
557 config: &str,
558 params: &AutoDeviceMapParams,
559 ) -> Result<usize> {
560 let AutoDeviceMapParams::Vision {
562 max_seq_len,
563 max_batch_size,
564 max_image_shape: _,
565 max_num_images,
566 } = params
567 else {
568 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
569 };
570
571 let cfg: Phi3Config = serde_json::from_str(config)?;
572
573 let vcfg = &PHI3V_CLIP_CONFIG;
574
575 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
576 let img_seq_len = (num_patches + 1) * max_num_images;
577
578 let max_text_attn = {
579 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
581 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
582 };
583
584 Ok(max_text_attn)
585 }
586
587 fn non_mapped_max_act_size_elems(
588 &self,
589 config: &str,
590 params: &AutoDeviceMapParams,
591 ) -> Result<usize> {
592 let AutoDeviceMapParams::Vision {
594 max_seq_len: _,
595 max_batch_size,
596 max_image_shape: _,
597 max_num_images,
598 } = params
599 else {
600 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
601 };
602
603 let cfg: Phi3Config = serde_json::from_str(config)?;
604
605 let vcfg = &PHI3V_CLIP_CONFIG;
606
607 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
608 let img_seq_len = num_patches + 1;
609
610 let max_vision_attn = {
611 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
612 };
613
614 Ok(max_vision_attn)
615 }
616
617 fn non_mapped_size_in_bytes(
618 &self,
619 config: &str,
620 dtype: DType,
621 weight_pack_factor: usize,
622 _matformer_config: Option<&MatformerSliceConfig>,
623 ) -> Result<usize> {
624 let cfg: Phi3Config = serde_json::from_str(config)?;
625 let elems = {
626 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
627 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
629 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
630 } else {
631 0
632 };
633 let norm = cfg.hidden_size;
634
635 let image_embed = {
636 let projection_cls = cfg
637 .embd_layer
638 .projection_cls
639 .clone()
640 .unwrap_or("linear".to_string());
641 let with_learnable_separator =
642 cfg.embd_layer.with_learnable_separator.unwrap_or(false);
643 let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
644 let image_dim_out = cfg.img_processor.image_dim_out;
645
646 let proj = match (projection_cls.as_str(), use_hd_transform) {
647 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
648 ("mlp", true) => {
649 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
650 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
651 a + b
652 }
653 ("mlp", false) => {
654 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
655 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
656 a + b
657 }
658 _ => {
659 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
660 }
661 };
662
663 let (glb_gn, sub_gn) = if with_learnable_separator {
664 let glb_gn = image_dim_out * 4;
665 let sub_gn = image_dim_out * 4;
666 (glb_gn, sub_gn)
667 } else {
668 (0, 0)
669 };
670
671 let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
672
673 proj + glb_gn + sub_gn + clip_vit
674 };
675
676 embed_tokens + lm_head + norm + image_embed
677 };
678
679 Ok(elems * dtype.size_in_bytes())
680 }
681
682 fn layer_sizes_in_bytes(
683 &self,
684 config: &str,
685 dtype: DType,
686 weight_pack_factor: usize,
687 _matformer_config: Option<&MatformerSliceConfig>,
688 ) -> Result<Vec<usize>> {
689 let cfg: Phi3Config = serde_json::from_str(config)?;
690 let per_layer_elems = {
691 let input_layernorm = cfg.hidden_size;
692 let post_attention_layernorm = cfg.hidden_size;
693
694 let size_in = cfg.hidden_size;
695 let head_dim = cfg.head_dim();
696 let op_size =
697 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
698 let qkv_proj = size_in * op_size / weight_pack_factor;
699 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
700
701 let h_size = cfg.hidden_size;
702 let i_size = cfg.intermediate_size;
703 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
704 let down_proj = h_size * i_size / weight_pack_factor;
705
706 input_layernorm
707 + post_attention_layernorm
708 + qkv_proj
709 + o_proj
710 + gate_up_proj
711 + down_proj
712 };
713 Ok(vec![
714 per_layer_elems * dtype.size_in_bytes();
715 cfg.num_hidden_layers
716 ])
717 }
718
719 fn num_layers(&self, config: &str) -> Result<usize> {
720 let cfg: Phi3Config = serde_json::from_str(config)?;
721 Ok(cfg.num_hidden_layers)
722 }
723
724 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
725 let cfg: Phi3Config = serde_json::from_str(config)?;
726
727 let cfg = ModelConfigMetadata {
728 max_seq_len: cfg.max_position_embeddings,
729 num_layers: cfg.num_hidden_layers,
730 hidden_size: cfg.hidden_size,
731 num_kv_heads: cfg.num_key_value_heads,
732 num_attn_heads: cfg.num_attention_heads,
733 sliding_window: cfg.sliding_window,
734 k_head_dim: cfg.head_dim(),
735 v_head_dim: cfg.head_dim(),
736 };
737
738 Ok(Box::new(cfg))
739 }
740
741 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
742 Some(vec![NonMappedSubModel::Vision])
743 }
744}
745
746pub struct Idefics2Loader;
752
753pub struct Idefics2Prefixer;
754
755impl MultimodalPromptPrefixer for Idefics2Prefixer {
756 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
757 prompt.to_string()
759 }
760}
761
762impl VisionModelLoader for Idefics2Loader {
763 fn load(
764 &self,
765 config: &str,
766 vb: ShardedVarBuilder,
767 normal_loading_metadata: NormalLoadingMetadata,
768 attention_mechanism: AttentionImplementation,
769 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
770 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
771 Ok(Box::new(Idefics2::new(
772 &cfg,
773 vb,
774 self.is_gptx(config),
775 normal_loading_metadata,
776 attention_mechanism,
777 )?))
778 }
779 fn is_gptx(&self, _config: &str) -> bool {
780 true
781 }
782 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
783 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
784 Ok(Box::new(cfg))
785 }
786 fn get_processor(
787 &self,
788 _model_config: &str,
789 processor_config: Option<ProcessorConfig>,
790 preprocessor_config: PreProcessorConfig,
791 max_edge: Option<u32>,
792 ) -> Arc<dyn Processor + Send + Sync> {
793 Arc::new(Idefics2Processor::new(
794 processor_config.unwrap(),
795 preprocessor_config,
796 max_edge,
797 ))
798 }
799 fn supports_paged_attention(&self, _config: &str) -> bool {
800 true
801 }
802 fn supports_prefix_cacher(&self, _config: &str) -> bool {
803 true
804 }
805 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
806 Arc::new(Idefics2Prefixer)
807 }
808 fn modalities(&self, _config: &str) -> Result<Modalities> {
809 Ok(Modalities {
810 input: vec![SupportedModality::Text, SupportedModality::Vision],
811 output: vec![SupportedModality::Text],
812 })
813 }
814}
815
816impl IsqModelLoader for Idefics2Loader {
817 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
818 Ok(vec![
819 Regex::new(r"lm_head\.(weight|bias)$")?,
820 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
822 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
823 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
824 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
825 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
827 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
828 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
829 ])
830 }
831 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
832 Ok(vec![
833 Regex::new(r"lm_head\.(weight|bias)$")?,
834 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
836 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
837 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
838 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
839 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
841 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
842 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
843 ])
844 }
845}
846
847impl DeviceMappedModelLoader for Idefics2Loader {
848 fn mapped_max_act_size_elems(
849 &self,
850 config: &str,
851 params: &AutoDeviceMapParams,
852 ) -> Result<usize> {
853 let AutoDeviceMapParams::Vision {
854 max_seq_len,
855 max_batch_size,
856 max_image_shape: _,
857 max_num_images,
858 } = params
859 else {
860 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
861 };
862
863 let cfg: Idefics2Config = serde_json::from_str(config)?;
864
865 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
866 let img_seq_len = (num_patches + 1) * max_num_images;
867
868 let max_text_attn = {
869 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
871 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
872 };
873
874 Ok(max_text_attn)
875 }
876
877 fn non_mapped_max_act_size_elems(
878 &self,
879 config: &str,
880 params: &AutoDeviceMapParams,
881 ) -> Result<usize> {
882 let AutoDeviceMapParams::Vision {
883 max_seq_len: _,
884 max_batch_size,
885 max_image_shape: _,
886 max_num_images,
887 } = params
888 else {
889 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
890 };
891
892 let cfg: Idefics2Config = serde_json::from_str(config)?;
893
894 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
895 let img_seq_len = num_patches + 1;
896
897 let max_vision_attn = {
898 let images_factor = 5;
900
901 (max_batch_size * images_factor * max_num_images)
902 * cfg.vision_config.num_attention_heads
903 * img_seq_len
904 * img_seq_len
905 };
906
907 Ok(max_vision_attn)
908 }
909
910 fn non_mapped_size_in_bytes(
911 &self,
912 config: &str,
913 dtype: DType,
914 weight_pack_factor: usize,
915 _matformer_config: Option<&MatformerSliceConfig>,
916 ) -> Result<usize> {
917 let cfg: Idefics2Config = serde_json::from_str(config)?;
918 let text_elems = {
919 let tie_word_embeddings = cfg.tie_word_embeddings;
920 let cfg = &cfg.text_config;
921
922 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
923 let lm_head = if !tie_word_embeddings {
924 cfg.hidden_size * cfg.vocab_size
925 } else {
926 0
927 };
928 let norm = cfg.hidden_size;
929 embed_tokens + lm_head + norm
930 };
931
932 let connector_elems = {
933 let tcfg = &cfg.text_config;
934 let vcfg = &cfg.vision_config;
935 let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
936 let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
937 let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
938
939 let perceiver_elems = {
940 let tcfg = &cfg.text_config;
941 let pcfg = &cfg.perceiver_config;
942
943 let n_latents = pcfg.resampler_n_latents;
944 let hidden_size = tcfg.hidden_size;
945 let depth = pcfg.resampler_depth;
946
947 let norm = tcfg.hidden_size;
948 let latents = n_latents * hidden_size;
949
950 let layer_elems = {
951 let input_latents_norm = hidden_size;
952 let input_context_norm = hidden_size;
953 let post_attn_norm = hidden_size;
954
955 let num_heads = pcfg.resampler_n_heads;
956 let head_dim = pcfg.resampler_head_dim;
957 let num_key_value_heads = pcfg.num_key_value_heads;
958
959 let q_proj = hidden_size * num_heads * head_dim;
960 let k_proj = hidden_size * num_key_value_heads * head_dim;
961 let v_proj = hidden_size * num_key_value_heads * head_dim;
962 let o_proj = num_heads * head_dim * hidden_size;
963
964 let gate_proj = hidden_size * hidden_size * 4;
965 let up_proj = hidden_size * hidden_size * 4;
966 let down_proj = hidden_size * 4 * hidden_size;
967
968 input_latents_norm
969 + input_context_norm
970 + post_attn_norm
971 + q_proj
972 + k_proj
973 + v_proj
974 + o_proj
975 + gate_proj
976 + up_proj
977 + down_proj
978 };
979
980 norm + latents + layer_elems * depth
981 };
982
983 gate_proj + up_proj + down_proj + perceiver_elems
984 };
985
986 let vision_transformer = {
987 let cfg = &cfg.vision_config;
988
989 let post_layernorm = cfg.hidden_size;
990
991 let conv_config = Conv2dConfig {
992 stride: cfg.patch_size,
993 ..Default::default()
994 };
995 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
996 * cfg.patch_size
997 * cfg.patch_size;
998
999 let num_patches_per_side = cfg.image_size / cfg.patch_size;
1000 let num_patches = num_patches_per_side.pow(2);
1001 let position_embedding = num_patches * cfg.hidden_size;
1002
1003 let layer_elems = {
1004 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1005 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1006
1007 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
1008 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
1009
1010 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1011 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1012 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1013 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1014
1015 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
1016 };
1017
1018 post_layernorm + patch_embedding + position_embedding + layer_elems
1019 };
1020
1021 let elems = text_elems + connector_elems + vision_transformer;
1022
1023 Ok(elems * dtype.size_in_bytes())
1024 }
1025
1026 fn layer_sizes_in_bytes(
1027 &self,
1028 config: &str,
1029 dtype: DType,
1030 weight_pack_factor: usize,
1031 _matformer_config: Option<&MatformerSliceConfig>,
1032 ) -> Result<Vec<usize>> {
1033 let cfg: Idefics2Config = serde_json::from_str(config)?;
1034 let cfg = cfg.text_config;
1035 let per_layer_elems = {
1036 let input_layernorm = cfg.hidden_size;
1037 let post_attention_layernorm = cfg.hidden_size;
1038
1039 let size_in = cfg.hidden_size;
1040 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1041 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1042 let q_proj = size_in * size_q / weight_pack_factor;
1043 let k_proj = size_in * size_kv / weight_pack_factor;
1044 let v_proj = size_in * size_kv / weight_pack_factor;
1045 let o_proj = size_q * size_in / weight_pack_factor;
1046
1047 let h_size = cfg.hidden_size;
1048 let i_size = cfg.intermediate_size;
1049 let gate_proj = h_size * i_size / weight_pack_factor;
1050 let up_proj = h_size * i_size / weight_pack_factor;
1051 let down_proj = i_size * h_size / weight_pack_factor;
1052
1053 input_layernorm
1054 + post_attention_layernorm
1055 + q_proj
1056 + k_proj
1057 + v_proj
1058 + o_proj
1059 + gate_proj
1060 + up_proj
1061 + down_proj
1062 };
1063 Ok(vec![
1064 per_layer_elems * dtype.size_in_bytes();
1065 cfg.num_hidden_layers
1066 ])
1067 }
1068
1069 fn num_layers(&self, config: &str) -> Result<usize> {
1070 let cfg: Idefics2Config = serde_json::from_str(config)?;
1071 Ok(cfg.text_config.num_hidden_layers)
1072 }
1073 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1074 let cfg: Idefics2Config = serde_json::from_str(config)?;
1075 let cfg = &cfg.text_config;
1076
1077 let cfg = ModelConfigMetadata {
1078 max_seq_len: cfg.max_position_embeddings,
1079 num_layers: cfg.num_hidden_layers,
1080 hidden_size: cfg.hidden_size,
1081 num_kv_heads: cfg.num_key_value_heads,
1082 num_attn_heads: cfg.num_attention_heads,
1083 sliding_window: cfg.sliding_window,
1084 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1085 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1086 };
1087
1088 Ok(Box::new(cfg))
1089 }
1090
1091 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1092 Some(vec![NonMappedSubModel::Vision])
1093 }
1094}
1095
1096pub struct LLaVANextLoader;
1102
1103pub struct LLaVANextPrefixer;
1104
1105impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1106 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1107 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1108 }
1109}
1110
1111impl VisionModelLoader for LLaVANextLoader {
1112 fn load(
1113 &self,
1114 config: &str,
1115 vb: ShardedVarBuilder,
1116 normal_loading_metadata: NormalLoadingMetadata,
1117 attention_mechanism: AttentionImplementation,
1118 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1119 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1120 Ok(Box::new(LLaVANext::new(
1121 &cfg,
1122 vb,
1123 self.is_gptx(config),
1124 normal_loading_metadata,
1125 attention_mechanism,
1126 )?))
1127 }
1128 fn is_gptx(&self, _config: &str) -> bool {
1129 false
1130 }
1131 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1132 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1133 Ok(Box::new(cfg))
1134 }
1135 fn get_processor(
1136 &self,
1137 model_config: &str,
1138 _processor_config: Option<ProcessorConfig>,
1139 _preprocessor_config: PreProcessorConfig,
1140 _max_edge: Option<u32>,
1141 ) -> Arc<dyn Processor + Send + Sync> {
1142 Arc::new(LLaVANextProcessor::new(model_config))
1143 }
1144 fn supports_paged_attention(&self, _config: &str) -> bool {
1145 true
1146 }
1147 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1148 true
1149 }
1150 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1151 Arc::new(LLaVANextPrefixer)
1152 }
1153 fn modalities(&self, _config: &str) -> Result<Modalities> {
1154 Ok(Modalities {
1155 input: vec![SupportedModality::Text, SupportedModality::Vision],
1156 output: vec![SupportedModality::Text],
1157 })
1158 }
1159}
1160
1161impl IsqModelLoader for LLaVANextLoader {
1162 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1163 Ok(vec![
1164 Regex::new(r"lm_head\.(weight|bias)$")?,
1165 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1167 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1168 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1169 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1170 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1172 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1173 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1174 ])
1175 }
1176 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1177 Ok(vec![
1178 Regex::new(r"lm_head\.(weight|bias)$")?,
1179 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1181 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1182 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1183 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1184 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1186 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1187 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1188 ])
1189 }
1190}
1191
1192impl DeviceMappedModelLoader for LLaVANextLoader {
1193 fn mapped_max_act_size_elems(
1194 &self,
1195 config: &str,
1196 params: &AutoDeviceMapParams,
1197 ) -> Result<usize> {
1198 let AutoDeviceMapParams::Vision {
1199 max_seq_len,
1200 max_batch_size,
1201 max_image_shape,
1202 max_num_images,
1203 } = params
1204 else {
1205 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1206 };
1207
1208 let config: LLaVAConfig = serde_json::from_str(config)?;
1209
1210 #[allow(clippy::cast_possible_truncation)]
1211 let img_seq_len =
1212 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1213 &config,
1214 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1215 );
1216 let img_seq_len = img_seq_len * max_num_images;
1217
1218 let max_text_attn = {
1219 let cfg = &config.text_config;
1220 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1222
1223 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1224 };
1225
1226 Ok(max_text_attn)
1227 }
1228
1229 fn non_mapped_max_act_size_elems(
1230 &self,
1231 config: &str,
1232 params: &AutoDeviceMapParams,
1233 ) -> Result<usize> {
1234 let AutoDeviceMapParams::Vision {
1235 max_seq_len: _,
1236 max_batch_size,
1237 max_image_shape,
1238 max_num_images,
1239 } = params
1240 else {
1241 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1242 };
1243
1244 let config: LLaVAConfig = serde_json::from_str(config)?;
1245
1246 #[allow(clippy::cast_possible_truncation)]
1247 let img_seq_len =
1248 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1249 &config,
1250 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1251 );
1252
1253 let max_vision_attn = {
1254 (max_batch_size * max_num_images)
1255 * config.vision_config.num_attention_heads
1256 * img_seq_len
1257 * img_seq_len
1258 };
1259
1260 Ok(max_vision_attn)
1261 }
1262
1263 fn non_mapped_size_in_bytes(
1264 &self,
1265 config: &str,
1266 dtype: DType,
1267 weight_pack_factor: usize,
1268 _matformer_config: Option<&MatformerSliceConfig>,
1269 ) -> Result<usize> {
1270 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1271 let text_elems = {
1272 let cfg = &cfg.text_config;
1273 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1274 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1275 let norm = cfg.hidden_size;
1276 embed_tokens + lm_head + norm
1277 };
1278
1279 let image_newline = cfg.text_config.hidden_size;
1280 let mmproj = {
1281 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1282 + cfg.text_config.hidden_size;
1283 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1284 + cfg.text_config.hidden_size;
1285
1286 linear_1 + linear_2
1287 };
1288 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1289
1290 let elems = text_elems + image_newline + mmproj + vision_tower;
1291 Ok(elems * dtype.size_in_bytes())
1292 }
1293
1294 fn layer_sizes_in_bytes(
1295 &self,
1296 config: &str,
1297 dtype: DType,
1298 weight_pack_factor: usize,
1299 _matformer_config: Option<&MatformerSliceConfig>,
1300 ) -> Result<Vec<usize>> {
1301 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1302 let per_layer_elems = {
1303 let cfg = &cfg.text_config;
1304 let input_layernorm = cfg.hidden_size;
1305 let post_attention_layernorm = cfg.hidden_size;
1306
1307 let size_in = cfg.hidden_size;
1308 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1309 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1310 let q_proj = size_in * size_q / weight_pack_factor;
1311 let k_proj = size_in * size_kv / weight_pack_factor;
1312 let v_proj = size_in * size_kv / weight_pack_factor;
1313 let o_proj = size_q * size_in / weight_pack_factor;
1314
1315 let h_size = cfg.hidden_size;
1316 let i_size = cfg.intermediate_size;
1317 let gate_proj = h_size * i_size / weight_pack_factor;
1318 let up_proj = h_size * i_size / weight_pack_factor;
1319 let down_proj = i_size * h_size / weight_pack_factor;
1320
1321 input_layernorm
1322 + post_attention_layernorm
1323 + q_proj
1324 + k_proj
1325 + v_proj
1326 + o_proj
1327 + gate_proj
1328 + up_proj
1329 + down_proj
1330 };
1331 Ok(vec![
1332 per_layer_elems * dtype.size_in_bytes();
1333 cfg.text_config.num_hidden_layers
1334 ])
1335 }
1336
1337 fn num_layers(&self, config: &str) -> Result<usize> {
1338 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1339 Ok(cfg.text_config.num_hidden_layers)
1340 }
1341
1342 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1343 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1344 let cfg = &cfg.text_config;
1345
1346 let cfg = ModelConfigMetadata {
1347 max_seq_len: cfg.max_position_embeddings,
1348 num_layers: cfg.num_hidden_layers,
1349 hidden_size: cfg.hidden_size,
1350 num_kv_heads: cfg.num_key_value_heads,
1351 num_attn_heads: cfg.num_attention_heads,
1352 sliding_window: cfg.sliding_window,
1353 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1354 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1355 };
1356
1357 Ok(Box::new(cfg))
1358 }
1359
1360 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1361 Some(vec![NonMappedSubModel::Vision])
1362 }
1363}
1364
1365pub struct LLaVALoader;
1371
1372pub struct LLaVAPrefixer;
1373
1374impl MultimodalPromptPrefixer for LLaVAPrefixer {
1375 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1376 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1377 }
1378}
1379
1380impl VisionModelLoader for LLaVALoader {
1381 fn load(
1382 &self,
1383 config: &str,
1384 vb: ShardedVarBuilder,
1385 normal_loading_metadata: NormalLoadingMetadata,
1386 attention_mechanism: AttentionImplementation,
1387 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1388 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1389 Ok(Box::new(LLaVA::new(
1390 &cfg,
1391 vb,
1392 self.is_gptx(config),
1393 normal_loading_metadata,
1394 attention_mechanism,
1395 )?))
1396 }
1397 fn is_gptx(&self, _config: &str) -> bool {
1398 false
1399 }
1400 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1401 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1402 Ok(Box::new(cfg))
1403 }
1404 fn get_processor(
1405 &self,
1406 model_config: &str,
1407 _processor_config: Option<ProcessorConfig>,
1408 _preprocessor_config: PreProcessorConfig,
1409 _max_edge: Option<u32>,
1410 ) -> Arc<dyn Processor + Send + Sync> {
1411 Arc::new(LLaVAProcessor::new(model_config))
1412 }
1413 fn supports_paged_attention(&self, _config: &str) -> bool {
1414 true
1415 }
1416 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1417 true
1418 }
1419 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1420 Arc::new(LLaVAPrefixer)
1421 }
1422 fn modalities(&self, _config: &str) -> Result<Modalities> {
1423 Ok(Modalities {
1424 input: vec![SupportedModality::Text, SupportedModality::Vision],
1425 output: vec![SupportedModality::Text],
1426 })
1427 }
1428}
1429
1430impl IsqModelLoader for LLaVALoader {
1431 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1432 Ok(vec![
1433 Regex::new(r"lm_head\.(weight|bias)$")?,
1434 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1436 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1437 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1438 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1439 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1441 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1442 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1443 ])
1444 }
1445 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1446 Ok(vec![
1447 Regex::new(r"lm_head\.(weight|bias)$")?,
1448 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1450 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1451 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1452 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1453 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1455 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1456 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1457 ])
1458 }
1459}
1460
1461impl DeviceMappedModelLoader for LLaVALoader {
1462 fn mapped_max_act_size_elems(
1463 &self,
1464 config: &str,
1465 params: &AutoDeviceMapParams,
1466 ) -> Result<usize> {
1467 let AutoDeviceMapParams::Vision {
1468 max_seq_len,
1469 max_batch_size,
1470 max_image_shape: _,
1471 max_num_images,
1472 } = params
1473 else {
1474 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1475 };
1476
1477 let config: LLaVAConfig = serde_json::from_str(config)?;
1478
1479 let img_seq_len =
1480 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1481 let img_seq_len = img_seq_len * max_num_images;
1482
1483 let max_text_attn = {
1484 let cfg = &config.text_config;
1485 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1487
1488 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1489 };
1490
1491 Ok(max_text_attn)
1492 }
1493
1494 fn non_mapped_max_act_size_elems(
1495 &self,
1496 config: &str,
1497 params: &AutoDeviceMapParams,
1498 ) -> Result<usize> {
1499 let AutoDeviceMapParams::Vision {
1500 max_seq_len: _,
1501 max_batch_size,
1502 max_image_shape: _,
1503 max_num_images,
1504 } = params
1505 else {
1506 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1507 };
1508
1509 let config: LLaVAConfig = serde_json::from_str(config)?;
1510
1511 let img_seq_len =
1512 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1513
1514 let max_vision_attn = {
1515 (max_batch_size * max_num_images)
1516 * config.vision_config.num_attention_heads
1517 * img_seq_len
1518 * img_seq_len
1519 };
1520
1521 Ok(max_vision_attn)
1522 }
1523
1524 fn non_mapped_size_in_bytes(
1525 &self,
1526 config: &str,
1527 dtype: DType,
1528 weight_pack_factor: usize,
1529 _matformer_config: Option<&MatformerSliceConfig>,
1530 ) -> Result<usize> {
1531 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1532 let text_elems = {
1533 let cfg = &cfg.text_config;
1534 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1535 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1536 let norm = cfg.hidden_size;
1537 embed_tokens + lm_head + norm
1538 };
1539
1540 let image_newline = cfg.text_config.hidden_size;
1541 let mmproj = {
1542 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1543 + cfg.text_config.hidden_size;
1544 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1545 + cfg.text_config.hidden_size;
1546
1547 linear_1 + linear_2
1548 };
1549 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1550
1551 let elems = text_elems + image_newline + mmproj + vision_tower;
1552 Ok(elems * dtype.size_in_bytes())
1553 }
1554
1555 fn layer_sizes_in_bytes(
1556 &self,
1557 config: &str,
1558 dtype: DType,
1559 weight_pack_factor: usize,
1560 _matformer_config: Option<&MatformerSliceConfig>,
1561 ) -> Result<Vec<usize>> {
1562 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1563 let per_layer_elems = {
1564 let cfg = &cfg.text_config;
1565 let input_layernorm = cfg.hidden_size;
1566 let post_attention_layernorm = cfg.hidden_size;
1567
1568 let size_in = cfg.hidden_size;
1569 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1570 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1571 let q_proj = size_in * size_q / weight_pack_factor;
1572 let k_proj = size_in * size_kv / weight_pack_factor;
1573 let v_proj = size_in * size_kv / weight_pack_factor;
1574 let o_proj = size_q * size_in / weight_pack_factor;
1575
1576 let h_size = cfg.hidden_size;
1577 let i_size = cfg.intermediate_size;
1578 let gate_proj = h_size * i_size / weight_pack_factor;
1579 let up_proj = h_size * i_size / weight_pack_factor;
1580 let down_proj = i_size * h_size / weight_pack_factor;
1581
1582 input_layernorm
1583 + post_attention_layernorm
1584 + q_proj
1585 + k_proj
1586 + v_proj
1587 + o_proj
1588 + gate_proj
1589 + up_proj
1590 + down_proj
1591 };
1592 Ok(vec![
1593 per_layer_elems * dtype.size_in_bytes();
1594 cfg.text_config.num_hidden_layers
1595 ])
1596 }
1597
1598 fn num_layers(&self, config: &str) -> Result<usize> {
1599 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1600 Ok(cfg.text_config.num_hidden_layers)
1601 }
1602
1603 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1604 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1605 let cfg = &cfg.text_config;
1606
1607 let cfg = ModelConfigMetadata {
1608 max_seq_len: cfg.max_position_embeddings,
1609 num_layers: cfg.num_hidden_layers,
1610 hidden_size: cfg.hidden_size,
1611 num_kv_heads: cfg.num_key_value_heads,
1612 num_attn_heads: cfg.num_attention_heads,
1613 sliding_window: cfg.sliding_window,
1614 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1615 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1616 };
1617
1618 Ok(Box::new(cfg))
1619 }
1620
1621 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1622 Some(vec![NonMappedSubModel::Vision])
1623 }
1624}
1625
1626pub struct VLlamaLoader;
1632
1633pub struct VLlamaPrefixer;
1634
1635impl MultimodalPromptPrefixer for VLlamaPrefixer {
1636 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1637 format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1638 }
1639}
1640
1641impl VisionModelLoader for VLlamaLoader {
1642 fn load(
1643 &self,
1644 config: &str,
1645 vb: ShardedVarBuilder,
1646 normal_loading_metadata: NormalLoadingMetadata,
1647 attention_mechanism: AttentionImplementation,
1648 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1649 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1650 Ok(Box::new(MLlamaModel::new(
1651 &cfg,
1652 vb,
1653 self.is_gptx(config),
1654 normal_loading_metadata,
1655 attention_mechanism,
1656 )?))
1657 }
1658 fn is_gptx(&self, _config: &str) -> bool {
1659 true
1660 }
1661 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1662 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1663 Ok(Box::new(cfg))
1664 }
1665 fn get_processor(
1666 &self,
1667 _model_config: &str,
1668 _processor_config: Option<ProcessorConfig>,
1669 _preprocessor_config: PreProcessorConfig,
1670 _max_edge: Option<u32>,
1671 ) -> Arc<dyn Processor + Send + Sync> {
1672 Arc::new(MLlamaProcessor::new())
1673 }
1674 fn supports_paged_attention(&self, _config: &str) -> bool {
1675 false
1676 }
1677 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1678 true
1679 }
1680 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1681 Arc::new(VLlamaPrefixer)
1682 }
1683 fn modalities(&self, _config: &str) -> Result<Modalities> {
1684 Ok(Modalities {
1685 input: vec![SupportedModality::Text, SupportedModality::Vision],
1686 output: vec![SupportedModality::Text],
1687 })
1688 }
1689}
1690
1691impl IsqModelLoader for VLlamaLoader {
1692 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1693 let config: MLlamaConfig = serde_json::from_str(config)?;
1694 let cross_attn_layers = &config.text_config.cross_attention_layers;
1695 let transformer_layers =
1696 (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1697 let mut text_regexes = Vec::new();
1698 for layer in transformer_layers {
1699 text_regexes.extend(vec![
1700 Regex::new(&format!(
1702 r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1703 ))?,
1704 Regex::new(&format!(
1705 r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1706 ))?,
1707 Regex::new(&format!(
1708 r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1709 ))?,
1710 Regex::new(&format!(
1711 r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1712 ))?,
1713 Regex::new(&format!(
1715 r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1716 ))?,
1717 Regex::new(&format!(
1718 r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1719 ))?,
1720 Regex::new(&format!(
1721 r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1722 ))?,
1723 ]);
1724 }
1725 let vision_regexes = vec![
1726 Regex::new(
1728 r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1729 )?,
1730 Regex::new(
1731 r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1732 )?,
1733 Regex::new(
1734 r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1735 )?,
1736 Regex::new(
1737 r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1738 )?,
1739 Regex::new(
1741 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1742 )?,
1743 Regex::new(
1744 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1745 )?,
1746 Regex::new(
1747 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1748 )?,
1749 Regex::new(
1750 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1751 )?,
1752 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1754 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1755 ];
1756
1757 Ok([text_regexes, vision_regexes].concat())
1758 }
1759 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1760 self.isq_layer_regexes(config)
1761 }
1762}
1763
1764impl DeviceMappedModelLoader for VLlamaLoader {
1765 fn mapped_max_act_size_elems(
1766 &self,
1767 config: &str,
1768 params: &AutoDeviceMapParams,
1769 ) -> Result<usize> {
1770 let AutoDeviceMapParams::Vision {
1771 max_seq_len,
1772 max_batch_size,
1773 max_image_shape: _,
1774 max_num_images,
1775 } = params
1776 else {
1777 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1778 };
1779
1780 let config: MLlamaConfig = serde_json::from_str(config)?;
1781
1782 let img_seq_len = {
1783 let cfg = &config.vision_config;
1784 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1785 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1786 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1787 };
1788 let img_seq_len = img_seq_len * max_num_images;
1789
1790 let max_cross_text_attn = {
1791 let cfg = &config.text_config;
1792 max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1793 };
1794
1795 let max_self_text_attn = {
1796 let cfg = &config.text_config;
1797 max_batch_size * cfg.num_attention_heads * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)
1798 };
1799
1800 Ok(max_self_text_attn.max(max_cross_text_attn))
1801 }
1802
1803 fn non_mapped_max_act_size_elems(
1804 &self,
1805 config: &str,
1806 params: &AutoDeviceMapParams,
1807 ) -> Result<usize> {
1808 let AutoDeviceMapParams::Vision {
1809 max_seq_len: _,
1810 max_batch_size,
1811 max_image_shape: _,
1812 max_num_images,
1813 } = params
1814 else {
1815 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1816 };
1817
1818 let config: MLlamaConfig = serde_json::from_str(config)?;
1819
1820 let img_seq_len = {
1821 let cfg = &config.vision_config;
1822 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1823 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1824 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1825 };
1826 let max_vision_attn = {
1827 let cfg = &config.vision_config;
1828 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1829 };
1830
1831 Ok(max_vision_attn)
1832 }
1833
1834 fn non_mapped_size_in_bytes(
1835 &self,
1836 config: &str,
1837 dtype: DType,
1838 weight_pack_factor: usize,
1839 _matformer_config: Option<&MatformerSliceConfig>,
1840 ) -> Result<usize> {
1841 let config: MLlamaConfig = serde_json::from_str(config)?;
1842 let text_elems = {
1843 let cfg = &config.text_config;
1844 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1845 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1847 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1848 } else {
1849 0
1850 };
1851 let norm = cfg.hidden_size;
1852 embed_tokens + lm_head + norm
1853 };
1854
1855 let vision_elems = {
1856 let cfg = &config.vision_config;
1857
1858 let conv_cfg = Conv2dConfig {
1859 stride: cfg.patch_size,
1860 ..Default::default()
1861 };
1862 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1863 * cfg.patch_size
1864 * cfg.patch_size;
1865
1866 let class_embedding = cfg.hidden_size;
1867
1868 let gated_positional_embedding = {
1869 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1870 let embedding = num_patches * cfg.hidden_size;
1871 let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1872 * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1873
1874 embedding + tile_embedding
1875 };
1876
1877 let pre_tile_positional_embedding =
1878 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1879 let post_tile_positional_embedding =
1880 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1881
1882 let layernorm_pre = cfg.hidden_size;
1883 let layernorm_post = cfg.hidden_size;
1884
1885 let encoder_layer = {
1886 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1887 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1888
1889 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1890 let q_proj =
1891 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1892 let k_proj =
1893 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1894 let v_proj =
1895 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1896 let o_proj =
1897 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1898
1899 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1900 + cfg.intermediate_size;
1901 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1902 + cfg.hidden_size;
1903
1904 input_layernorm
1905 + post_attention_layernorm
1906 + q_proj
1907 + k_proj
1908 + v_proj
1909 + o_proj
1910 + fc1
1911 + fc2
1912 };
1913
1914 patch_embedding
1915 + class_embedding
1916 + gated_positional_embedding
1917 + pre_tile_positional_embedding
1918 + post_tile_positional_embedding
1919 + layernorm_pre
1920 + layernorm_post
1921 + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1922 };
1923
1924 let elems = text_elems + vision_elems;
1925 Ok(elems * dtype.size_in_bytes())
1926 }
1927
1928 fn layer_sizes_in_bytes(
1929 &self,
1930 config: &str,
1931 dtype: DType,
1932 weight_pack_factor: usize,
1933 _matformer_config: Option<&MatformerSliceConfig>,
1934 ) -> Result<Vec<usize>> {
1935 let config: MLlamaConfig = serde_json::from_str(config)?;
1936 let cfg = &config.text_config;
1937
1938 let mut layer_sizes = Vec::new();
1939
1940 for i in 0..cfg.num_hidden_layers {
1941 let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1942 1
1944 } else {
1945 weight_pack_factor
1946 };
1947
1948 let per_layer_elems = {
1949 let input_layernorm = cfg.hidden_size;
1950 let post_attention_layernorm = cfg.hidden_size;
1951
1952 let size_in = cfg.hidden_size;
1953 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1954 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1955 let q_proj = size_in * size_q / weight_pack_factor;
1956 let k_proj = size_in * size_kv / weight_pack_factor;
1957 let v_proj = size_in * size_kv / weight_pack_factor;
1958 let o_proj = size_q * size_in / weight_pack_factor;
1959
1960 let h_size = cfg.hidden_size;
1961 let i_size = cfg.intermediate_size;
1962 let gate_proj = h_size * i_size / weight_pack_factor;
1963 let up_proj = h_size * i_size / weight_pack_factor;
1964 let down_proj = i_size * h_size / weight_pack_factor;
1965
1966 input_layernorm
1967 + post_attention_layernorm
1968 + q_proj
1969 + k_proj
1970 + v_proj
1971 + o_proj
1972 + gate_proj
1973 + up_proj
1974 + down_proj
1975 };
1976
1977 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1978 }
1979
1980 Ok(layer_sizes)
1981 }
1982
1983 fn num_layers(&self, config: &str) -> Result<usize> {
1984 let config: MLlamaConfig = serde_json::from_str(config)?;
1985 Ok(config.text_config.num_hidden_layers)
1986 }
1987
1988 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1989 let cfg: MLlamaConfig = serde_json::from_str(config)?;
1990 let cfg = &cfg.text_config;
1991
1992 let cfg = ModelConfigMetadata {
1993 max_seq_len: cfg.max_position_embeddings,
1994 num_layers: cfg.num_hidden_layers,
1995 hidden_size: cfg.hidden_size,
1996 num_kv_heads: cfg.num_key_value_heads,
1997 num_attn_heads: cfg.num_attention_heads,
1998 sliding_window: None,
1999 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2000 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2001 };
2002
2003 Ok(Box::new(cfg))
2004 }
2005
2006 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2007 Some(vec![NonMappedSubModel::Vision])
2008 }
2009}
2010
2011pub struct Qwen2VLLoader;
2017
2018pub struct Qwen2VLPrefixer;
2019
2020impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
2021 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2022 format!(
2023 "{}{prompt}",
2024 format!(
2025 "{}{}{}",
2026 Qwen2VLProcessor::VISION_START,
2027 Qwen2VLProcessor::IMAGE_PAD,
2028 Qwen2VLProcessor::VISION_END
2029 )
2030 .repeat(image_indexes.len())
2031 )
2032 }
2033}
2034
2035impl VisionModelLoader for Qwen2VLLoader {
2036 fn load(
2037 &self,
2038 config: &str,
2039 vb: ShardedVarBuilder,
2040 normal_loading_metadata: NormalLoadingMetadata,
2041 attention_mechanism: AttentionImplementation,
2042 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2043 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2044 Ok(Box::new(Qwen2VLModel::new(
2045 &cfg,
2046 vb,
2047 self.is_gptx(config),
2048 normal_loading_metadata,
2049 attention_mechanism,
2050 )?))
2051 }
2052 fn is_gptx(&self, _config: &str) -> bool {
2053 true
2054 }
2055 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2056 let config: Qwen2VLConfig = serde_json::from_str(config)?;
2057 Ok(Box::new(config))
2058 }
2059 fn get_processor(
2060 &self,
2061 _model_config: &str,
2062 _processor_config: Option<ProcessorConfig>,
2063 _preprocessor_config: PreProcessorConfig,
2064 max_edge: Option<u32>,
2065 ) -> Arc<dyn Processor + Send + Sync> {
2066 Arc::new(Qwen2VLProcessor::new(max_edge))
2067 }
2068 fn supports_paged_attention(&self, _config: &str) -> bool {
2069 false
2070 }
2071 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2072 Arc::new(Qwen2VLPrefixer)
2073 }
2074 fn modalities(&self, _config: &str) -> Result<Modalities> {
2075 Ok(Modalities {
2076 input: vec![SupportedModality::Text, SupportedModality::Vision],
2077 output: vec![SupportedModality::Text],
2078 })
2079 }
2080}
2081
2082impl IsqModelLoader for Qwen2VLLoader {
2083 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2084 Ok(vec![
2085 Regex::new(r"lm_head\.(weight|bias)$")?,
2086 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2088 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2089 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2090 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2091 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2093 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2094 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2095 ])
2096 }
2097 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2098 self.isq_layer_regexes(config)
2099 }
2100}
2101
2102impl DeviceMappedModelLoader for Qwen2VLLoader {
2103 fn mapped_max_act_size_elems(
2104 &self,
2105 config: &str,
2106 params: &AutoDeviceMapParams,
2107 ) -> Result<usize> {
2108 let AutoDeviceMapParams::Vision {
2109 max_seq_len,
2110 max_batch_size,
2111 max_image_shape,
2112 max_num_images,
2113 } = params
2114 else {
2115 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2116 };
2117
2118 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2119
2120 let img_seq_len = {
2121 let cfg = &cfg.vision_config;
2122 let grid_t = max_num_images / cfg.temporal_patch_size;
2123 let grid_h = max_image_shape.0 / cfg.patch_size;
2124 let grid_w = max_image_shape.1 / cfg.patch_size;
2125 grid_t * grid_h * grid_w
2126 };
2127 let img_seq_len = img_seq_len * max_num_images;
2128
2129 let max_text_attn = {
2130 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2132 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2133 };
2134
2135 Ok(max_text_attn)
2136 }
2137
2138 fn non_mapped_max_act_size_elems(
2139 &self,
2140 config: &str,
2141 params: &AutoDeviceMapParams,
2142 ) -> Result<usize> {
2143 let AutoDeviceMapParams::Vision {
2144 max_seq_len: _,
2145 max_batch_size,
2146 max_image_shape,
2147 max_num_images,
2148 } = params
2149 else {
2150 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2151 };
2152
2153 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2154
2155 let img_seq_len = {
2156 let cfg = &cfg.vision_config;
2157 let grid_t = max_num_images / cfg.temporal_patch_size;
2158 let grid_h = max_image_shape.0 / cfg.patch_size;
2159 let grid_w = max_image_shape.1 / cfg.patch_size;
2160 grid_t * grid_h * grid_w
2161 };
2162
2163 let max_vision_attn = {
2164 let cfg = &cfg.vision_config;
2165 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2166 };
2167
2168 Ok(max_vision_attn)
2169 }
2170
2171 fn non_mapped_size_in_bytes(
2172 &self,
2173 config: &str,
2174 dtype: DType,
2175 weight_pack_factor: usize,
2176 _matformer_config: Option<&MatformerSliceConfig>,
2177 ) -> Result<usize> {
2178 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2179 let text_elems = {
2180 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2181 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2183 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2184 } else {
2185 0
2186 };
2187 let norm = cfg.hidden_size;
2188 embed_tokens + lm_head + norm
2189 };
2190
2191 let patch_merger = {
2192 let cfg = &cfg.vision_config;
2193 let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2194
2195 let mlp0 = hidden_size * hidden_size + hidden_size;
2196 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2197
2198 let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2199
2200 mlp0 + mlp2 + ln_q
2201 };
2202
2203 let patch_embed = {
2204 let cfg = &cfg.vision_config;
2205 let conv_cfg = Conv3dConfig {
2206 stride: cfg.patch_size,
2207 ..Default::default()
2208 };
2209 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2210 cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2211 * kernel_sizes[0]
2212 * kernel_sizes[1]
2213 * kernel_sizes[2]
2214 };
2215
2216 let encoder_layer = {
2217 let cfg = &cfg.vision_config;
2218 let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2219 let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2220
2221 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2222 let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2223 let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2224 let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2225
2226 let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2227 let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2228
2229 norm1 + norm2 + fc1 + fc2 + qkv + out
2230 };
2231
2232 let elems =
2233 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2234
2235 Ok(elems * dtype.size_in_bytes())
2236 }
2237
2238 fn layer_sizes_in_bytes(
2239 &self,
2240 config: &str,
2241 dtype: DType,
2242 weight_pack_factor: usize,
2243 _matformer_config: Option<&MatformerSliceConfig>,
2244 ) -> Result<Vec<usize>> {
2245 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2246 let per_layer_elems = {
2247 let input_layernorm = cfg.hidden_size;
2248 let post_attention_layernorm = cfg.hidden_size;
2249
2250 let size_in = cfg.hidden_size;
2251 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2252 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2253 let q_proj = size_in * size_q / weight_pack_factor + size_q;
2254 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2255 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2256 let o_proj = size_q * size_in / weight_pack_factor;
2257
2258 let h_size = cfg.hidden_size;
2259 let i_size = cfg.intermediate_size;
2260 let gate_proj = h_size * i_size / weight_pack_factor;
2261 let up_proj = h_size * i_size / weight_pack_factor;
2262 let down_proj = i_size * h_size / weight_pack_factor;
2263
2264 input_layernorm
2265 + post_attention_layernorm
2266 + q_proj
2267 + k_proj
2268 + v_proj
2269 + o_proj
2270 + gate_proj
2271 + up_proj
2272 + down_proj
2273 };
2274 Ok(vec![
2275 per_layer_elems * dtype.size_in_bytes();
2276 cfg.num_hidden_layers
2277 ])
2278 }
2279
2280 fn num_layers(&self, config: &str) -> Result<usize> {
2281 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2282 Ok(cfg.num_hidden_layers)
2283 }
2284
2285 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2286 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2287
2288 let cfg = ModelConfigMetadata {
2289 max_seq_len: cfg.max_position_embeddings,
2290 num_layers: cfg.num_hidden_layers,
2291 hidden_size: cfg.hidden_size,
2292 num_kv_heads: cfg.num_key_value_heads,
2293 num_attn_heads: cfg.num_attention_heads,
2294 sliding_window: cfg.sliding_window,
2295 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2296 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2297 };
2298
2299 Ok(Box::new(cfg))
2300 }
2301
2302 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2303 Some(vec![NonMappedSubModel::Vision])
2304 }
2305}
2306
2307pub struct Idefics3Loader;
2313
2314pub struct Idefics3Prefixer;
2315
2316impl MultimodalPromptPrefixer for Idefics3Prefixer {
2317 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2318 prompt.to_string()
2320 }
2321}
2322
2323impl VisionModelLoader for Idefics3Loader {
2324 fn load(
2325 &self,
2326 config: &str,
2327 vb: ShardedVarBuilder,
2328 normal_loading_metadata: NormalLoadingMetadata,
2329 attention_mechanism: AttentionImplementation,
2330 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2331 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2332 Ok(Box::new(Idefics3Model::new(
2333 &cfg,
2334 vb,
2335 self.is_gptx(config),
2336 normal_loading_metadata,
2337 attention_mechanism,
2338 )?))
2339 }
2340 fn is_gptx(&self, _config: &str) -> bool {
2341 true
2342 }
2343 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2344 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2345 Ok(Box::new(cfg))
2346 }
2347 fn get_processor(
2348 &self,
2349 _model_config: &str,
2350 processor_config: Option<ProcessorConfig>,
2351 preprocessor_config: PreProcessorConfig,
2352 max_edge: Option<u32>,
2353 ) -> Arc<dyn Processor + Send + Sync> {
2354 Arc::new(Idefics3Processor::new(
2355 processor_config.unwrap_or_default(),
2356 preprocessor_config,
2357 max_edge,
2358 ))
2359 }
2360 fn supports_paged_attention(&self, _config: &str) -> bool {
2361 true
2362 }
2363 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2364 true
2365 }
2366 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2367 Arc::new(Idefics3Prefixer)
2368 }
2369 fn modalities(&self, _config: &str) -> Result<Modalities> {
2370 Ok(Modalities {
2371 input: vec![SupportedModality::Text, SupportedModality::Vision],
2372 output: vec![SupportedModality::Text],
2373 })
2374 }
2375}
2376
2377impl IsqModelLoader for Idefics3Loader {
2378 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2379 Ok(vec![
2380 Regex::new(r"lm_head\.(weight|bias)$")?,
2381 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2383 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2384 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2385 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2386 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2388 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2389 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2390 ])
2391 }
2392 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2393 Ok(vec![
2394 Regex::new(r"lm_head\.(weight|bias)$")?,
2395 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2397 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2398 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2399 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2400 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2402 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2403 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2404 ])
2421 }
2422}
2423
2424impl DeviceMappedModelLoader for Idefics3Loader {
2425 fn mapped_max_act_size_elems(
2426 &self,
2427 config: &str,
2428 params: &AutoDeviceMapParams,
2429 ) -> Result<usize> {
2430 let AutoDeviceMapParams::Vision {
2431 max_seq_len,
2432 max_batch_size,
2433 max_image_shape: _,
2434 max_num_images,
2435 } = params
2436 else {
2437 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2438 };
2439
2440 let cfg: Idefics3Config = serde_json::from_str(config)?;
2441
2442 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2443 let img_seq_len = (num_patches + 1) * max_num_images;
2444
2445 let max_text_attn = {
2446 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2448 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2449 };
2450
2451 Ok(max_text_attn)
2452 }
2453
2454 fn non_mapped_max_act_size_elems(
2455 &self,
2456 config: &str,
2457 params: &AutoDeviceMapParams,
2458 ) -> Result<usize> {
2459 let AutoDeviceMapParams::Vision {
2460 max_seq_len: _,
2461 max_batch_size,
2462 max_image_shape: _,
2463 max_num_images,
2464 } = params
2465 else {
2466 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2467 };
2468
2469 let cfg: Idefics3Config = serde_json::from_str(config)?;
2470
2471 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2472 let img_seq_len = num_patches + 1;
2473
2474 let max_vision_attn = {
2475 let images_factor = 5;
2477
2478 (max_batch_size * images_factor * max_num_images)
2479 * cfg.vision_config.num_attention_heads
2480 * img_seq_len
2481 * img_seq_len
2482 };
2483
2484 Ok(max_vision_attn)
2485 }
2486
2487 fn non_mapped_size_in_bytes(
2488 &self,
2489 config: &str,
2490 dtype: DType,
2491 weight_pack_factor: usize,
2492 _matformer_config: Option<&MatformerSliceConfig>,
2493 ) -> Result<usize> {
2494 let cfg: Idefics3Config = serde_json::from_str(config)?;
2495 let text_elems = {
2496 let cfg = &cfg.text_config;
2497
2498 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2499 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2500 let norm = cfg.hidden_size;
2501 embed_tokens + lm_head + norm
2502 };
2503
2504 let connector_elems = {
2505 let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2506 let out_dim = cfg.text_config.hidden_size;
2507
2508 in_dim * out_dim
2509 };
2510
2511 let vision_transformer = {
2512 let cfg = &cfg.vision_config;
2513
2514 let post_layernorm = cfg.hidden_size;
2515
2516 let conv_config = Conv2dConfig {
2517 stride: cfg.patch_size,
2518 ..Default::default()
2519 };
2520 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2521 * cfg.patch_size
2522 * cfg.patch_size;
2523
2524 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2525 let num_patches = num_patches_per_side.pow(2);
2526 let position_embedding = num_patches * cfg.hidden_size;
2527
2528 let layer_elems = {
2529 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2530 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2531
2532 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2533 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2534
2535 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2536 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2537 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2538 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2539
2540 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2541 };
2542
2543 post_layernorm
2544 + patch_embedding
2545 + position_embedding
2546 + layer_elems * cfg.num_hidden_layers
2547 };
2548
2549 let elems = text_elems + connector_elems + vision_transformer;
2550
2551 Ok(elems * dtype.size_in_bytes())
2552 }
2553
2554 fn layer_sizes_in_bytes(
2555 &self,
2556 config: &str,
2557 dtype: DType,
2558 weight_pack_factor: usize,
2559 _matformer_config: Option<&MatformerSliceConfig>,
2560 ) -> Result<Vec<usize>> {
2561 let cfg: Idefics3Config = serde_json::from_str(config)?;
2562 let cfg = cfg.text_config;
2563 let per_layer_elems = {
2564 let input_layernorm = cfg.hidden_size;
2565 let post_attention_layernorm = cfg.hidden_size;
2566
2567 let size_in = cfg.hidden_size;
2568 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2569 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2570 let q_proj = size_in * size_q / weight_pack_factor;
2571 let k_proj = size_in * size_kv / weight_pack_factor;
2572 let v_proj = size_in * size_kv / weight_pack_factor;
2573 let o_proj = size_q * size_in / weight_pack_factor;
2574
2575 let h_size = cfg.hidden_size;
2576 let i_size = cfg.intermediate_size;
2577 let gate_proj = h_size * i_size / weight_pack_factor;
2578 let up_proj = h_size * i_size / weight_pack_factor;
2579 let down_proj = i_size * h_size / weight_pack_factor;
2580
2581 input_layernorm
2582 + post_attention_layernorm
2583 + q_proj
2584 + k_proj
2585 + v_proj
2586 + o_proj
2587 + gate_proj
2588 + up_proj
2589 + down_proj
2590 };
2591 Ok(vec![
2592 per_layer_elems * dtype.size_in_bytes();
2593 cfg.num_hidden_layers
2594 ])
2595 }
2596
2597 fn num_layers(&self, config: &str) -> Result<usize> {
2598 let cfg: Idefics3Config = serde_json::from_str(config)?;
2599 Ok(cfg.text_config.num_hidden_layers)
2600 }
2601 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2602 let cfg: Idefics3Config = serde_json::from_str(config)?;
2603 let cfg = &cfg.text_config;
2604
2605 let cfg = ModelConfigMetadata {
2606 max_seq_len: cfg.max_position_embeddings,
2607 num_layers: cfg.num_hidden_layers,
2608 hidden_size: cfg.hidden_size,
2609 num_kv_heads: cfg.num_key_value_heads,
2610 num_attn_heads: cfg.num_attention_heads,
2611 sliding_window: None,
2612 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2613 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2614 };
2615
2616 Ok(Box::new(cfg))
2617 }
2618
2619 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2620 Some(vec![NonMappedSubModel::Vision])
2621 }
2622}
2623
2624pub struct MiniCpmOLoader;
2630
2631pub struct MiniCpmOPrefixer;
2632
2633impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2634 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2635 format!(
2636 "{}{prompt}",
2637 "(<image>./</image>)".repeat(image_indexes.len())
2638 )
2639 }
2640}
2641
2642impl VisionModelLoader for MiniCpmOLoader {
2643 fn load(
2644 &self,
2645 config: &str,
2646 vb: ShardedVarBuilder,
2647 normal_loading_metadata: NormalLoadingMetadata,
2648 attention_mechanism: AttentionImplementation,
2649 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2650 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2651 Ok(Box::new(MiniCpmOModel::new(
2652 &cfg,
2653 vb,
2654 self.is_gptx(config),
2655 normal_loading_metadata,
2656 attention_mechanism,
2657 )?))
2658 }
2659 fn is_gptx(&self, _config: &str) -> bool {
2660 true
2661 }
2662 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2663 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2664 Ok(Box::new(cfg))
2665 }
2666 fn get_processor(
2667 &self,
2668 _model_config: &str,
2669 processor_config: Option<ProcessorConfig>,
2670 preprocessor_config: PreProcessorConfig,
2671 max_edge: Option<u32>,
2672 ) -> Arc<dyn Processor + Send + Sync> {
2673 Arc::new(MiniCpmOProcessor::new(
2674 processor_config.unwrap_or_default(),
2675 preprocessor_config,
2676 max_edge,
2677 ))
2678 }
2679 fn supports_paged_attention(&self, _config: &str) -> bool {
2680 true
2681 }
2682 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2683 Arc::new(MiniCpmOPrefixer)
2684 }
2685 fn modalities(&self, _config: &str) -> Result<Modalities> {
2686 Ok(Modalities {
2687 input: vec![SupportedModality::Text, SupportedModality::Vision],
2688 output: vec![SupportedModality::Text],
2689 })
2690 }
2691}
2692
2693impl IsqModelLoader for MiniCpmOLoader {
2694 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2695 Ok(vec![
2696 Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2697 Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2699 Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2700 Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2701 Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2702 Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2704 Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2705 Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2706 ])
2707 }
2708 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2709 self.isq_layer_regexes(config)
2710 }
2711}
2712
2713impl DeviceMappedModelLoader for MiniCpmOLoader {
2714 fn mapped_max_act_size_elems(
2715 &self,
2716 config: &str,
2717 params: &AutoDeviceMapParams,
2718 ) -> Result<usize> {
2719 let AutoDeviceMapParams::Vision {
2720 max_seq_len,
2721 max_batch_size,
2722 max_image_shape: _,
2723 max_num_images,
2724 } = params
2725 else {
2726 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2727 };
2728
2729 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2730
2731 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2732 let img_seq_len = (num_patches + 1) * max_num_images;
2733
2734 let max_text_attn = {
2735 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2737 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2738 };
2739
2740 Ok(max_text_attn)
2741 }
2742
2743 fn non_mapped_max_act_size_elems(
2744 &self,
2745 config: &str,
2746 params: &AutoDeviceMapParams,
2747 ) -> Result<usize> {
2748 let AutoDeviceMapParams::Vision {
2749 max_seq_len: _,
2750 max_batch_size,
2751 max_image_shape: _,
2752 max_num_images,
2753 } = params
2754 else {
2755 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2756 };
2757
2758 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2759
2760 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2761 let img_seq_len = num_patches + 1;
2762
2763 let max_vision_attn = {
2764 let images_factor = 5;
2766
2767 (max_batch_size * images_factor * max_num_images)
2768 * cfg.vision_config.num_attention_heads
2769 * img_seq_len
2770 * img_seq_len
2771 };
2772
2773 Ok(max_vision_attn)
2774 }
2775
2776 fn non_mapped_size_in_bytes(
2777 &self,
2778 config: &str,
2779 dtype: DType,
2780 weight_pack_factor: usize,
2781 _matformer_config: Option<&MatformerSliceConfig>,
2782 ) -> Result<usize> {
2783 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2784 let text_elems = {
2785 let cfg = &cfg.text_config;
2786
2787 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2788 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2789 let norm = cfg.hidden_size;
2790 embed_tokens + lm_head + norm
2791 };
2792
2793 let vision_transformer = {
2794 let cfg = &cfg.vision_config;
2795
2796 let post_layernorm = cfg.hidden_size;
2797
2798 let conv_config = Conv2dConfig {
2799 stride: cfg.patch_size,
2800 ..Default::default()
2801 };
2802 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2803 * cfg.patch_size
2804 * cfg.patch_size;
2805
2806 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2807 let num_patches = num_patches_per_side.pow(2);
2808 let position_embedding = num_patches * cfg.hidden_size;
2809
2810 let layer_elems = {
2811 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2812 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2813
2814 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2815 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2816
2817 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2818 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2819 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2820 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2821
2822 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2823 };
2824
2825 post_layernorm
2826 + patch_embedding
2827 + position_embedding
2828 + layer_elems * cfg.num_hidden_layers
2829 };
2830
2831 let elems = text_elems + vision_transformer;
2832
2833 Ok(elems * dtype.size_in_bytes())
2834 }
2835
2836 fn layer_sizes_in_bytes(
2837 &self,
2838 config: &str,
2839 dtype: DType,
2840 weight_pack_factor: usize,
2841 _matformer_config: Option<&MatformerSliceConfig>,
2842 ) -> Result<Vec<usize>> {
2843 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2844 let cfg = cfg.text_config;
2845 let per_layer_elems = {
2846 let input_layernorm = cfg.hidden_size;
2847 let post_attention_layernorm = cfg.hidden_size;
2848
2849 let size_in = cfg.hidden_size;
2850 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2851 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2852 let q_proj = size_in * size_q / weight_pack_factor;
2853 let k_proj = size_in * size_kv / weight_pack_factor;
2854 let v_proj = size_in * size_kv / weight_pack_factor;
2855 let o_proj = size_q * size_in / weight_pack_factor;
2856
2857 let h_size = cfg.hidden_size;
2858 let i_size = cfg.intermediate_size;
2859 let gate_proj = h_size * i_size / weight_pack_factor;
2860 let up_proj = h_size * i_size / weight_pack_factor;
2861 let down_proj = i_size * h_size / weight_pack_factor;
2862
2863 input_layernorm
2864 + post_attention_layernorm
2865 + q_proj
2866 + k_proj
2867 + v_proj
2868 + o_proj
2869 + gate_proj
2870 + up_proj
2871 + down_proj
2872 };
2873 Ok(vec![
2874 per_layer_elems * dtype.size_in_bytes();
2875 cfg.num_hidden_layers
2876 ])
2877 }
2878
2879 fn num_layers(&self, config: &str) -> Result<usize> {
2880 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2881 Ok(cfg.text_config.num_hidden_layers)
2882 }
2883 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2884 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2885 let cfg = &cfg.text_config;
2886
2887 let cfg = ModelConfigMetadata {
2888 max_seq_len: cfg.max_position_embeddings,
2889 num_layers: cfg.num_hidden_layers,
2890 hidden_size: cfg.hidden_size,
2891 num_kv_heads: cfg.num_key_value_heads,
2892 num_attn_heads: cfg.num_attention_heads,
2893 sliding_window: None,
2894 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2895 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2896 };
2897
2898 Ok(Box::new(cfg))
2899 }
2900}
2901
2902pub struct Phi4MMLoader;
2908
2909pub struct Phi4MMPrefixer;
2910
2911impl MultimodalPromptPrefixer for Phi4MMPrefixer {
2912 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2913 format!(
2916 "{}{prompt}",
2917 image_indexes
2918 .into_iter()
2919 .map(|image_index| format!("<|image_{}|>", image_index + 1))
2920 .join("")
2921 )
2922 }
2923 fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
2924 format!(
2927 "{}{prompt}",
2928 audio_indexes
2929 .into_iter()
2930 .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
2931 .join("")
2932 )
2933 }
2934}
2935
2936impl VisionModelLoader for Phi4MMLoader {
2937 fn load(
2938 &self,
2939 config: &str,
2940 vb: ShardedVarBuilder,
2941 normal_loading_metadata: NormalLoadingMetadata,
2942 attention_mechanism: AttentionImplementation,
2943 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2944 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2945 Ok(Box::new(Phi4MMModel::new(
2946 &cfg,
2947 vb,
2948 self.is_gptx(config),
2949 normal_loading_metadata,
2950 attention_mechanism,
2951 )?))
2952 }
2953 fn is_gptx(&self, _config: &str) -> bool {
2954 true
2955 }
2956 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2957 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2958 Ok(Box::new(cfg))
2959 }
2960 fn get_processor(
2961 &self,
2962 _model_config: &str,
2963 processor_config: Option<ProcessorConfig>,
2964 preprocessor_config: PreProcessorConfig,
2965 _max_edge: Option<u32>,
2966 ) -> Arc<dyn Processor + Send + Sync> {
2967 Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2968 }
2969 fn supports_paged_attention(&self, _config: &str) -> bool {
2970 true
2971 }
2972 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2973 true
2974 }
2975 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2976 Arc::new(Phi4MMPrefixer)
2977 }
2978 fn modalities(&self, _config: &str) -> Result<Modalities> {
2979 Ok(Modalities {
2980 input: vec![
2981 SupportedModality::Text,
2982 SupportedModality::Vision,
2983 SupportedModality::Audio,
2984 ],
2985 output: vec![SupportedModality::Text],
2986 })
2987 }
2988}
2989
2990impl IsqModelLoader for Phi4MMLoader {
2991 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2992 Ok(vec![
2993 Regex::new(r"lm_head\.(weight|bias)$")?,
2994 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
2996 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2997 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
2999 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3000 ])
3001 }
3002 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3003 self.isq_layer_regexes(config)
3004 }
3005}
3006
3007impl DeviceMappedModelLoader for Phi4MMLoader {
3008 fn mapped_max_act_size_elems(
3009 &self,
3010 config: &str,
3011 params: &AutoDeviceMapParams,
3012 ) -> Result<usize> {
3013 let AutoDeviceMapParams::Vision {
3015 max_seq_len,
3016 max_batch_size,
3017 max_image_shape: _,
3018 max_num_images,
3019 } = params
3020 else {
3021 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3022 };
3023
3024 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3025
3026 let vcfg = &PHI4_MM_VISION_CFG;
3027
3028 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3029 let img_seq_len = (num_patches + 1) * max_num_images;
3030
3031 let max_text_attn = {
3032 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3034 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3035 };
3036
3037 Ok(max_text_attn)
3038 }
3039
3040 fn non_mapped_max_act_size_elems(
3041 &self,
3042 _config: &str,
3043 params: &AutoDeviceMapParams,
3044 ) -> Result<usize> {
3045 let AutoDeviceMapParams::Vision {
3046 max_seq_len: _,
3047 max_batch_size,
3048 max_image_shape,
3049 max_num_images,
3050 } = params
3051 else {
3052 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3053 };
3054
3055 let vcfg = &PHI4_MM_VISION_CFG;
3056
3057 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3058 let img_seq_len = num_patches + 1;
3059
3060 let max_batch_size = max_batch_size
3061 * (max_image_shape
3062 .0
3063 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3064 * max_image_shape
3065 .1
3066 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3067 + 1);
3068
3069 let max_vision_attn = (max_batch_size * max_num_images)
3070 * vcfg.num_attention_heads
3071 * img_seq_len
3072 * img_seq_len;
3073 let max_qkv = 3
3074 * (max_batch_size
3075 * vcfg.num_attention_heads
3076 * img_seq_len
3077 * (vcfg.hidden_size / vcfg.num_attention_heads));
3078
3079 Ok(max_vision_attn + max_qkv)
3080 }
3081
3082 fn non_mapped_size_in_bytes(
3083 &self,
3084 config: &str,
3085 dtype: DType,
3086 weight_pack_factor: usize,
3087 _matformer_config: Option<&MatformerSliceConfig>,
3088 ) -> Result<usize> {
3089 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3090 let elems = {
3091 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3092 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3094 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3095 } else {
3096 0
3097 };
3098 let norm = cfg.hidden_size;
3099
3100 let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3101 let projection_cls = img_embed
3102 .projection_cls
3103 .clone()
3104 .unwrap_or("linear".to_string());
3105 let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3106 let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3107 let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3108
3109 let proj = match (projection_cls.as_str(), use_hd_transform) {
3110 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3111 ("mlp", true) => {
3112 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3113 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3114 a + b
3115 }
3116 ("mlp", false) => {
3117 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3118 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3119 a + b
3120 }
3121 _ => {
3122 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3123 }
3124 };
3125
3126 let (glb_gn, sub_gn) = if with_learnable_separator {
3127 let glb_gn = image_dim_out * 4;
3128 let sub_gn = image_dim_out * 4;
3129 (glb_gn, sub_gn)
3130 } else {
3131 (0, 0)
3132 };
3133
3134 let vision_transformer = {
3135 let cfg = &PHI4_MM_VISION_CFG;
3136
3137 let post_layernorm = cfg.hidden_size;
3138
3139 let conv_config = Conv2dConfig {
3140 stride: cfg.patch_size,
3141 ..Default::default()
3142 };
3143 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3144 * cfg.patch_size
3145 * cfg.patch_size;
3146
3147 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3148 let num_patches = num_patches_per_side.pow(2);
3149 let position_embedding = num_patches * cfg.hidden_size;
3150
3151 let layer_elems = {
3152 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3153 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3154
3155 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3156 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3157
3158 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3159 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3160 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3161 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3162
3163 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3164 };
3165
3166 post_layernorm
3167 + patch_embedding
3168 + position_embedding
3169 + layer_elems * cfg.num_hidden_layers
3170 };
3171
3172 proj + glb_gn + sub_gn + vision_transformer
3173 } else {
3174 0
3175 };
3176
3177 embed_tokens + lm_head + norm + image_embed
3178 };
3179
3180 Ok(elems * dtype.size_in_bytes())
3181 }
3182
3183 fn layer_sizes_in_bytes(
3184 &self,
3185 config: &str,
3186 dtype: DType,
3187 weight_pack_factor: usize,
3188 _matformer_config: Option<&MatformerSliceConfig>,
3189 ) -> Result<Vec<usize>> {
3190 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3191 let per_layer_elems = {
3192 let input_layernorm = cfg.hidden_size;
3193 let post_attention_layernorm = cfg.hidden_size;
3194
3195 let size_in = cfg.hidden_size;
3196 let head_dim = cfg.head_dim();
3197 let op_size =
3198 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3199 let qkv_proj = size_in * op_size / weight_pack_factor;
3200 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3201
3202 let h_size = cfg.hidden_size;
3203 let i_size = cfg.intermediate_size;
3204 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3205 let down_proj = h_size * i_size / weight_pack_factor;
3206
3207 input_layernorm
3208 + post_attention_layernorm
3209 + qkv_proj
3210 + o_proj
3211 + gate_up_proj
3212 + down_proj
3213 };
3214 Ok(vec![
3215 per_layer_elems * dtype.size_in_bytes();
3216 cfg.num_hidden_layers
3217 ])
3218 }
3219
3220 fn num_layers(&self, config: &str) -> Result<usize> {
3221 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3222 Ok(cfg.num_hidden_layers)
3223 }
3224
3225 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3226 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3227
3228 let cfg = ModelConfigMetadata {
3229 max_seq_len: cfg.max_position_embeddings,
3230 num_layers: cfg.num_hidden_layers,
3231 hidden_size: cfg.hidden_size,
3232 num_kv_heads: cfg.num_key_value_heads(),
3233 num_attn_heads: cfg.num_attention_heads,
3234 sliding_window: cfg.sliding_window,
3235 k_head_dim: cfg.head_dim(),
3236 v_head_dim: cfg.head_dim(),
3237 };
3238
3239 Ok(Box::new(cfg))
3240 }
3241
3242 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3243 Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
3244 }
3245}
3246
3247pub struct Qwen2_5VLLoader;
3253
3254pub struct Qwen2_5VLPrefixer;
3255
3256impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3257 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3258 format!(
3259 "{}{prompt}",
3260 format!(
3261 "{}{}{}",
3262 Qwen2_5VLProcessor::VISION_START,
3263 Qwen2_5VLProcessor::IMAGE_PAD,
3264 Qwen2_5VLProcessor::VISION_END
3265 )
3266 .repeat(image_indexes.len())
3267 )
3268 }
3269}
3270
3271impl VisionModelLoader for Qwen2_5VLLoader {
3272 fn load(
3273 &self,
3274 config: &str,
3275 vb: ShardedVarBuilder,
3276 normal_loading_metadata: NormalLoadingMetadata,
3277 attention_mechanism: AttentionImplementation,
3278 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3279 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3280 Ok(Box::new(Qwen2_5VLModel::new(
3281 &cfg,
3282 vb,
3283 self.is_gptx(config),
3284 normal_loading_metadata,
3285 attention_mechanism,
3286 )?))
3287 }
3288 fn is_gptx(&self, _config: &str) -> bool {
3289 true
3290 }
3291 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3292 let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3293 Ok(Box::new(config))
3294 }
3295 fn get_processor(
3296 &self,
3297 _model_config: &str,
3298 _processor_config: Option<ProcessorConfig>,
3299 _preprocessor_config: PreProcessorConfig,
3300 max_edge: Option<u32>,
3301 ) -> Arc<dyn Processor + Send + Sync> {
3302 Arc::new(Qwen2_5VLProcessor::new(max_edge))
3303 }
3304 fn supports_paged_attention(&self, _config: &str) -> bool {
3305 false
3306 }
3307 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3308 Arc::new(Qwen2_5VLPrefixer)
3309 }
3310 fn modalities(&self, _config: &str) -> Result<Modalities> {
3311 Ok(Modalities {
3312 input: vec![SupportedModality::Text, SupportedModality::Vision],
3313 output: vec![SupportedModality::Text],
3314 })
3315 }
3316}
3317
3318impl IsqModelLoader for Qwen2_5VLLoader {
3319 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3320 Ok(vec![
3321 Regex::new(r"lm_head\.(weight|bias)$")?,
3322 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3324 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3325 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3326 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3327 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3329 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3330 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3331 ])
3332 }
3333 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3334 self.isq_layer_regexes(config)
3335 }
3336}
3337
3338impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3339 fn mapped_max_act_size_elems(
3340 &self,
3341 config: &str,
3342 params: &AutoDeviceMapParams,
3343 ) -> Result<usize> {
3344 let AutoDeviceMapParams::Vision {
3345 max_seq_len,
3346 max_batch_size,
3347 max_image_shape,
3348 max_num_images,
3349 } = params
3350 else {
3351 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3352 };
3353
3354 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3355
3356 let img_seq_len = {
3357 let cfg = &cfg.vision_config;
3358 let grid_t = max_num_images / cfg.temporal_patch_size;
3359 let grid_h = max_image_shape.0 / cfg.patch_size;
3360 let grid_w = max_image_shape.1 / cfg.patch_size;
3361 grid_t * grid_h * grid_w
3362 };
3363 let img_seq_len = img_seq_len * max_num_images;
3364
3365 let max_text_attn = {
3366 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3368 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3369 };
3370
3371 Ok(max_text_attn)
3372 }
3373
3374 fn non_mapped_max_act_size_elems(
3375 &self,
3376 config: &str,
3377 params: &AutoDeviceMapParams,
3378 ) -> Result<usize> {
3379 let AutoDeviceMapParams::Vision {
3380 max_seq_len: _,
3381 max_batch_size,
3382 max_image_shape,
3383 max_num_images,
3384 } = params
3385 else {
3386 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3387 };
3388
3389 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3390
3391 let img_seq_len = {
3392 let cfg = &cfg.vision_config;
3393 let grid_t = max_num_images / cfg.temporal_patch_size;
3394 let grid_h = max_image_shape.0 / cfg.patch_size;
3395 let grid_w = max_image_shape.1 / cfg.patch_size;
3396 grid_t * grid_h * grid_w
3397 };
3398
3399 let max_vision_attn = {
3400 let cfg = &cfg.vision_config;
3401 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3402 };
3403
3404 Ok(max_vision_attn)
3405 }
3406
3407 fn non_mapped_size_in_bytes(
3408 &self,
3409 config: &str,
3410 dtype: DType,
3411 weight_pack_factor: usize,
3412 _matformer_config: Option<&MatformerSliceConfig>,
3413 ) -> Result<usize> {
3414 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3415 let text_elems = {
3416 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3417 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3419 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3420 } else {
3421 0
3422 };
3423 let norm = cfg.hidden_size;
3424 embed_tokens + lm_head + norm
3425 };
3426
3427 let patch_merger = {
3428 let cfg = &cfg.vision_config;
3429 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3430
3431 let mlp0 = hidden_size * hidden_size + hidden_size;
3432 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3433
3434 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3435
3436 mlp0 + mlp2 + ln_q
3437 };
3438
3439 let patch_embed = {
3440 let cfg = &cfg.vision_config;
3441 let conv_cfg = Conv3dConfig {
3442 stride: cfg.patch_size,
3443 ..Default::default()
3444 };
3445 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3446 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3447 * kernel_sizes[0]
3448 * kernel_sizes[1]
3449 * kernel_sizes[2]
3450 };
3451
3452 let encoder_layer = {
3453 let cfg = &cfg.vision_config;
3454 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3455 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3456
3457 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3458 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3459 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3460
3461 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3462 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3463
3464 norm1 + norm2 + fc1 + fc2 + qkv + out
3465 };
3466
3467 let elems =
3468 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3469
3470 Ok(elems * dtype.size_in_bytes())
3471 }
3472
3473 fn layer_sizes_in_bytes(
3474 &self,
3475 config: &str,
3476 dtype: DType,
3477 weight_pack_factor: usize,
3478 _matformer_config: Option<&MatformerSliceConfig>,
3479 ) -> Result<Vec<usize>> {
3480 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3481 let per_layer_elems = {
3482 let input_layernorm = cfg.hidden_size;
3483 let post_attention_layernorm = cfg.hidden_size;
3484
3485 let size_in = cfg.hidden_size;
3486 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3487 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3488 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3489 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3490 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3491 let o_proj = size_q * size_in / weight_pack_factor;
3492
3493 let h_size = cfg.hidden_size;
3494 let i_size = cfg.intermediate_size;
3495 let gate_proj = h_size * i_size / weight_pack_factor;
3496 let up_proj = h_size * i_size / weight_pack_factor;
3497 let down_proj = i_size * h_size / weight_pack_factor;
3498
3499 input_layernorm
3500 + post_attention_layernorm
3501 + q_proj
3502 + k_proj
3503 + v_proj
3504 + o_proj
3505 + gate_proj
3506 + up_proj
3507 + down_proj
3508 };
3509 Ok(vec![
3510 per_layer_elems * dtype.size_in_bytes();
3511 cfg.num_hidden_layers
3512 ])
3513 }
3514
3515 fn num_layers(&self, config: &str) -> Result<usize> {
3516 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3517 Ok(cfg.num_hidden_layers)
3518 }
3519
3520 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3521 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3522
3523 let cfg = ModelConfigMetadata {
3524 max_seq_len: cfg.max_position_embeddings,
3525 num_layers: cfg.num_hidden_layers,
3526 hidden_size: cfg.hidden_size,
3527 num_kv_heads: cfg.num_key_value_heads,
3528 num_attn_heads: cfg.num_attention_heads,
3529 sliding_window: cfg.sliding_window,
3530 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3531 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3532 };
3533
3534 Ok(Box::new(cfg))
3535 }
3536
3537 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3538 Some(vec![NonMappedSubModel::Vision])
3539 }
3540}
3541
3542pub struct Gemma3Loader;
3548
3549pub struct Gemma3Prefixer;
3550
3551impl MultimodalPromptPrefixer for Gemma3Prefixer {
3552 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3553 prompt.to_string()
3554 }
3555}
3556
3557impl VisionModelLoader for Gemma3Loader {
3558 fn load(
3559 &self,
3560 config: &str,
3561 vb: ShardedVarBuilder,
3562 normal_loading_metadata: NormalLoadingMetadata,
3563 attention_mechanism: AttentionImplementation,
3564 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3565 let cfg: Gemma3Config = serde_json::from_str(config)?;
3566 Ok(Box::new(Gemma3Model::new(
3567 &cfg,
3568 vb,
3569 self.is_gptx(config),
3570 normal_loading_metadata,
3571 attention_mechanism,
3572 )?))
3573 }
3574 fn is_gptx(&self, _config: &str) -> bool {
3575 true
3576 }
3577 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3578 let config: Gemma3Config = serde_json::from_str(config)?;
3579 Ok(Box::new(config))
3580 }
3581 fn get_processor(
3582 &self,
3583 config: &str,
3584 processor_config: Option<ProcessorConfig>,
3585 _preprocessor_config: PreProcessorConfig,
3586 _max_edge: Option<u32>,
3587 ) -> Arc<dyn Processor + Send + Sync> {
3588 let config: Gemma3Config = serde_json::from_str(config).unwrap();
3589 Arc::new(Gemma3Processor::new(
3591 processor_config.unwrap_or_default(),
3592 matches!(config, Gemma3Config::WithVision { .. }),
3593 ))
3594 }
3595 fn supports_paged_attention(&self, _config: &str) -> bool {
3596 true
3597 }
3598 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3599 true
3600 }
3601 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3602 Arc::new(Gemma3Prefixer)
3603 }
3604 fn modalities(&self, _config: &str) -> Result<Modalities> {
3605 Ok(Modalities {
3606 input: vec![SupportedModality::Text, SupportedModality::Vision],
3607 output: vec![SupportedModality::Text],
3608 })
3609 }
3610}
3611
3612impl IsqModelLoader for Gemma3Loader {
3613 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3614 Ok(vec![
3615 Regex::new(r"lm_head\.(weight|bias)$")?,
3616 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3618 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3619 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3620 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3621 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3623 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3624 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3625 ])
3626 }
3627 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3628 Ok(vec![
3629 Regex::new(r"lm_head\.(weight|bias)$")?,
3630 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3632 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3633 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3634 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3635 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3637 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3638 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3639 ])
3640 }
3641}
3642
3643impl DeviceMappedModelLoader for Gemma3Loader {
3644 fn mapped_max_act_size_elems(
3645 &self,
3646 config: &str,
3647 params: &AutoDeviceMapParams,
3648 ) -> Result<usize> {
3649 let AutoDeviceMapParams::Vision {
3650 max_seq_len,
3651 max_batch_size,
3652 max_image_shape: _,
3653 max_num_images,
3654 } = params
3655 else {
3656 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3657 };
3658
3659 let cfg: Gemma3Config = serde_json::from_str(config)?;
3660
3661 match cfg {
3662 Gemma3Config::Text(text_config) => Ok(max_batch_size
3663 * text_config.num_attention_heads
3664 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)),
3665 Gemma3Config::WithVision {
3666 text_config,
3667 vision_config,
3668 ..
3669 } => {
3670 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3671 let img_seq_len = (num_patches + 1) * max_num_images;
3672
3673 let max_text_attn = {
3674 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3676 max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3677 };
3678 Ok(max_text_attn)
3679 }
3680 }
3681 }
3682
3683 fn non_mapped_max_act_size_elems(
3684 &self,
3685 config: &str,
3686 params: &AutoDeviceMapParams,
3687 ) -> Result<usize> {
3688 let AutoDeviceMapParams::Vision {
3689 max_seq_len: _,
3690 max_batch_size,
3691 max_image_shape: _,
3692 max_num_images,
3693 } = params
3694 else {
3695 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3696 };
3697
3698 let cfg: Gemma3Config = serde_json::from_str(config)?;
3699
3700 match cfg {
3701 Gemma3Config::WithVision { vision_config, .. } => {
3702 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3703 let img_seq_len = num_patches + 1;
3704
3705 let max_vision_attn = {
3706 (max_batch_size * max_num_images)
3707 * vision_config.num_attention_heads
3708 * img_seq_len
3709 * img_seq_len
3710 };
3711
3712 Ok(max_vision_attn)
3713 }
3714 Gemma3Config::Text(_) => Ok(0),
3715 }
3716 }
3717
3718 fn non_mapped_size_in_bytes(
3719 &self,
3720 config: &str,
3721 dtype: DType,
3722 weight_pack_factor: usize,
3723 _matformer_config: Option<&MatformerSliceConfig>,
3724 ) -> Result<usize> {
3725 let cfg: Gemma3Config = serde_json::from_str(config)?;
3726
3727 let text_elems = {
3728 let cfg = match &cfg {
3729 Gemma3Config::Text(cfg) => cfg,
3730 Gemma3Config::WithVision { text_config, .. } => text_config,
3731 };
3732 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3733 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3735 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3736 } else {
3737 0
3738 };
3739 let norm = cfg.hidden_size;
3740 embed_tokens + lm_head + norm
3741 };
3742
3743 let vision_transformer = if let Gemma3Config::WithVision {
3744 vision_config: cfg, ..
3745 } = &cfg
3746 {
3747 let post_layernorm = cfg.hidden_size;
3748
3749 let conv_config = Conv2dConfig {
3750 stride: cfg.patch_size,
3751 ..Default::default()
3752 };
3753 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3754 * cfg.patch_size
3755 * cfg.patch_size;
3756
3757 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3758 let num_patches = num_patches_per_side.pow(2);
3759 let position_embedding = num_patches * cfg.hidden_size;
3760
3761 let layer_elems = {
3762 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3763 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3764
3765 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3766 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3767
3768 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3769 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3770 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3771 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3772
3773 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3774 };
3775
3776 post_layernorm
3777 + patch_embedding
3778 + position_embedding
3779 + layer_elems * cfg.num_hidden_layers
3780 } else {
3781 0
3782 };
3783
3784 let elems = text_elems + vision_transformer;
3785
3786 Ok(elems * dtype.size_in_bytes())
3787 }
3788
3789 fn layer_sizes_in_bytes(
3790 &self,
3791 config: &str,
3792 dtype: DType,
3793 weight_pack_factor: usize,
3794 _matformer_config: Option<&MatformerSliceConfig>,
3795 ) -> Result<Vec<usize>> {
3796 let cfg: Gemma3Config = serde_json::from_str(config)?;
3797
3798 let txt_cfg = match &cfg {
3799 Gemma3Config::Text(cfg) => cfg,
3800 Gemma3Config::WithVision { text_config, .. } => text_config,
3801 };
3802 let per_layer_elems = {
3803 let cfg = txt_cfg;
3804
3805 let input_layernorm = cfg.hidden_size;
3806 let post_attention_layernorm = cfg.hidden_size;
3807
3808 let size_in = cfg.hidden_size;
3809 let size_q = cfg.head_dim * cfg.num_attention_heads;
3810 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3811 let q_proj =
3812 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3813 let k_proj =
3814 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3815 let v_proj =
3816 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3817 let o_proj =
3818 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3819
3820 let h_size = cfg.hidden_size;
3821 let i_size = cfg.intermediate_size;
3822 let gate_proj = h_size * i_size / weight_pack_factor;
3823 let up_proj = h_size * i_size / weight_pack_factor;
3824 let down_proj = i_size * h_size / weight_pack_factor;
3825
3826 input_layernorm
3827 + post_attention_layernorm
3828 + q_proj
3829 + k_proj
3830 + v_proj
3831 + o_proj
3832 + gate_proj
3833 + up_proj
3834 + down_proj
3835 };
3836 Ok(vec![
3837 per_layer_elems * dtype.size_in_bytes();
3838 txt_cfg.num_hidden_layers
3839 ])
3840 }
3841
3842 fn num_layers(&self, config: &str) -> Result<usize> {
3843 let cfg: Gemma3Config = serde_json::from_str(config)?;
3844
3845 let txt_cfg = match &cfg {
3846 Gemma3Config::Text(cfg) => cfg,
3847 Gemma3Config::WithVision { text_config, .. } => text_config,
3848 };
3849
3850 Ok(txt_cfg.num_hidden_layers)
3851 }
3852
3853 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3854 let cfg: Gemma3Config = serde_json::from_str(config)?;
3855
3856 let cfg = match &cfg {
3857 Gemma3Config::Text(cfg) => cfg,
3858 Gemma3Config::WithVision { text_config, .. } => text_config,
3859 };
3860
3861 let cfg = ModelConfigMetadata {
3862 max_seq_len: cfg.max_position_embeddings,
3863 num_layers: cfg.num_hidden_layers,
3864 hidden_size: cfg.hidden_size,
3865 num_kv_heads: cfg.num_key_value_heads,
3866 num_attn_heads: cfg.num_attention_heads,
3867 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3869 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3870 };
3871
3872 Ok(Box::new(cfg))
3873 }
3874
3875 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3876 Some(vec![NonMappedSubModel::Vision])
3877 }
3878}
3879
3880pub struct Mistral3Loader;
3886
3887pub struct Mistral3Prefixer;
3888
3889impl MultimodalPromptPrefixer for Mistral3Prefixer {
3890 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3891 prompt.to_string()
3892 }
3893}
3894
3895impl VisionModelLoader for Mistral3Loader {
3896 fn load(
3897 &self,
3898 config: &str,
3899 vb: ShardedVarBuilder,
3900 normal_loading_metadata: NormalLoadingMetadata,
3901 attention_mechanism: AttentionImplementation,
3902 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3903 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3904 Ok(Box::new(Mistral3Model::new(
3905 &cfg,
3906 vb,
3907 self.is_gptx(config),
3908 normal_loading_metadata,
3909 attention_mechanism,
3910 )?))
3911 }
3912 fn is_gptx(&self, _config: &str) -> bool {
3913 true
3914 }
3915 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3916 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3917 Ok(Box::new(cfg))
3918 }
3919 fn get_processor(
3920 &self,
3921 _model_config: &str,
3922 processor_config: Option<ProcessorConfig>,
3923 _preprocessor_config: PreProcessorConfig,
3924 _max_edge: Option<u32>,
3925 ) -> Arc<dyn Processor + Send + Sync> {
3926 Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3927 }
3928 fn supports_paged_attention(&self, _config: &str) -> bool {
3929 true
3930 }
3931 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3932 true
3933 }
3934 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3935 Arc::new(Mistral3Prefixer)
3936 }
3937 fn modalities(&self, _config: &str) -> Result<Modalities> {
3938 Ok(Modalities {
3939 input: vec![SupportedModality::Text, SupportedModality::Vision],
3940 output: vec![SupportedModality::Text],
3941 })
3942 }
3943}
3944
3945impl IsqModelLoader for Mistral3Loader {
3946 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3947 Ok(vec![
3948 Regex::new(r"lm_head\.(weight|bias)$")?,
3949 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3951 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3952 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3953 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3954 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3956 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3957 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3958 ])
3959 }
3960 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3961 Ok(vec![
3962 Regex::new(r"lm_head\.(weight|bias)$")?,
3963 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3965 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3966 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3967 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3968 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3970 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3971 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3972 ])
3973 }
3974}
3975
3976#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3977impl DeviceMappedModelLoader for Mistral3Loader {
3978 fn mapped_max_act_size_elems(
3979 &self,
3980 config: &str,
3981 params: &AutoDeviceMapParams,
3982 ) -> Result<usize> {
3983 let cfg: Mistral3Config = serde_json::from_str(config)?;
3984 let vcfg = &cfg.vision_config;
3985 let tcfg = &cfg.text_config;
3986
3987 let AutoDeviceMapParams::Vision {
3988 max_seq_len,
3989 max_batch_size,
3990 max_image_shape: (mut height, mut width),
3991 max_num_images,
3992 } = params
3993 else {
3994 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3995 };
3996
3997 let img_seq_len = {
3998 let (max_height, max_width) = (1540, 1540);
4002 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4003 if ratio > 1. {
4004 height = (height as f64 / ratio).floor() as usize;
4005 width = (width as f64 / ratio).floor() as usize;
4006 }
4007
4008 let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
4009 let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
4010
4011 height = num_height_tokens * vcfg.patch_size;
4012 width = num_width_tokens * vcfg.patch_size;
4013
4014 let num_height_tokens = height / vcfg.patch_size;
4015 let num_width_tokens = width / vcfg.patch_size;
4016
4017 (num_width_tokens + 1) * num_height_tokens
4018 };
4019
4020 let max_seq_len = img_seq_len * max_num_images + *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4022 Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
4023 }
4024
4025 fn non_mapped_max_act_size_elems(
4026 &self,
4027 config: &str,
4028 params: &AutoDeviceMapParams,
4029 ) -> Result<usize> {
4030 let cfg: Mistral3Config = serde_json::from_str(config)?;
4031 let cfg = &cfg.vision_config;
4032
4033 let AutoDeviceMapParams::Vision {
4034 max_seq_len: _,
4035 max_batch_size,
4036 max_image_shape: (mut height, mut width),
4037 max_num_images,
4038 } = params
4039 else {
4040 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4041 };
4042
4043 let img_seq_len = {
4044 let (max_height, max_width) = (1540, 1540);
4048 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4049 if ratio > 1. {
4050 height = (height as f64 / ratio).floor() as usize;
4051 width = (width as f64 / ratio).floor() as usize;
4052 }
4053
4054 let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4055 let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4056
4057 height = num_height_tokens * cfg.patch_size;
4058 width = num_width_tokens * cfg.patch_size;
4059
4060 let num_height_tokens = height / cfg.patch_size;
4061 let num_width_tokens = width / cfg.patch_size;
4062
4063 (num_width_tokens + 1) * num_height_tokens
4064 };
4065
4066 Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4067 }
4068
4069 fn non_mapped_size_in_bytes(
4070 &self,
4071 config: &str,
4072 dtype: DType,
4073 weight_pack_factor: usize,
4074 _matformer_config: Option<&MatformerSliceConfig>,
4075 ) -> Result<usize> {
4076 let cfg: Mistral3Config = serde_json::from_str(config)?;
4077
4078 let text_elems = {
4079 let cfg = &cfg.text_config;
4080
4081 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4082 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4084 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4085 } else {
4086 0
4087 };
4088 let norm = cfg.hidden_size;
4089 embed_tokens + lm_head + norm
4090 };
4091
4092 let vision_elems = {
4093 let cfg = &cfg.vision_config;
4094
4095 let patch_embed = {
4096 let conv_cfg = Conv2dConfig {
4097 stride: cfg.patch_size,
4098 ..Default::default()
4099 };
4100 cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4101 * cfg.patch_size
4102 * cfg.patch_size
4103 * cfg.patch_size
4104 };
4105 let ln_pre = cfg.hidden_size;
4106 let vision_layer = {
4107 let attn_norm = cfg.hidden_size;
4108 let ffn_norm = cfg.hidden_size;
4109
4110 let gate = cfg.hidden_size * cfg.intermediate_size;
4111 let up = cfg.hidden_size * cfg.intermediate_size;
4112 let down = cfg.hidden_size * cfg.intermediate_size;
4113
4114 let q = cfg.hidden_size * cfg.hidden_size;
4115 let k = cfg.hidden_size * cfg.hidden_size;
4116 let v = cfg.hidden_size * cfg.hidden_size;
4117 let o = cfg.hidden_size * cfg.hidden_size;
4118
4119 attn_norm + ffn_norm + gate + up + down + q + k + v + o
4120 };
4121
4122 patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4123 };
4124
4125 let elems = text_elems + vision_elems;
4126
4127 Ok(elems * dtype.size_in_bytes())
4128 }
4129
4130 fn layer_sizes_in_bytes(
4131 &self,
4132 config: &str,
4133 dtype: DType,
4134 weight_pack_factor: usize,
4135 _matformer_config: Option<&MatformerSliceConfig>,
4136 ) -> Result<Vec<usize>> {
4137 let cfg: Mistral3Config = serde_json::from_str(config)?;
4138 let cfg = &cfg.text_config;
4139
4140 let per_layer_elems = {
4141 let input_layernorm = cfg.hidden_size;
4142 let post_attention_layernorm = cfg.hidden_size;
4143
4144 let size_in = cfg.hidden_size;
4145 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4146 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4147 let q_proj = size_in * size_q / weight_pack_factor;
4148 let k_proj = size_in * size_kv / weight_pack_factor;
4149 let v_proj = size_in * size_kv / weight_pack_factor;
4150 let o_proj = size_q * size_in / weight_pack_factor;
4151
4152 let h_size = cfg.hidden_size;
4153 let i_size = cfg.intermediate_size;
4154 let gate_proj = h_size * i_size / weight_pack_factor;
4155 let up_proj = h_size * i_size / weight_pack_factor;
4156 let down_proj = i_size * h_size / weight_pack_factor;
4157
4158 input_layernorm
4159 + post_attention_layernorm
4160 + q_proj
4161 + k_proj
4162 + v_proj
4163 + o_proj
4164 + gate_proj
4165 + up_proj
4166 + down_proj
4167 };
4168 Ok(vec![
4169 per_layer_elems * dtype.size_in_bytes();
4170 cfg.num_hidden_layers
4171 ])
4172 }
4173
4174 fn num_layers(&self, config: &str) -> Result<usize> {
4175 let cfg: Mistral3Config = serde_json::from_str(config)?;
4176 let cfg = &cfg.text_config;
4177 Ok(cfg.num_hidden_layers)
4178 }
4179
4180 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4181 let cfg: Mistral3Config = serde_json::from_str(config)?;
4182 let cfg = &cfg.text_config;
4183
4184 let cfg = ModelConfigMetadata {
4185 max_seq_len: cfg.max_position_embeddings,
4186 num_layers: cfg.num_hidden_layers,
4187 hidden_size: cfg.hidden_size,
4188 num_kv_heads: cfg.num_key_value_heads,
4189 num_attn_heads: cfg.num_attention_heads,
4190 sliding_window: cfg.sliding_window,
4191 k_head_dim: cfg.head_dim(),
4192 v_head_dim: cfg.head_dim(),
4193 };
4194
4195 Ok(Box::new(cfg))
4196 }
4197
4198 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4199 Some(vec![NonMappedSubModel::Vision])
4200 }
4201}
4202
4203pub struct VLlama4Loader;
4209
4210pub struct VLlama4Prefixer;
4211
4212impl MultimodalPromptPrefixer for VLlama4Prefixer {
4213 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4214 format!(
4215 "{}{prompt}",
4216 llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4217 )
4218 }
4219}
4220
4221impl VisionModelLoader for VLlama4Loader {
4222 fn load(
4223 &self,
4224 config: &str,
4225 vb: ShardedVarBuilder,
4226 normal_loading_metadata: NormalLoadingMetadata,
4227 attention_mechanism: AttentionImplementation,
4228 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4229 let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4230 Ok(Box::new(Llama4Model::new(
4231 &cfg,
4232 vb,
4233 self.is_gptx(config),
4234 normal_loading_metadata,
4235 attention_mechanism,
4236 )?))
4237 }
4238 fn is_gptx(&self, _config: &str) -> bool {
4239 false
4240 }
4241 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4242 let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4243 Ok(Box::new(cfg))
4244 }
4245 fn get_processor(
4246 &self,
4247 _model_config: &str,
4248 processor_config: Option<ProcessorConfig>,
4249 _preprocessor_config: PreProcessorConfig,
4250 _max_edge: Option<u32>,
4251 ) -> Arc<dyn Processor + Send + Sync> {
4252 Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4253 }
4254 fn supports_paged_attention(&self, _config: &str) -> bool {
4255 true
4256 }
4257 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4258 Arc::new(VLlama4Prefixer)
4259 }
4260 fn modalities(&self, _config: &str) -> Result<Modalities> {
4261 Ok(Modalities {
4262 input: vec![SupportedModality::Text, SupportedModality::Vision],
4263 output: vec![SupportedModality::Text],
4264 })
4265 }
4266}
4267
4268impl IsqModelLoader for VLlama4Loader {
4269 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4270 Ok(vec![
4271 Regex::new(r"lm_head\.(weight|bias)$")?,
4272 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4274 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4275 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4276 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4277 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4279 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4280 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4281 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4282 Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4283 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4284 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4285 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4286 Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4288 Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4289 Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4290 ])
4291 }
4292 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4293 Ok(vec![
4294 Regex::new(r"lm_head\.(weight|bias)$")?,
4295 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4297 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4298 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4299 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4300 Regex::new(
4302 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4303 )?,
4304 Regex::new(
4305 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4306 )?,
4307 Regex::new(
4308 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4309 )?,
4310 Regex::new(
4311 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4312 )?,
4313 Regex::new(
4314 r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4315 )?,
4316 Regex::new(
4317 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4318 )?,
4319 Regex::new(
4320 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4321 )?,
4322 Regex::new(
4323 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4324 )?,
4325 Regex::new(
4327 r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4328 )?,
4329 Regex::new(
4330 r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4331 )?,
4332 Regex::new(
4333 r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4334 )?,
4335 ])
4336 }
4337}
4338
4339impl VLlama4Loader {
4340 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4343 fn run_dummy_processing(
4344 &self,
4345 cfg: &Llama4Config,
4346 height: usize,
4347 width: usize,
4348 max_num_images: usize,
4349 max_batch_size: usize,
4350 ) -> Result<(usize, usize)> {
4351 let cfg = &cfg.vision_config;
4352
4353 let img_processor =
4354 Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4355 let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4356 let res = img_processor.preprocess(
4357 vec![image; max_num_images],
4358 vec![],
4359 &PreProcessorConfig::default(),
4360 &Device::Cpu,
4361 (max_batch_size, max_num_images),
4362 )?;
4363
4364 let pixels_batch_size = res.pixel_values.dim(0)?;
4365 let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4366
4367 let (image_h, image_w) = (
4368 res.pixel_values.dim(D::Minus2).unwrap(),
4369 res.pixel_values.dim(D::Minus1).unwrap(),
4370 );
4371 let num_patches_per_chunk = (image_h / img_processor.patch_size)
4372 * (image_w / img_processor.patch_size)
4373 / img_processor.downsample_ratio;
4374
4375 Ok((
4376 pixels_max_batch_size,
4377 num_patches_per_chunk * pixels_max_batch_size,
4378 ))
4379 }
4380}
4381
4382impl DeviceMappedModelLoader for VLlama4Loader {
4383 fn mapped_max_act_size_elems(
4384 &self,
4385 config: &str,
4386 params: &AutoDeviceMapParams,
4387 ) -> Result<usize> {
4388 let AutoDeviceMapParams::Vision {
4389 max_seq_len,
4390 max_batch_size,
4391 max_image_shape: (height, width),
4392 max_num_images,
4393 } = params
4394 else {
4395 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4396 };
4397
4398 let cfg: Llama4Config = serde_json::from_str(config)?;
4399
4400 let (_pixels_batch_size, num_text_image_toks) =
4401 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4402
4403 let max_seq_len = max_seq_len.min(&ATTENTION_CHUNK_SIZE) + num_text_image_toks;
4404
4405 Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4406 }
4407 fn non_mapped_max_act_size_elems(
4408 &self,
4409 config: &str,
4410 params: &AutoDeviceMapParams,
4411 ) -> Result<usize> {
4412 let AutoDeviceMapParams::Vision {
4413 max_seq_len: _,
4414 max_batch_size,
4415 max_image_shape: (height, width),
4416 max_num_images,
4417 } = params
4418 else {
4419 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4420 };
4421
4422 let cfg: Llama4Config = serde_json::from_str(config)?;
4423
4424 let (pixels_batch_size, _num_text_image_toks) =
4425 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4426 let max_seq_len = cfg.vision_config.num_patches();
4427
4428 Ok((max_batch_size * pixels_batch_size)
4429 * cfg.vision_config.num_attention_heads
4430 * max_seq_len
4431 * max_seq_len)
4432 }
4433
4434 fn non_mapped_size_in_bytes(
4435 &self,
4436 config: &str,
4437 dtype: DType,
4438 weight_pack_factor: usize,
4439 _matformer_config: Option<&MatformerSliceConfig>,
4440 ) -> Result<usize> {
4441 let cfg: Llama4Config = serde_json::from_str(config)?;
4442 let tcfg = &cfg.text_config;
4443
4444 let text_elems = {
4445 let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4446 let lm_head = if !tcfg.tie_word_embeddings {
4447 tcfg.hidden_size * tcfg.vocab_size
4448 } else {
4449 0
4450 };
4451 let norm = tcfg.hidden_size;
4452 embed_tokens + lm_head + norm
4453 };
4454
4455 let vision_elems = {
4456 let cfg = &cfg.vision_config;
4457
4458 let num_patches = cfg.num_patches();
4459
4460 let unfold_elems =
4461 (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4462 let class_embeddng_elems = cfg.hidden_size;
4463 let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4464 let layernorm_pre_elems = cfg.hidden_size;
4465 let layernorm_post_elems = cfg.hidden_size;
4466
4467 let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4468 / weight_pack_factor
4469 + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4470
4471 let encoder_layer = {
4472 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4473 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4474
4475 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4476 let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4477 / weight_pack_factor
4478 + cfg.num_attention_heads * head_dim;
4479 let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4480 / weight_pack_factor
4481 + cfg.num_attention_heads * head_dim;
4482 let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4483 / weight_pack_factor
4484 + cfg.num_attention_heads * head_dim;
4485 let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4486 / weight_pack_factor
4487 + cfg.num_attention_heads * head_dim;
4488
4489 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4490 + cfg.intermediate_size;
4491 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4492 + cfg.hidden_size;
4493
4494 input_layernorm
4495 + post_attention_layernorm
4496 + q_proj
4497 + k_proj
4498 + v_proj
4499 + o_proj
4500 + fc1
4501 + fc2
4502 };
4503
4504 unfold_elems
4505 + class_embeddng_elems
4506 + positional_embedding_vlm_elems
4507 + layernorm_post_elems
4508 + layernorm_pre_elems
4509 + pixel_shuffle_elems
4510 + encoder_layer * cfg.num_hidden_layers
4511 };
4512
4513 let elems = text_elems + vision_elems;
4514
4515 Ok(elems * dtype.size_in_bytes())
4516 }
4517
4518 fn layer_sizes_in_bytes(
4519 &self,
4520 config: &str,
4521 dtype: DType,
4522 weight_pack_factor: usize,
4523 _matformer_config: Option<&MatformerSliceConfig>,
4524 ) -> Result<Vec<usize>> {
4525 let cfg: Llama4Config = serde_json::from_str(config)?;
4526 let tcfg = &cfg.text_config;
4527
4528 let mut per_layer_elems = Vec::new();
4529
4530 for layer_idx in 0..tcfg.num_hidden_layers {
4531 let input_layernorm = tcfg.hidden_size;
4532 let post_attention_layernorm = tcfg.hidden_size;
4533
4534 let size_in = tcfg.hidden_size;
4535 let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4536 let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4537 let q_proj = size_in * size_q / weight_pack_factor;
4538 let k_proj = size_in * size_kv / weight_pack_factor;
4539 let v_proj = size_in * size_kv / weight_pack_factor;
4540 let o_proj = size_q * size_in / weight_pack_factor;
4541
4542 let use_moe = tcfg.moe_layers().contains(&layer_idx);
4543 let moe_block = if use_moe {
4544 let h_size = tcfg.hidden_size;
4545 let i_size = tcfg.intermediate_size;
4546 let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4547 let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4548 let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4549
4550 gate_proj + up_proj + down_proj
4551 } else {
4552 let h_size = tcfg.hidden_size;
4553 let i_size = tcfg.intermediate_size_mlp;
4554 let gate_proj = h_size * i_size / weight_pack_factor;
4555 let up_proj = h_size * i_size / weight_pack_factor;
4556 let down_proj = i_size * h_size / weight_pack_factor;
4557
4558 gate_proj + up_proj + down_proj
4559 };
4560
4561 per_layer_elems.push(
4562 input_layernorm
4563 + post_attention_layernorm
4564 + q_proj
4565 + k_proj
4566 + v_proj
4567 + o_proj
4568 + moe_block,
4569 );
4570 }
4571
4572 Ok(per_layer_elems
4573 .into_iter()
4574 .map(|x| x * dtype.size_in_bytes())
4575 .collect())
4576 }
4577
4578 fn num_layers(&self, config: &str) -> Result<usize> {
4579 let cfg: Llama4Config = serde_json::from_str(config)?;
4580 Ok(cfg.text_config.num_hidden_layers)
4581 }
4582
4583 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4584 let cfg: Llama4Config = serde_json::from_str(config)?;
4585 let cfg = &cfg.text_config;
4586
4587 let cfg = ModelConfigMetadata {
4588 max_seq_len: cfg.max_position_embeddings,
4589 num_layers: cfg.num_hidden_layers,
4590 hidden_size: cfg.hidden_size,
4591 num_kv_heads: cfg.num_attention_heads,
4592 num_attn_heads: cfg.num_attention_heads,
4593 sliding_window: None,
4594 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4595 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4596 };
4597
4598 Ok(Box::new(cfg))
4599 }
4600
4601 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4602 Some(vec![NonMappedSubModel::Vision])
4603 }
4604}
4605
4606pub struct Gemma3nLoader;
4612
4613#[allow(dead_code)]
4614pub struct Gemma3nPrefixer;
4615
4616impl MultimodalPromptPrefixer for Gemma3nPrefixer {
4617 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4618 prompt.to_string()
4619 }
4620}
4621
4622impl VisionModelLoader for Gemma3nLoader {
4623 fn load(
4624 &self,
4625 config: &str,
4626 vb: ShardedVarBuilder,
4627 normal_loading_metadata: NormalLoadingMetadata,
4628 attention_mechanism: AttentionImplementation,
4629 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4630 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4631 Ok(Box::new(Gemma3nModel::new(
4632 &cfg,
4633 vb,
4634 self.is_gptx(config),
4635 normal_loading_metadata,
4636 attention_mechanism,
4637 )?))
4638 }
4639 fn is_gptx(&self, _config: &str) -> bool {
4640 true
4641 }
4642 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4643 let config: Gemma3nConfig = serde_json::from_str(config)?;
4644 Ok(Box::new(config))
4645 }
4646 fn get_processor(
4647 &self,
4648 _config: &str,
4649 processor_config: Option<ProcessorConfig>,
4650 _preprocessor_config: PreProcessorConfig,
4651 _max_edge: Option<u32>,
4652 ) -> Arc<dyn Processor + Send + Sync> {
4653 Arc::new(Gemma3nProcessor::new(
4655 processor_config.unwrap_or_default(),
4656 true,
4657 ))
4658 }
4659 fn supports_paged_attention(&self, _config: &str) -> bool {
4660 false
4661 }
4662 fn supports_prefix_cacher(&self, _config: &str) -> bool {
4663 true
4664 }
4665 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4666 Arc::new(Gemma3Prefixer)
4667 }
4668 fn modalities(&self, _config: &str) -> Result<Modalities> {
4669 Ok(Modalities {
4670 input: vec![
4671 SupportedModality::Text,
4672 SupportedModality::Vision,
4673 SupportedModality::Audio,
4674 ],
4675 output: vec![SupportedModality::Text],
4676 })
4677 }
4678}
4679
4680impl IsqModelLoader for Gemma3nLoader {
4681 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4682 Ok(vec![
4683 Regex::new(r"lm_head\.(weight|bias)$")?,
4684 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4686 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4687 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4688 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4689 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4691 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4692 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4693 Regex::new(r"conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$")?,
4695 Regex::new(r"conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$")?,
4696 Regex::new(r"conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$")?,
4697 Regex::new(
4698 r"conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4699 )?,
4700 Regex::new(r"conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4701 Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$")?,
4703 Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$")?,
4704 Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$")?,
4705 Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$")?,
4706 Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$")?,
4708 Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$")?,
4709 Regex::new(r"subsample_conv_projection\.input_proj_linear\.(weight|bias)$")?,
4711 Regex::new(r"embed_vision\.embedding_projection\.(weight|bias)$")?,
4713 Regex::new(r"embed_audio\.embedding_projection\.(weight|bias)$")?,
4714 ])
4715 }
4716 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4717 Ok(vec![
4718 Regex::new(r"lm_head\.(weight|bias)$")?,
4719 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4721 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4722 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4723 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4724 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4726 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4727 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4728 Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
4730 Regex::new(r"model\.language_model\.altup_projections\.(\d+)\.(weight|bias)$")?,
4731 Regex::new(r"model\.language_model\.altup_unembed_projections\.(\d+)\.(weight|bias)$")?,
4732 Regex::new(
4734 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$",
4735 )?,
4736 Regex::new(
4737 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$",
4738 )?,
4739 Regex::new(
4740 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$",
4741 )?,
4742 Regex::new(
4743 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4744 )?,
4745 Regex::new(r"model\.audio_tower\.conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4746 Regex::new(
4748 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$",
4749 )?,
4750 Regex::new(
4751 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$",
4752 )?,
4753 Regex::new(
4754 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$",
4755 )?,
4756 Regex::new(
4757 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$",
4758 )?,
4759 Regex::new(
4761 r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$",
4762 )?,
4763 Regex::new(
4764 r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$",
4765 )?,
4766 Regex::new(
4768 r"model\.audio_tower\.subsample_conv_projection\.input_proj_linear\.(weight|bias)$",
4769 )?,
4770 Regex::new(r"model\.embed_vision\.embedding_projection\.(weight|bias)$")?,
4772 Regex::new(r"model\.embed_audio\.embedding_projection\.(weight|bias)$")?,
4773 ])
4774 }
4775}
4776
4777impl DeviceMappedModelLoader for Gemma3nLoader {
4778 fn mapped_max_act_size_elems(
4779 &self,
4780 config: &str,
4781 params: &AutoDeviceMapParams,
4782 ) -> Result<usize> {
4783 let AutoDeviceMapParams::Vision {
4784 max_seq_len,
4785 max_batch_size,
4786 max_image_shape: _,
4787 max_num_images,
4788 } = params
4789 else {
4790 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4791 };
4792
4793 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4794 let text_cfg = &cfg.text_config;
4795
4796 let mut total_seq_len = *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4800
4801 {
4803 let msfa_spatial_size = 16; let vision_tokens_per_image = msfa_spatial_size * msfa_spatial_size; total_seq_len += vision_tokens_per_image * max_num_images;
4808 }
4809
4810 {
4812 let audio_tokens = cfg.audio_soft_tokens_per_image;
4815 total_seq_len += audio_tokens;
4816 }
4817
4818 let max_text_attn =
4820 max_batch_size * text_cfg.num_attention_heads * total_seq_len * total_seq_len;
4821
4822 Ok(max_text_attn)
4823 }
4824
4825 fn non_mapped_max_act_size_elems(
4826 &self,
4827 config: &str,
4828 params: &AutoDeviceMapParams,
4829 ) -> Result<usize> {
4830 let AutoDeviceMapParams::Vision {
4831 max_seq_len: _,
4832 max_batch_size,
4833 max_image_shape: _,
4834 max_num_images,
4835 } = params
4836 else {
4837 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4838 };
4839
4840 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4841
4842 let mut max_activation = 0;
4844
4845 {
4847 let vision_tower_act = {
4857 let num_heads = 16; let spatial_size = 24; let seq_len = spatial_size * spatial_size;
4863
4864 max_batch_size * max_num_images * num_heads * seq_len * seq_len
4866 };
4867
4868 let vision_embed_act = {
4870 let msfa_channels = 2048; let spatial_size = 16; let vision_features =
4874 max_batch_size * max_num_images * msfa_channels * spatial_size * spatial_size;
4875
4876 let projected = max_batch_size
4878 * max_num_images
4879 * spatial_size
4880 * spatial_size
4881 * cfg.text_config.hidden_size;
4882
4883 vision_features.max(projected)
4884 };
4885
4886 max_activation = max_activation.max(vision_tower_act).max(vision_embed_act);
4887 }
4888
4889 {
4891 let audio_cfg = &cfg.audio_config;
4892
4893 let max_audio_frames = 1280;
4898
4899 let subsample_factor: usize = audio_cfg
4900 .sscp_conv_stride_size
4901 .iter()
4902 .map(|stride| stride[0]) .product();
4904 let audio_seq_after_subsample = max_audio_frames / subsample_factor;
4905
4906 let audio_encoder_act = {
4908 let intermediate_size = audio_cfg.hidden_size * 4; max_batch_size * audio_seq_after_subsample * intermediate_size
4913 };
4914
4915 let audio_attn_act = {
4917 let chunk_size = audio_cfg.conf_attention_chunk_size;
4919 let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
4920 + audio_cfg.conf_attention_context_right;
4921
4922 let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
4924
4925 max_batch_size
4926 * audio_cfg.conf_num_attention_heads
4927 * num_chunks
4928 * chunk_size
4929 * context_size
4930 };
4931
4932 max_activation = max_activation.max(audio_encoder_act).max(audio_attn_act);
4933 }
4934
4935 Ok(max_activation)
4936 }
4937
4938 fn non_mapped_size_in_bytes(
4939 &self,
4940 config: &str,
4941 dtype: DType,
4942 weight_pack_factor: usize,
4943 matformer_config: Option<&MatformerSliceConfig>,
4944 ) -> Result<usize> {
4945 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4946
4947 let text_cfg = if let Some(matformer_cfg) = matformer_config {
4949 use crate::device_map::DummyDeviceMapper;
4950 use crate::vision_models::gemma3n::text::handle_matformer_slicing;
4951
4952 let dummy_mapper = DummyDeviceMapper {
4953 nm_device: Device::Cpu,
4954 };
4955 let (adjusted_cfg, _, _, _, _) = handle_matformer_slicing(
4956 &cfg.text_config,
4957 &Some(matformer_cfg.clone()),
4958 &dummy_mapper,
4959 )?;
4960 adjusted_cfg
4961 } else {
4962 cfg.text_config.clone()
4963 };
4964
4965 let text_cfg = &text_cfg;
4966
4967 let text_elems = {
4969 let embed_tokens = text_cfg.hidden_size * text_cfg.vocab_size;
4971 let embed_tokens_per_layer = text_cfg.num_hidden_layers
4972 * text_cfg.hidden_size_per_layer_input
4973 * text_cfg.vocab_size_per_layer_input;
4974
4975 let lm_head = if !text_cfg.tie_word_embeddings || weight_pack_factor != 1 {
4977 text_cfg.hidden_size * text_cfg.vocab_size / weight_pack_factor
4978 } else {
4979 0
4980 };
4981
4982 let norm = text_cfg.hidden_size;
4984
4985 let altup_projections =
4987 (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
4988 / weight_pack_factor;
4989 let altup_unembed_projections =
4990 (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
4991 / weight_pack_factor;
4992
4993 let per_layer_model_projection = text_cfg.num_hidden_layers
4995 * text_cfg.hidden_size
4996 * text_cfg.hidden_size_per_layer_input
4997 / weight_pack_factor;
4998 let per_layer_projection_norm = text_cfg.hidden_size;
4999
5000 embed_tokens
5001 + embed_tokens_per_layer
5002 + lm_head
5003 + norm
5004 + altup_projections
5005 + altup_unembed_projections
5006 + per_layer_model_projection
5007 + per_layer_projection_norm
5008 };
5009
5010 let vision_elems = {
5012 let vision_cfg = &cfg.vision_config;
5013 let vision_tower_elems = {
5017 use crate::vision_models::gemma3n::vision::{
5018 gemma3n_mobilenet_def, make_divisible, BlockType, INPUT_CHANNELS,
5019 MSFA_EXPANSION_RATIO, MSFA_IN_CHANNELS, MSFA_OUT_CHANNELS, STEM_KERNEL_SIZE,
5020 STEM_OUT_CHANNELS,
5021 };
5022
5023 let stem_conv =
5025 INPUT_CHANNELS * STEM_OUT_CHANNELS * STEM_KERNEL_SIZE * STEM_KERNEL_SIZE;
5026 let stem_norm = STEM_OUT_CHANNELS; let mut in_chs = STEM_OUT_CHANNELS;
5030 let mut total_elems = stem_conv + stem_norm;
5031
5032 let block_defs = gemma3n_mobilenet_def();
5034
5035 for stage_blocks in block_defs.iter() {
5036 for block_type in stage_blocks.iter() {
5037 match block_type {
5038 BlockType::EdgeResidual {
5039 out_channels,
5040 kernel_size,
5041 stride: _,
5042 expand_ratio,
5043 ..
5044 } => {
5045 #[allow(clippy::cast_precision_loss)]
5046 let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5047 total_elems += in_chs * mid_chs * kernel_size * kernel_size; total_elems += mid_chs; total_elems += mid_chs * out_channels; total_elems += out_channels; in_chs = *out_channels;
5053 }
5054 BlockType::UniversalInvertedResidual {
5055 out_channels,
5056 start_kernel_size,
5057 mid_kernel_size,
5058 stride: _,
5059 expand_ratio,
5060 ..
5061 } => {
5062 #[allow(clippy::cast_precision_loss)]
5063 let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5064 if *expand_ratio != 1.0 {
5066 total_elems += in_chs * mid_chs; total_elems += mid_chs; }
5069 if *start_kernel_size > 0 {
5070 total_elems += mid_chs * start_kernel_size * start_kernel_size; total_elems += mid_chs; }
5073 if *mid_kernel_size > 0 {
5074 total_elems += mid_chs * mid_kernel_size * mid_kernel_size; total_elems += mid_chs; }
5077 total_elems += mid_chs * out_channels; total_elems += out_channels; total_elems += out_channels; in_chs = *out_channels;
5081 }
5082 BlockType::MultiQueryAttention {
5083 num_heads,
5084 kv_dim,
5085 kv_stride: _,
5086 ..
5087 } => {
5088 let dw_kernel_size = 3; total_elems += in_chs; total_elems += in_chs * num_heads * kv_dim; total_elems += in_chs * kv_dim; total_elems += in_chs * dw_kernel_size * dw_kernel_size; total_elems += *kv_dim; total_elems += 1; total_elems += *kv_dim; total_elems += num_heads * kv_dim * in_chs; total_elems += in_chs; }
5100 }
5101 }
5102 }
5103
5104 let msfa_in = MSFA_IN_CHANNELS.iter().sum::<usize>();
5106 let msfa_out = MSFA_OUT_CHANNELS;
5107 #[allow(clippy::cast_precision_loss)]
5108 let msfa_mid = make_divisible(msfa_in as f64 * MSFA_EXPANSION_RATIO, 8);
5109
5110 total_elems += msfa_in * msfa_mid; total_elems += msfa_mid; total_elems += msfa_mid * msfa_out; total_elems += msfa_out; total_elems += msfa_out; total_elems
5118 };
5119
5120 let embed_vision_elems = {
5122 let embedding = vision_cfg.vocab_size * vision_cfg.hidden_size;
5124
5125 let hard_norm = vision_cfg.hidden_size;
5127 let soft_norm = vision_cfg.hidden_size;
5128
5129 let projection = vision_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5131
5132 let post_norm = text_cfg.hidden_size;
5134
5135 embedding + hard_norm + soft_norm + projection + post_norm
5136 };
5137
5138 vision_tower_elems + embed_vision_elems
5139 };
5140
5141 let audio_elems = {
5143 let audio_cfg = &cfg.audio_config;
5144
5145 let subsample_conv_projection_elems = {
5147 let mut conv_elems = 0;
5149
5150 let in_ch_0 = 1;
5152 let out_ch_0 = audio_cfg.sscp_conv_channel_size[0];
5153 let kernel_0 = &audio_cfg.sscp_conv_kernel_size[0];
5154 conv_elems += in_ch_0 * out_ch_0 * kernel_0[0] * kernel_0[1];
5155
5156 let in_ch_1 = out_ch_0;
5158 let out_ch_1 = audio_cfg.sscp_conv_channel_size[1];
5159 let kernel_1 = &audio_cfg.sscp_conv_kernel_size[1];
5160 conv_elems += in_ch_1 * out_ch_1 * kernel_1[0] * kernel_1[1];
5161
5162 let norm_0 = out_ch_0; let norm_1 = out_ch_1; let mut f_out = audio_cfg.input_feat_size;
5168 for i in 0..2 {
5169 let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
5170 let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
5171 let pad_left = 1;
5172 let pad_right = 1;
5173 f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
5174 }
5175 let input_proj_in_features = out_ch_1 * f_out;
5176 let input_proj_linear =
5177 input_proj_in_features * audio_cfg.hidden_size / weight_pack_factor;
5178
5179 conv_elems + norm_0 + norm_1 + input_proj_linear
5180 };
5181
5182 let conformer_elems = {
5184 let mut total = 0;
5185
5186 for _ in 0..audio_cfg.conf_num_hidden_layers {
5187 let attention_elems = {
5189 let pre_attn_norm = audio_cfg.hidden_size;
5191 let post_norm = audio_cfg.hidden_size;
5192
5193 let q_proj =
5195 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5196 let k_proj =
5197 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5198 let v_proj =
5199 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5200 let post =
5201 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5202
5203 let pos_proj =
5205 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5206 let per_dim_scale =
5207 audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads; let inv_timescales = audio_cfg.hidden_size / 2; let pos_indices = audio_cfg.conf_attention_context_left
5210 + audio_cfg.conf_attention_context_right
5211 + 1;
5212
5213 let chunk_size = audio_cfg.conf_attention_chunk_size;
5215 let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5216 + audio_cfg.conf_attention_context_right;
5217 let local_causal_valid_mask = chunk_size * context_size; let invalid_logits_tensor = 1; pre_attn_norm
5221 + post_norm
5222 + q_proj
5223 + k_proj
5224 + v_proj
5225 + post
5226 + pos_proj
5227 + per_dim_scale
5228 + inv_timescales
5229 + pos_indices
5230 + local_causal_valid_mask
5231 + invalid_logits_tensor
5232 };
5233
5234 let ffw_elems = {
5236 let intermediate_size = audio_cfg.hidden_size * 4;
5242
5243 let ffw_start = {
5244 let pre_norm = audio_cfg.hidden_size;
5245 let layer_1 =
5246 audio_cfg.hidden_size * intermediate_size / weight_pack_factor;
5247 let layer_2 =
5248 intermediate_size * audio_cfg.hidden_size / weight_pack_factor;
5249 let post_norm = audio_cfg.hidden_size;
5250 pre_norm + layer_1 + layer_2 + post_norm
5251 };
5252
5253 let ffw_end = ffw_start; ffw_start + ffw_end
5256 };
5257
5258 let lconv1d_elems = {
5260 let pre_layer_norm = audio_cfg.hidden_size;
5262 let conv_norm = audio_cfg.hidden_size;
5263
5264 let linear_start = audio_cfg.hidden_size * (audio_cfg.hidden_size * 2)
5266 / weight_pack_factor;
5267 let linear_end =
5268 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5269
5270 let depthwise = audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
5272
5273 pre_layer_norm + conv_norm + linear_start + linear_end + depthwise
5274 };
5275
5276 let block_norm = audio_cfg.hidden_size;
5278
5279 total += attention_elems + ffw_elems + lconv1d_elems + block_norm;
5280 }
5281
5282 total
5283 };
5284
5285 let embed_audio_elems = {
5287 let embedding = audio_cfg.vocab_size * audio_cfg.hidden_size;
5289
5290 let hard_embedding_norm = audio_cfg.hidden_size; let soft_embedding_norm = audio_cfg.hidden_size; let embedding_post_projection_norm = text_cfg.hidden_size; let embedding_projection =
5297 audio_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5298
5299 embedding
5300 + hard_embedding_norm
5301 + soft_embedding_norm
5302 + embedding_post_projection_norm
5303 + embedding_projection
5304 };
5305
5306 subsample_conv_projection_elems + conformer_elems + embed_audio_elems
5307 };
5308
5309 let vision_dtype = if dtype == DType::F16 {
5310 DType::F32
5312 } else {
5313 dtype
5314 };
5315
5316 let total_elems = text_elems * dtype.size_in_bytes()
5317 + vision_elems * vision_dtype.size_in_bytes()
5318 + audio_elems * dtype.size_in_bytes();
5319
5320 Ok(total_elems)
5321 }
5322
5323 fn layer_sizes_in_bytes(
5324 &self,
5325 config: &str,
5326 dtype: DType,
5327 weight_pack_factor: usize,
5328 matformer_config: Option<&MatformerSliceConfig>,
5329 ) -> Result<Vec<usize>> {
5330 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5331
5332 let (text_cfg, _layer_rename_map, _layers_skipped) = if let Some(matformer_cfg) =
5334 matformer_config
5335 {
5336 use crate::device_map::DummyDeviceMapper;
5337 use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5338
5339 let dummy_mapper = DummyDeviceMapper {
5340 nm_device: Device::Cpu,
5341 };
5342 let (adjusted_cfg, _, _, layer_rename_map, layers_skipped) = handle_matformer_slicing(
5343 &cfg.text_config,
5344 &Some(matformer_cfg.clone()),
5345 &dummy_mapper,
5346 )?;
5347 (adjusted_cfg, layer_rename_map, layers_skipped)
5348 } else {
5349 (cfg.text_config.clone(), None, None)
5350 };
5351
5352 let text_cfg = &text_cfg;
5353
5354 let mut layer_sizes = Vec::new();
5356
5357 for layer_idx in 0..text_cfg.num_hidden_layers {
5361 let per_layer_elems = {
5362 let input_layernorm = text_cfg.hidden_size;
5364 let post_attention_layernorm = text_cfg.hidden_size;
5365 let pre_feedforward_layernorm = text_cfg.hidden_size;
5366 let post_feedforward_layernorm = text_cfg.hidden_size;
5367 let post_per_layer_input_norm = text_cfg.hidden_size;
5368
5369 let size_in = text_cfg.hidden_size;
5371 let size_q = text_cfg.num_attention_heads * text_cfg.head_dim;
5372 let size_kv = text_cfg.num_key_value_heads * text_cfg.head_dim;
5373
5374 let q_proj = size_in * size_q / weight_pack_factor;
5375 let k_proj = size_in * size_kv / weight_pack_factor;
5376 let v_proj = size_in * size_kv / weight_pack_factor;
5377 let o_proj = size_q * size_in / weight_pack_factor;
5378
5379 let q_norm = text_cfg.head_dim;
5381 let k_norm = text_cfg.head_dim;
5382 let v_norm = text_cfg.head_dim; let intermediate_size = match &text_cfg.intermediate_size {
5386 IntermediateSize::Single(size) => *size,
5387 IntermediateSize::PerLayer(sizes) => sizes[layer_idx],
5388 IntermediateSize::Matformer(sizes, _) => sizes[layer_idx],
5389 };
5390 let gate_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5391 let up_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5392 let down_proj = intermediate_size * text_cfg.hidden_size / weight_pack_factor;
5393
5394 let altup_elems = {
5396 let correct_output_scale = text_cfg.hidden_size;
5397 let correction_coefs = text_cfg.altup_num_inputs * text_cfg.altup_num_inputs;
5398 let prediction_coefs =
5399 text_cfg.altup_num_inputs * text_cfg.altup_num_inputs.pow(2);
5400 let modality_router = text_cfg.hidden_size * text_cfg.altup_num_inputs;
5401 let router_norm = text_cfg.hidden_size;
5402
5403 correct_output_scale
5404 + correction_coefs
5405 + prediction_coefs
5406 + modality_router
5407 + router_norm
5408 };
5409
5410 let laurel_elems = {
5412 let left = text_cfg.hidden_size * text_cfg.laurel_rank;
5413 let right = text_cfg.laurel_rank * text_cfg.hidden_size;
5414 let post_norm = text_cfg.hidden_size;
5415
5416 left + right + post_norm
5417 };
5418
5419 let per_layer_input_gate =
5421 text_cfg.hidden_size * text_cfg.hidden_size_per_layer_input;
5422 let per_layer_projection =
5423 text_cfg.hidden_size_per_layer_input * text_cfg.hidden_size;
5424
5425 input_layernorm
5426 + post_attention_layernorm
5427 + pre_feedforward_layernorm
5428 + post_feedforward_layernorm
5429 + post_per_layer_input_norm
5430 + q_proj
5431 + k_proj
5432 + v_proj
5433 + o_proj
5434 + q_norm
5435 + k_norm
5436 + v_norm
5437 + gate_proj
5438 + up_proj
5439 + down_proj
5440 + altup_elems
5441 + laurel_elems
5442 + per_layer_input_gate
5443 + per_layer_projection
5444 };
5445
5446 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5447 }
5448
5449 Ok(layer_sizes)
5450 }
5451
5452 fn num_layers(&self, config: &str) -> Result<usize> {
5453 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5454 Ok(cfg.text_config.num_hidden_layers)
5455 }
5456
5457 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5458 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5459 let cfg = cfg.text_config;
5460
5461 let cfg = ModelConfigMetadata {
5462 max_seq_len: cfg.max_position_embeddings,
5463 num_layers: cfg.num_hidden_layers,
5464 hidden_size: cfg.hidden_size,
5465 num_kv_heads: cfg.num_key_value_heads,
5466 num_attn_heads: cfg.num_attention_heads,
5467 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5469 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5470 };
5471
5472 Ok(Box::new(cfg))
5473 }
5474
5475 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5476 Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
5477 }
5478}