1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::attention::ATTENTION_CHUNK_SIZE;
24use crate::device_map::DeviceMapper;
25use crate::layers::Conv3dConfig;
26use crate::matformer::MatformerSliceConfig;
27use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
28use crate::pipeline::isq::IsqModelLoader;
29use crate::pipeline::loaders::AutoDeviceMapParams;
30use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
31use crate::pipeline::{
32 EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
33 SupportedModality,
34};
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::vision_models::clip::ClipConfig;
37use crate::vision_models::gemma3::config::Gemma3Config;
38use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
39use crate::vision_models::gemma3n::config::{Gemma3nConfig, IntermediateSize};
40use crate::vision_models::gemma3n::{Gemma3nModel, Gemma3nProcessor};
41use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
42use crate::vision_models::idefics2_input_processor::Idefics2Processor;
43use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
44use crate::vision_models::image_processor::ImagePreProcessor;
45use crate::vision_models::inputs_processor::Phi4MMProcessor;
46use crate::vision_models::llama4::{
47 self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
48};
49use crate::vision_models::llava::config::Config as LLaVAConfig;
50use crate::vision_models::llava15::Model as LLaVA;
51use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
52use crate::vision_models::llava_next::Model as LLaVANext;
53use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
54use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
55use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
56use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
57use crate::vision_models::phi3_inputs_processor::Phi3Processor;
58use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
59use crate::vision_models::preprocessor_config::PreProcessorConfig;
60use crate::vision_models::processor_config::ProcessorConfig;
61use crate::vision_models::qwen2_5_vl::{
62 Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
63};
64use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
65use crate::vision_models::qwen3_vl::{Config as Qwen3VLConfig, Qwen3VLModel, Qwen3VLProcessor};
66use crate::vision_models::{minicpmo, phi4};
67
68pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
69 #[allow(clippy::too_many_arguments)]
71 fn forward(
72 &self,
73 input_ids: &Tensor,
74 pixel_values: Option<Tensor>,
75 seqlen_offsets: &[usize],
76 context_lens: Vec<(usize, usize)>,
77 position_ids: Vec<usize>,
78 model_specific_args: Box<dyn Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
80 flash_params: &FlashParams,
81 ) -> candle_core::Result<Tensor>;
82 fn device(&self) -> &Device;
83 fn cache(&self) -> &EitherCache;
84 fn cache_mut(&mut self) -> &mut EitherCache;
85 fn max_seq_len(&self) -> usize;
86 fn config(&self) -> &ModelConfigMetadata;
87 fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
89}
90
91pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
92 fn load(
93 &self,
94 config: &str,
95 vb: ShardedVarBuilder,
96 normal_loading_metadata: NormalLoadingMetadata,
97 attention_mechanism: AttentionImplementation,
98 ) -> Result<Box<dyn VisionModel + Send + Sync>>;
99 fn is_gptx(&self, config: &str) -> bool;
100 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
101 fn get_processor(
102 &self,
103 model_config: &str,
104 processor_config: Option<ProcessorConfig>,
105 preprocessor_config: PreProcessorConfig,
106 max_edge: Option<u32>,
107 ) -> Arc<dyn Processor + Send + Sync>;
108 fn supports_paged_attention(&self, config: &str) -> bool;
109 fn supports_prefix_cacher(&self, _config: &str) -> bool {
110 false
112 }
113 fn modalities(&self, config: &str) -> Result<Modalities>;
114 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
115 fn get_device_for_tensor(
116 &self,
117 config: &str,
118 _mapper: &dyn DeviceMapper,
119 loading_isq: bool,
120 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
121 if loading_isq {
122 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
123 } else {
124 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
125 let num_layers = self.model_config(config)?.num_layers();
126 let closure = move |name: String| {
127 if let Some(captures) = re.captures(&name) {
128 captures
129 .get(1)
130 .and_then(|m| m.as_str().parse::<usize>().ok())
131 .map(|l| l.min(num_layers))
132 .map(DeviceForLoadTensor::Idx)
133 .unwrap_or(DeviceForLoadTensor::Base)
134 } else {
135 DeviceForLoadTensor::Base
136 }
137 };
138
139 Ok(Arc::new(closure))
140 }
141 }
142}
143
144#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
145#[derive(Clone, Debug, Deserialize, PartialEq)]
146pub enum VisionLoaderType {
148 #[serde(rename = "phi3v")]
149 Phi3V,
150 #[serde(rename = "idefics2")]
151 Idefics2,
152 #[serde(rename = "llava_next")]
153 LLaVANext,
154 #[serde(rename = "llava")]
155 LLaVA,
156 #[serde(rename = "vllama")]
157 VLlama,
158 #[serde(rename = "qwen2vl")]
159 Qwen2VL,
160 #[serde(rename = "idefics3")]
161 Idefics3,
162 #[serde(rename = "minicpmo")]
163 MiniCpmO,
164 #[serde(rename = "phi4mm")]
165 Phi4MM,
166 #[serde(rename = "qwen2_5vl")]
167 Qwen2_5VL,
168 #[serde(rename = "gemma3")]
169 Gemma3,
170 #[serde(rename = "mistral3")]
171 Mistral3,
172 #[serde(rename = "llama4")]
173 Llama4,
174 #[serde(rename = "gemma3n")]
175 Gemma3n,
176 #[serde(rename = "qwen3vl")]
177 Qwen3VL,
178}
179
180impl VisionLoaderType {
182 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
183 match name {
184 "Phi3VForCausalLM" => Ok(Self::Phi3V),
185 "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
186 "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
187 "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
188 "MllamaForConditionalGeneration" => Ok(Self::VLlama),
189 "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
190 "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
191 "MiniCPMO" => Ok(Self::MiniCpmO),
192 "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
193 "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
194 "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
195 "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
196 "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
197 "Gemma3nForConditionalGeneration" => Ok(Self::Gemma3n),
198 "Qwen3VLForConditionalGeneration" => Ok(Self::Qwen3VL),
199 other => anyhow::bail!(
200 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
201 ),
202 }
203 }
204}
205
206impl FromStr for VisionLoaderType {
207 type Err = String;
208 fn from_str(s: &str) -> Result<Self, Self::Err> {
209 match s {
210 "phi3v" => Ok(Self::Phi3V),
211 "idefics2" => Ok(Self::Idefics2),
212 "llava_next" => Ok(Self::LLaVANext),
213 "llava" => Ok(Self::LLaVA),
214 "vllama" => Ok(Self::VLlama),
215 "qwen2vl" => Ok(Self::Qwen2VL),
216 "idefics3" => Ok(Self::Idefics3),
217 "minicpmo" => Ok(Self::MiniCpmO),
218 "phi4mm" => Ok(Self::Phi4MM),
219 "qwen2_5vl" => Ok(Self::Qwen2_5VL),
220 "gemma3" => Ok(Self::Gemma3),
221 "mistral3" => Ok(Self::Mistral3),
222 "llama4" => Ok(Self::Llama4),
223 "gemma3n" => Ok(Self::Gemma3n),
224 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`, `gemma3n`.")),
225 }
226 }
227}
228
229impl std::fmt::Display for VisionLoaderType {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 let name = match self {
232 VisionLoaderType::Phi3V => "phi3v",
233 VisionLoaderType::Idefics2 => "idefics2",
234 VisionLoaderType::LLaVANext => "llava_next",
235 VisionLoaderType::LLaVA => "llava",
236 VisionLoaderType::VLlama => "vllama",
237 VisionLoaderType::Qwen2VL => "qwen2vl",
238 VisionLoaderType::Idefics3 => "idefics3",
239 VisionLoaderType::MiniCpmO => "minicpmo",
240 VisionLoaderType::Phi4MM => "phi4mm",
241 VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
242 VisionLoaderType::Gemma3 => "gemma3",
243 VisionLoaderType::Mistral3 => "mistral3",
244 VisionLoaderType::Llama4 => "llama4",
245 VisionLoaderType::Gemma3n => "gemma3n",
246 VisionLoaderType::Qwen3VL => "qwen3vl",
247 };
248 write!(f, "{name}")
249 }
250}
251
252#[derive(Deserialize)]
253struct AutoVisionLoaderConfig {
254 architectures: Vec<String>,
255}
256
257pub struct AutoVisionLoader;
259
260impl AutoVisionLoader {
261 fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
262 let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
263 if auto_cfg.architectures.len() != 1 {
264 anyhow::bail!("Expected exactly one architecture in config");
265 }
266
267 let name = &auto_cfg.architectures[0];
268 let tp = VisionLoaderType::from_causal_lm_name(name)?;
269
270 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
271
272 Ok(match tp {
274 VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
275 VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
276 VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
277 VisionLoaderType::LLaVA => Box::new(LLaVALoader),
278 VisionLoaderType::VLlama => Box::new(VLlamaLoader),
279 VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
280 VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
281 VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
282 VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
283 VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
284 VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
285 VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
286 VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
287 VisionLoaderType::Gemma3n => Box::new(Gemma3nLoader),
288 VisionLoaderType::Qwen3VL => Box::new(Qwen3VLLoader),
289 })
290 }
291}
292
293impl VisionModelLoader for AutoVisionLoader {
294 fn load(
295 &self,
296 config: &str,
297 vb: ShardedVarBuilder,
298 normal_loading_metadata: NormalLoadingMetadata,
299 attention_mechanism: AttentionImplementation,
300 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
301 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
302 }
303
304 fn is_gptx(&self, config: &str) -> bool {
305 Self::get_loader(config)
306 .expect("AutoVisionLoader get_loader")
307 .is_gptx(config)
308 }
309
310 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
311 Self::get_loader(config)?.get_config_repr(config)
312 }
313
314 fn get_processor(
315 &self,
316 model_config: &str,
317 proc_cfg: Option<ProcessorConfig>,
318 preproc_cfg: PreProcessorConfig,
319 max_edge: Option<u32>,
320 ) -> Arc<dyn Processor + Send + Sync> {
321 Self::get_loader(model_config)
322 .expect("AutoVisionLoader get_loader")
323 .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
324 }
325
326 fn supports_paged_attention(&self, config: &str) -> bool {
327 Self::get_loader(config)
328 .expect("AutoVisionLoader")
329 .supports_paged_attention(config)
330 }
331
332 fn modalities(&self, config: &str) -> Result<Modalities> {
333 Self::get_loader(config)?.modalities(config)
334 }
335
336 fn supports_prefix_cacher(&self, config: &str) -> bool {
337 Self::get_loader(config)
338 .expect("AutoVisionLoader")
339 .supports_prefix_cacher(config)
340 }
341
342 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
343 Self::get_loader(config)
344 .expect("AutoVisionLoader")
345 .prefixer(config)
346 }
347
348 fn get_device_for_tensor(
349 &self,
350 config: &str,
351 mapper: &dyn DeviceMapper,
352 loading_isq: bool,
353 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
354 Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
355 }
356}
357
358impl IsqModelLoader for AutoVisionLoader {
359 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
360 Self::get_loader(config)?.isq_layer_regexes(config)
361 }
362 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
363 Self::get_loader(config)?.immediate_isq_predicates(config)
364 }
365}
366
367impl DeviceMappedModelLoader for AutoVisionLoader {
368 fn mapped_max_act_size_elems(
369 &self,
370 config: &str,
371 params: &AutoDeviceMapParams,
372 ) -> Result<usize> {
373 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
374 }
375 fn non_mapped_max_act_size_elems(
376 &self,
377 config: &str,
378 params: &AutoDeviceMapParams,
379 ) -> Result<usize> {
380 Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
381 }
382 fn non_mapped_size_in_bytes(
383 &self,
384 config: &str,
385 dtype: DType,
386 weight_pack_factor: usize,
387 _matformer_config: Option<&MatformerSliceConfig>,
388 ) -> Result<usize> {
389 Self::get_loader(config)?.non_mapped_size_in_bytes(
390 config,
391 dtype,
392 weight_pack_factor,
393 _matformer_config,
394 )
395 }
396 fn layer_sizes_in_bytes(
397 &self,
398 config: &str,
399 dtype: DType,
400 weight_pack_factor: usize,
401 _matformer_config: Option<&MatformerSliceConfig>,
402 ) -> Result<Vec<usize>> {
403 Self::get_loader(config)?.layer_sizes_in_bytes(
404 config,
405 dtype,
406 weight_pack_factor,
407 _matformer_config,
408 )
409 }
410 fn num_layers(&self, config: &str) -> Result<usize> {
411 Self::get_loader(config)?.num_layers(config)
412 }
413 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
414 Self::get_loader(config)?.model_config(config)
415 }
416}
417
418macro_rules! bias_if {
419 ($cond:expr, $size:expr) => {
420 if $cond {
421 $size
422 } else {
423 0
424 }
425 };
426}
427
428fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
429 let pre_layer_norm = cfg.hidden_size;
430 let final_layer_norm = cfg.hidden_size;
431
432 let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
433 let num_positions = num_patches + 1;
434
435 let class_embedding = cfg.hidden_size;
436
437 let position_ids = num_positions;
438 let position_embedding = num_positions * cfg.hidden_size;
439
440 let conv2dconfig = Conv2dConfig {
441 stride: cfg.patch_size,
442 ..Default::default()
443 };
444 let patch_embedding =
445 cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
446
447 let encoder_layer_elems = {
448 let layer_norm1 = cfg.hidden_size;
449 let layer_norm2 = cfg.hidden_size;
450
451 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
452 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
453 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
454 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
455
456 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
457 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
458
459 layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
460 };
461
462 pre_layer_norm
463 + final_layer_norm
464 + class_embedding
465 + position_ids
466 + position_embedding
467 + patch_embedding
468 + cfg.num_hidden_layers * encoder_layer_elems
469}
470
471pub struct Phi3VLoader;
477
478pub struct Phi3VPrefixer;
479
480impl MultimodalPromptPrefixer for Phi3VPrefixer {
481 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
482 format!(
484 "{}{prompt}",
485 image_indexes
486 .into_iter()
487 .map(|image_index| format!("<|image_{}|>", image_index + 1))
488 .join("")
489 )
490 }
491}
492
493impl VisionModelLoader for Phi3VLoader {
494 fn load(
495 &self,
496 config: &str,
497 vb: ShardedVarBuilder,
498 normal_loading_metadata: NormalLoadingMetadata,
499 attention_mechanism: AttentionImplementation,
500 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
501 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
502 Ok(Box::new(Phi3::new(
503 &cfg,
504 vb,
505 self.is_gptx(config),
506 normal_loading_metadata,
507 attention_mechanism,
508 )?))
509 }
510 fn is_gptx(&self, _config: &str) -> bool {
511 true
512 }
513 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
514 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
515 Ok(Box::new(cfg))
516 }
517 fn get_processor(
518 &self,
519 _model_config: &str,
520 processor_config: Option<ProcessorConfig>,
521 preprocessor_config: PreProcessorConfig,
522 _max_edge: Option<u32>,
523 ) -> Arc<dyn Processor + Send + Sync> {
524 Phi3Processor::new_processor(processor_config, preprocessor_config)
525 }
526 fn supports_paged_attention(&self, _config: &str) -> bool {
527 true
528 }
529 fn supports_prefix_cacher(&self, _config: &str) -> bool {
530 true
531 }
532 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
533 Arc::new(Phi3VPrefixer)
534 }
535 fn modalities(&self, _config: &str) -> Result<Modalities> {
536 Ok(Modalities {
537 input: vec![SupportedModality::Text, SupportedModality::Vision],
538 output: vec![SupportedModality::Text],
539 })
540 }
541}
542
543impl IsqModelLoader for Phi3VLoader {
544 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
545 Ok(vec![
546 Regex::new(r"lm_head\.(weight|bias)$")?,
547 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
549 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
550 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
552 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
553 ])
554 }
555 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
556 self.isq_layer_regexes(config)
557 }
558}
559
560impl DeviceMappedModelLoader for Phi3VLoader {
561 fn mapped_max_act_size_elems(
562 &self,
563 config: &str,
564 params: &AutoDeviceMapParams,
565 ) -> Result<usize> {
566 let AutoDeviceMapParams::Vision {
568 max_seq_len,
569 max_batch_size,
570 max_image_shape: _,
571 max_num_images,
572 } = params
573 else {
574 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
575 };
576
577 let cfg: Phi3Config = serde_json::from_str(config)?;
578
579 let vcfg = &PHI3V_CLIP_CONFIG;
580
581 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
582 let img_seq_len = (num_patches + 1) * max_num_images;
583
584 let max_text_attn = {
585 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
587 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
588 };
589
590 Ok(max_text_attn)
591 }
592
593 fn non_mapped_max_act_size_elems(
594 &self,
595 config: &str,
596 params: &AutoDeviceMapParams,
597 ) -> Result<usize> {
598 let AutoDeviceMapParams::Vision {
600 max_seq_len: _,
601 max_batch_size,
602 max_image_shape: _,
603 max_num_images,
604 } = params
605 else {
606 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
607 };
608
609 let cfg: Phi3Config = serde_json::from_str(config)?;
610
611 let vcfg = &PHI3V_CLIP_CONFIG;
612
613 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
614 let img_seq_len = num_patches + 1;
615
616 let max_vision_attn = {
617 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
618 };
619
620 Ok(max_vision_attn)
621 }
622
623 fn non_mapped_size_in_bytes(
624 &self,
625 config: &str,
626 dtype: DType,
627 weight_pack_factor: usize,
628 _matformer_config: Option<&MatformerSliceConfig>,
629 ) -> Result<usize> {
630 let cfg: Phi3Config = serde_json::from_str(config)?;
631 let elems = {
632 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
633 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
635 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
636 } else {
637 0
638 };
639 let norm = cfg.hidden_size;
640
641 let image_embed = {
642 let projection_cls = cfg
643 .embd_layer
644 .projection_cls
645 .clone()
646 .unwrap_or("linear".to_string());
647 let with_learnable_separator =
648 cfg.embd_layer.with_learnable_separator.unwrap_or(false);
649 let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
650 let image_dim_out = cfg.img_processor.image_dim_out;
651
652 let proj = match (projection_cls.as_str(), use_hd_transform) {
653 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
654 ("mlp", true) => {
655 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
656 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
657 a + b
658 }
659 ("mlp", false) => {
660 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
661 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
662 a + b
663 }
664 _ => {
665 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
666 }
667 };
668
669 let (glb_gn, sub_gn) = if with_learnable_separator {
670 let glb_gn = image_dim_out * 4;
671 let sub_gn = image_dim_out * 4;
672 (glb_gn, sub_gn)
673 } else {
674 (0, 0)
675 };
676
677 let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
678
679 proj + glb_gn + sub_gn + clip_vit
680 };
681
682 embed_tokens + lm_head + norm + image_embed
683 };
684
685 Ok(elems * dtype.size_in_bytes())
686 }
687
688 fn layer_sizes_in_bytes(
689 &self,
690 config: &str,
691 dtype: DType,
692 weight_pack_factor: usize,
693 _matformer_config: Option<&MatformerSliceConfig>,
694 ) -> Result<Vec<usize>> {
695 let cfg: Phi3Config = serde_json::from_str(config)?;
696 let per_layer_elems = {
697 let input_layernorm = cfg.hidden_size;
698 let post_attention_layernorm = cfg.hidden_size;
699
700 let size_in = cfg.hidden_size;
701 let head_dim = cfg.head_dim();
702 let op_size =
703 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
704 let qkv_proj = size_in * op_size / weight_pack_factor;
705 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
706
707 let h_size = cfg.hidden_size;
708 let i_size = cfg.intermediate_size;
709 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
710 let down_proj = h_size * i_size / weight_pack_factor;
711
712 input_layernorm
713 + post_attention_layernorm
714 + qkv_proj
715 + o_proj
716 + gate_up_proj
717 + down_proj
718 };
719 Ok(vec![
720 per_layer_elems * dtype.size_in_bytes();
721 cfg.num_hidden_layers
722 ])
723 }
724
725 fn num_layers(&self, config: &str) -> Result<usize> {
726 let cfg: Phi3Config = serde_json::from_str(config)?;
727 Ok(cfg.num_hidden_layers)
728 }
729
730 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
731 let cfg: Phi3Config = serde_json::from_str(config)?;
732
733 let cfg = ModelConfigMetadata {
734 max_seq_len: cfg.max_position_embeddings,
735 num_layers: cfg.num_hidden_layers,
736 hidden_size: cfg.hidden_size,
737 num_kv_heads: cfg.num_key_value_heads,
738 num_attn_heads: cfg.num_attention_heads,
739 sliding_window: cfg.sliding_window,
740 k_head_dim: cfg.head_dim(),
741 v_head_dim: cfg.head_dim(),
742 };
743
744 Ok(Box::new(cfg))
745 }
746
747 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
748 Some(vec![NonMappedSubModel::Vision])
749 }
750}
751
752pub struct Idefics2Loader;
758
759pub struct Idefics2Prefixer;
760
761impl MultimodalPromptPrefixer for Idefics2Prefixer {
762 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
763 prompt.to_string()
765 }
766}
767
768impl VisionModelLoader for Idefics2Loader {
769 fn load(
770 &self,
771 config: &str,
772 vb: ShardedVarBuilder,
773 normal_loading_metadata: NormalLoadingMetadata,
774 attention_mechanism: AttentionImplementation,
775 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
776 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
777 Ok(Box::new(Idefics2::new(
778 &cfg,
779 vb,
780 self.is_gptx(config),
781 normal_loading_metadata,
782 attention_mechanism,
783 )?))
784 }
785 fn is_gptx(&self, _config: &str) -> bool {
786 true
787 }
788 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
789 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
790 Ok(Box::new(cfg))
791 }
792 fn get_processor(
793 &self,
794 _model_config: &str,
795 processor_config: Option<ProcessorConfig>,
796 preprocessor_config: PreProcessorConfig,
797 max_edge: Option<u32>,
798 ) -> Arc<dyn Processor + Send + Sync> {
799 Arc::new(Idefics2Processor::new(
800 processor_config.unwrap(),
801 preprocessor_config,
802 max_edge,
803 ))
804 }
805 fn supports_paged_attention(&self, _config: &str) -> bool {
806 true
807 }
808 fn supports_prefix_cacher(&self, _config: &str) -> bool {
809 true
810 }
811 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
812 Arc::new(Idefics2Prefixer)
813 }
814 fn modalities(&self, _config: &str) -> Result<Modalities> {
815 Ok(Modalities {
816 input: vec![SupportedModality::Text, SupportedModality::Vision],
817 output: vec![SupportedModality::Text],
818 })
819 }
820}
821
822impl IsqModelLoader for Idefics2Loader {
823 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
824 Ok(vec![
825 Regex::new(r"lm_head\.(weight|bias)$")?,
826 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
828 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
829 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
830 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
831 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
833 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
834 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
835 ])
836 }
837 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
838 Ok(vec![
839 Regex::new(r"lm_head\.(weight|bias)$")?,
840 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
842 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
843 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
844 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
845 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
847 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
848 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
849 ])
850 }
851}
852
853impl DeviceMappedModelLoader for Idefics2Loader {
854 fn mapped_max_act_size_elems(
855 &self,
856 config: &str,
857 params: &AutoDeviceMapParams,
858 ) -> Result<usize> {
859 let AutoDeviceMapParams::Vision {
860 max_seq_len,
861 max_batch_size,
862 max_image_shape: _,
863 max_num_images,
864 } = params
865 else {
866 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
867 };
868
869 let cfg: Idefics2Config = serde_json::from_str(config)?;
870
871 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
872 let img_seq_len = (num_patches + 1) * max_num_images;
873
874 let max_text_attn = {
875 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
877 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
878 };
879
880 Ok(max_text_attn)
881 }
882
883 fn non_mapped_max_act_size_elems(
884 &self,
885 config: &str,
886 params: &AutoDeviceMapParams,
887 ) -> Result<usize> {
888 let AutoDeviceMapParams::Vision {
889 max_seq_len: _,
890 max_batch_size,
891 max_image_shape: _,
892 max_num_images,
893 } = params
894 else {
895 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
896 };
897
898 let cfg: Idefics2Config = serde_json::from_str(config)?;
899
900 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
901 let img_seq_len = num_patches + 1;
902
903 let max_vision_attn = {
904 let images_factor = 5;
906
907 (max_batch_size * images_factor * max_num_images)
908 * cfg.vision_config.num_attention_heads
909 * img_seq_len
910 * img_seq_len
911 };
912
913 Ok(max_vision_attn)
914 }
915
916 fn non_mapped_size_in_bytes(
917 &self,
918 config: &str,
919 dtype: DType,
920 weight_pack_factor: usize,
921 _matformer_config: Option<&MatformerSliceConfig>,
922 ) -> Result<usize> {
923 let cfg: Idefics2Config = serde_json::from_str(config)?;
924 let text_elems = {
925 let tie_word_embeddings = cfg.tie_word_embeddings;
926 let cfg = &cfg.text_config;
927
928 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
929 let lm_head = if !tie_word_embeddings {
930 cfg.hidden_size * cfg.vocab_size
931 } else {
932 0
933 };
934 let norm = cfg.hidden_size;
935 embed_tokens + lm_head + norm
936 };
937
938 let connector_elems = {
939 let tcfg = &cfg.text_config;
940 let vcfg = &cfg.vision_config;
941 let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
942 let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
943 let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
944
945 let perceiver_elems = {
946 let tcfg = &cfg.text_config;
947 let pcfg = &cfg.perceiver_config;
948
949 let n_latents = pcfg.resampler_n_latents;
950 let hidden_size = tcfg.hidden_size;
951 let depth = pcfg.resampler_depth;
952
953 let norm = tcfg.hidden_size;
954 let latents = n_latents * hidden_size;
955
956 let layer_elems = {
957 let input_latents_norm = hidden_size;
958 let input_context_norm = hidden_size;
959 let post_attn_norm = hidden_size;
960
961 let num_heads = pcfg.resampler_n_heads;
962 let head_dim = pcfg.resampler_head_dim;
963 let num_key_value_heads = pcfg.num_key_value_heads;
964
965 let q_proj = hidden_size * num_heads * head_dim;
966 let k_proj = hidden_size * num_key_value_heads * head_dim;
967 let v_proj = hidden_size * num_key_value_heads * head_dim;
968 let o_proj = num_heads * head_dim * hidden_size;
969
970 let gate_proj = hidden_size * hidden_size * 4;
971 let up_proj = hidden_size * hidden_size * 4;
972 let down_proj = hidden_size * 4 * hidden_size;
973
974 input_latents_norm
975 + input_context_norm
976 + post_attn_norm
977 + q_proj
978 + k_proj
979 + v_proj
980 + o_proj
981 + gate_proj
982 + up_proj
983 + down_proj
984 };
985
986 norm + latents + layer_elems * depth
987 };
988
989 gate_proj + up_proj + down_proj + perceiver_elems
990 };
991
992 let vision_transformer = {
993 let cfg = &cfg.vision_config;
994
995 let post_layernorm = cfg.hidden_size;
996
997 let conv_config = Conv2dConfig {
998 stride: cfg.patch_size,
999 ..Default::default()
1000 };
1001 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
1002 * cfg.patch_size
1003 * cfg.patch_size;
1004
1005 let num_patches_per_side = cfg.image_size / cfg.patch_size;
1006 let num_patches = num_patches_per_side.pow(2);
1007 let position_embedding = num_patches * cfg.hidden_size;
1008
1009 let layer_elems = {
1010 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1011 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1012
1013 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
1014 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
1015
1016 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1017 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1018 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1019 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1020
1021 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
1022 };
1023
1024 post_layernorm + patch_embedding + position_embedding + layer_elems
1025 };
1026
1027 let elems = text_elems + connector_elems + vision_transformer;
1028
1029 Ok(elems * dtype.size_in_bytes())
1030 }
1031
1032 fn layer_sizes_in_bytes(
1033 &self,
1034 config: &str,
1035 dtype: DType,
1036 weight_pack_factor: usize,
1037 _matformer_config: Option<&MatformerSliceConfig>,
1038 ) -> Result<Vec<usize>> {
1039 let cfg: Idefics2Config = serde_json::from_str(config)?;
1040 let cfg = cfg.text_config;
1041 let per_layer_elems = {
1042 let input_layernorm = cfg.hidden_size;
1043 let post_attention_layernorm = cfg.hidden_size;
1044
1045 let size_in = cfg.hidden_size;
1046 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1047 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1048 let q_proj = size_in * size_q / weight_pack_factor;
1049 let k_proj = size_in * size_kv / weight_pack_factor;
1050 let v_proj = size_in * size_kv / weight_pack_factor;
1051 let o_proj = size_q * size_in / weight_pack_factor;
1052
1053 let h_size = cfg.hidden_size;
1054 let i_size = cfg.intermediate_size;
1055 let gate_proj = h_size * i_size / weight_pack_factor;
1056 let up_proj = h_size * i_size / weight_pack_factor;
1057 let down_proj = i_size * h_size / weight_pack_factor;
1058
1059 input_layernorm
1060 + post_attention_layernorm
1061 + q_proj
1062 + k_proj
1063 + v_proj
1064 + o_proj
1065 + gate_proj
1066 + up_proj
1067 + down_proj
1068 };
1069 Ok(vec![
1070 per_layer_elems * dtype.size_in_bytes();
1071 cfg.num_hidden_layers
1072 ])
1073 }
1074
1075 fn num_layers(&self, config: &str) -> Result<usize> {
1076 let cfg: Idefics2Config = serde_json::from_str(config)?;
1077 Ok(cfg.text_config.num_hidden_layers)
1078 }
1079 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1080 let cfg: Idefics2Config = serde_json::from_str(config)?;
1081 let cfg = &cfg.text_config;
1082
1083 let cfg = ModelConfigMetadata {
1084 max_seq_len: cfg.max_position_embeddings,
1085 num_layers: cfg.num_hidden_layers,
1086 hidden_size: cfg.hidden_size,
1087 num_kv_heads: cfg.num_key_value_heads,
1088 num_attn_heads: cfg.num_attention_heads,
1089 sliding_window: cfg.sliding_window,
1090 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1091 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1092 };
1093
1094 Ok(Box::new(cfg))
1095 }
1096
1097 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1098 Some(vec![NonMappedSubModel::Vision])
1099 }
1100}
1101
1102pub struct LLaVANextLoader;
1108
1109pub struct LLaVANextPrefixer;
1110
1111impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1112 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1113 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1114 }
1115}
1116
1117impl VisionModelLoader for LLaVANextLoader {
1118 fn load(
1119 &self,
1120 config: &str,
1121 vb: ShardedVarBuilder,
1122 normal_loading_metadata: NormalLoadingMetadata,
1123 attention_mechanism: AttentionImplementation,
1124 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1125 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1126 Ok(Box::new(LLaVANext::new(
1127 &cfg,
1128 vb,
1129 self.is_gptx(config),
1130 normal_loading_metadata,
1131 attention_mechanism,
1132 )?))
1133 }
1134 fn is_gptx(&self, _config: &str) -> bool {
1135 false
1136 }
1137 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1138 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1139 Ok(Box::new(cfg))
1140 }
1141 fn get_processor(
1142 &self,
1143 model_config: &str,
1144 _processor_config: Option<ProcessorConfig>,
1145 _preprocessor_config: PreProcessorConfig,
1146 _max_edge: Option<u32>,
1147 ) -> Arc<dyn Processor + Send + Sync> {
1148 Arc::new(LLaVANextProcessor::new(model_config))
1149 }
1150 fn supports_paged_attention(&self, _config: &str) -> bool {
1151 true
1152 }
1153 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1154 true
1155 }
1156 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1157 Arc::new(LLaVANextPrefixer)
1158 }
1159 fn modalities(&self, _config: &str) -> Result<Modalities> {
1160 Ok(Modalities {
1161 input: vec![SupportedModality::Text, SupportedModality::Vision],
1162 output: vec![SupportedModality::Text],
1163 })
1164 }
1165}
1166
1167impl IsqModelLoader for LLaVANextLoader {
1168 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1169 Ok(vec![
1170 Regex::new(r"lm_head\.(weight|bias)$")?,
1171 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1173 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1174 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1175 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1176 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1178 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1179 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1180 ])
1181 }
1182 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1183 Ok(vec![
1184 Regex::new(r"lm_head\.(weight|bias)$")?,
1185 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1187 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1188 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1189 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1190 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1192 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1193 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1194 ])
1195 }
1196}
1197
1198impl DeviceMappedModelLoader for LLaVANextLoader {
1199 fn mapped_max_act_size_elems(
1200 &self,
1201 config: &str,
1202 params: &AutoDeviceMapParams,
1203 ) -> Result<usize> {
1204 let AutoDeviceMapParams::Vision {
1205 max_seq_len,
1206 max_batch_size,
1207 max_image_shape,
1208 max_num_images,
1209 } = params
1210 else {
1211 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1212 };
1213
1214 let config: LLaVAConfig = serde_json::from_str(config)?;
1215
1216 #[allow(clippy::cast_possible_truncation)]
1217 let img_seq_len =
1218 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1219 &config,
1220 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1221 );
1222 let img_seq_len = img_seq_len * max_num_images;
1223
1224 let max_text_attn = {
1225 let cfg = &config.text_config;
1226 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1228
1229 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1230 };
1231
1232 Ok(max_text_attn)
1233 }
1234
1235 fn non_mapped_max_act_size_elems(
1236 &self,
1237 config: &str,
1238 params: &AutoDeviceMapParams,
1239 ) -> Result<usize> {
1240 let AutoDeviceMapParams::Vision {
1241 max_seq_len: _,
1242 max_batch_size,
1243 max_image_shape,
1244 max_num_images,
1245 } = params
1246 else {
1247 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1248 };
1249
1250 let config: LLaVAConfig = serde_json::from_str(config)?;
1251
1252 #[allow(clippy::cast_possible_truncation)]
1253 let img_seq_len =
1254 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1255 &config,
1256 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1257 );
1258
1259 let max_vision_attn = {
1260 (max_batch_size * max_num_images)
1261 * config.vision_config.num_attention_heads
1262 * img_seq_len
1263 * img_seq_len
1264 };
1265
1266 Ok(max_vision_attn)
1267 }
1268
1269 fn non_mapped_size_in_bytes(
1270 &self,
1271 config: &str,
1272 dtype: DType,
1273 weight_pack_factor: usize,
1274 _matformer_config: Option<&MatformerSliceConfig>,
1275 ) -> Result<usize> {
1276 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1277 let text_elems = {
1278 let cfg = &cfg.text_config;
1279 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1280 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1281 let norm = cfg.hidden_size;
1282 embed_tokens + lm_head + norm
1283 };
1284
1285 let image_newline = cfg.text_config.hidden_size;
1286 let mmproj = {
1287 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1288 + cfg.text_config.hidden_size;
1289 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1290 + cfg.text_config.hidden_size;
1291
1292 linear_1 + linear_2
1293 };
1294 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1295
1296 let elems = text_elems + image_newline + mmproj + vision_tower;
1297 Ok(elems * dtype.size_in_bytes())
1298 }
1299
1300 fn layer_sizes_in_bytes(
1301 &self,
1302 config: &str,
1303 dtype: DType,
1304 weight_pack_factor: usize,
1305 _matformer_config: Option<&MatformerSliceConfig>,
1306 ) -> Result<Vec<usize>> {
1307 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1308 let per_layer_elems = {
1309 let cfg = &cfg.text_config;
1310 let input_layernorm = cfg.hidden_size;
1311 let post_attention_layernorm = cfg.hidden_size;
1312
1313 let size_in = cfg.hidden_size;
1314 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1315 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1316 let q_proj = size_in * size_q / weight_pack_factor;
1317 let k_proj = size_in * size_kv / weight_pack_factor;
1318 let v_proj = size_in * size_kv / weight_pack_factor;
1319 let o_proj = size_q * size_in / weight_pack_factor;
1320
1321 let h_size = cfg.hidden_size;
1322 let i_size = cfg.intermediate_size;
1323 let gate_proj = h_size * i_size / weight_pack_factor;
1324 let up_proj = h_size * i_size / weight_pack_factor;
1325 let down_proj = i_size * h_size / weight_pack_factor;
1326
1327 input_layernorm
1328 + post_attention_layernorm
1329 + q_proj
1330 + k_proj
1331 + v_proj
1332 + o_proj
1333 + gate_proj
1334 + up_proj
1335 + down_proj
1336 };
1337 Ok(vec![
1338 per_layer_elems * dtype.size_in_bytes();
1339 cfg.text_config.num_hidden_layers
1340 ])
1341 }
1342
1343 fn num_layers(&self, config: &str) -> Result<usize> {
1344 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1345 Ok(cfg.text_config.num_hidden_layers)
1346 }
1347
1348 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1349 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1350 let cfg = &cfg.text_config;
1351
1352 let cfg = ModelConfigMetadata {
1353 max_seq_len: cfg.max_position_embeddings,
1354 num_layers: cfg.num_hidden_layers,
1355 hidden_size: cfg.hidden_size,
1356 num_kv_heads: cfg.num_key_value_heads,
1357 num_attn_heads: cfg.num_attention_heads,
1358 sliding_window: cfg.sliding_window,
1359 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1360 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1361 };
1362
1363 Ok(Box::new(cfg))
1364 }
1365
1366 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1367 Some(vec![NonMappedSubModel::Vision])
1368 }
1369}
1370
1371pub struct LLaVALoader;
1377
1378pub struct LLaVAPrefixer;
1379
1380impl MultimodalPromptPrefixer for LLaVAPrefixer {
1381 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1382 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1383 }
1384}
1385
1386impl VisionModelLoader for LLaVALoader {
1387 fn load(
1388 &self,
1389 config: &str,
1390 vb: ShardedVarBuilder,
1391 normal_loading_metadata: NormalLoadingMetadata,
1392 attention_mechanism: AttentionImplementation,
1393 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1394 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1395 Ok(Box::new(LLaVA::new(
1396 &cfg,
1397 vb,
1398 self.is_gptx(config),
1399 normal_loading_metadata,
1400 attention_mechanism,
1401 )?))
1402 }
1403 fn is_gptx(&self, _config: &str) -> bool {
1404 false
1405 }
1406 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1407 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1408 Ok(Box::new(cfg))
1409 }
1410 fn get_processor(
1411 &self,
1412 model_config: &str,
1413 _processor_config: Option<ProcessorConfig>,
1414 _preprocessor_config: PreProcessorConfig,
1415 _max_edge: Option<u32>,
1416 ) -> Arc<dyn Processor + Send + Sync> {
1417 Arc::new(LLaVAProcessor::new(model_config))
1418 }
1419 fn supports_paged_attention(&self, _config: &str) -> bool {
1420 true
1421 }
1422 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1423 true
1424 }
1425 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1426 Arc::new(LLaVAPrefixer)
1427 }
1428 fn modalities(&self, _config: &str) -> Result<Modalities> {
1429 Ok(Modalities {
1430 input: vec![SupportedModality::Text, SupportedModality::Vision],
1431 output: vec![SupportedModality::Text],
1432 })
1433 }
1434}
1435
1436impl IsqModelLoader for LLaVALoader {
1437 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1438 Ok(vec![
1439 Regex::new(r"lm_head\.(weight|bias)$")?,
1440 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1442 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1443 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1444 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1445 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1447 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1448 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1449 ])
1450 }
1451 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1452 Ok(vec![
1453 Regex::new(r"lm_head\.(weight|bias)$")?,
1454 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1456 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1457 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1458 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1459 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1461 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1462 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1463 ])
1464 }
1465}
1466
1467impl DeviceMappedModelLoader for LLaVALoader {
1468 fn mapped_max_act_size_elems(
1469 &self,
1470 config: &str,
1471 params: &AutoDeviceMapParams,
1472 ) -> Result<usize> {
1473 let AutoDeviceMapParams::Vision {
1474 max_seq_len,
1475 max_batch_size,
1476 max_image_shape: _,
1477 max_num_images,
1478 } = params
1479 else {
1480 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1481 };
1482
1483 let config: LLaVAConfig = serde_json::from_str(config)?;
1484
1485 let img_seq_len =
1486 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1487 let img_seq_len = img_seq_len * max_num_images;
1488
1489 let max_text_attn = {
1490 let cfg = &config.text_config;
1491 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1493
1494 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1495 };
1496
1497 Ok(max_text_attn)
1498 }
1499
1500 fn non_mapped_max_act_size_elems(
1501 &self,
1502 config: &str,
1503 params: &AutoDeviceMapParams,
1504 ) -> Result<usize> {
1505 let AutoDeviceMapParams::Vision {
1506 max_seq_len: _,
1507 max_batch_size,
1508 max_image_shape: _,
1509 max_num_images,
1510 } = params
1511 else {
1512 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1513 };
1514
1515 let config: LLaVAConfig = serde_json::from_str(config)?;
1516
1517 let img_seq_len =
1518 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1519
1520 let max_vision_attn = {
1521 (max_batch_size * max_num_images)
1522 * config.vision_config.num_attention_heads
1523 * img_seq_len
1524 * img_seq_len
1525 };
1526
1527 Ok(max_vision_attn)
1528 }
1529
1530 fn non_mapped_size_in_bytes(
1531 &self,
1532 config: &str,
1533 dtype: DType,
1534 weight_pack_factor: usize,
1535 _matformer_config: Option<&MatformerSliceConfig>,
1536 ) -> Result<usize> {
1537 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1538 let text_elems = {
1539 let cfg = &cfg.text_config;
1540 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1541 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1542 let norm = cfg.hidden_size;
1543 embed_tokens + lm_head + norm
1544 };
1545
1546 let image_newline = cfg.text_config.hidden_size;
1547 let mmproj = {
1548 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1549 + cfg.text_config.hidden_size;
1550 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1551 + cfg.text_config.hidden_size;
1552
1553 linear_1 + linear_2
1554 };
1555 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1556
1557 let elems = text_elems + image_newline + mmproj + vision_tower;
1558 Ok(elems * dtype.size_in_bytes())
1559 }
1560
1561 fn layer_sizes_in_bytes(
1562 &self,
1563 config: &str,
1564 dtype: DType,
1565 weight_pack_factor: usize,
1566 _matformer_config: Option<&MatformerSliceConfig>,
1567 ) -> Result<Vec<usize>> {
1568 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1569 let per_layer_elems = {
1570 let cfg = &cfg.text_config;
1571 let input_layernorm = cfg.hidden_size;
1572 let post_attention_layernorm = cfg.hidden_size;
1573
1574 let size_in = cfg.hidden_size;
1575 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1576 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1577 let q_proj = size_in * size_q / weight_pack_factor;
1578 let k_proj = size_in * size_kv / weight_pack_factor;
1579 let v_proj = size_in * size_kv / weight_pack_factor;
1580 let o_proj = size_q * size_in / weight_pack_factor;
1581
1582 let h_size = cfg.hidden_size;
1583 let i_size = cfg.intermediate_size;
1584 let gate_proj = h_size * i_size / weight_pack_factor;
1585 let up_proj = h_size * i_size / weight_pack_factor;
1586 let down_proj = i_size * h_size / weight_pack_factor;
1587
1588 input_layernorm
1589 + post_attention_layernorm
1590 + q_proj
1591 + k_proj
1592 + v_proj
1593 + o_proj
1594 + gate_proj
1595 + up_proj
1596 + down_proj
1597 };
1598 Ok(vec![
1599 per_layer_elems * dtype.size_in_bytes();
1600 cfg.text_config.num_hidden_layers
1601 ])
1602 }
1603
1604 fn num_layers(&self, config: &str) -> Result<usize> {
1605 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1606 Ok(cfg.text_config.num_hidden_layers)
1607 }
1608
1609 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1610 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1611 let cfg = &cfg.text_config;
1612
1613 let cfg = ModelConfigMetadata {
1614 max_seq_len: cfg.max_position_embeddings,
1615 num_layers: cfg.num_hidden_layers,
1616 hidden_size: cfg.hidden_size,
1617 num_kv_heads: cfg.num_key_value_heads,
1618 num_attn_heads: cfg.num_attention_heads,
1619 sliding_window: cfg.sliding_window,
1620 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1621 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1622 };
1623
1624 Ok(Box::new(cfg))
1625 }
1626
1627 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1628 Some(vec![NonMappedSubModel::Vision])
1629 }
1630}
1631
1632pub struct VLlamaLoader;
1638
1639pub struct VLlamaPrefixer;
1640
1641impl MultimodalPromptPrefixer for VLlamaPrefixer {
1642 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1643 format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1644 }
1645}
1646
1647impl VisionModelLoader for VLlamaLoader {
1648 fn load(
1649 &self,
1650 config: &str,
1651 vb: ShardedVarBuilder,
1652 normal_loading_metadata: NormalLoadingMetadata,
1653 attention_mechanism: AttentionImplementation,
1654 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1655 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1656 Ok(Box::new(MLlamaModel::new(
1657 &cfg,
1658 vb,
1659 self.is_gptx(config),
1660 normal_loading_metadata,
1661 attention_mechanism,
1662 )?))
1663 }
1664 fn is_gptx(&self, _config: &str) -> bool {
1665 true
1666 }
1667 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1668 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1669 Ok(Box::new(cfg))
1670 }
1671 fn get_processor(
1672 &self,
1673 _model_config: &str,
1674 _processor_config: Option<ProcessorConfig>,
1675 _preprocessor_config: PreProcessorConfig,
1676 _max_edge: Option<u32>,
1677 ) -> Arc<dyn Processor + Send + Sync> {
1678 Arc::new(MLlamaProcessor::new())
1679 }
1680 fn supports_paged_attention(&self, _config: &str) -> bool {
1681 false
1682 }
1683 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1684 true
1685 }
1686 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1687 Arc::new(VLlamaPrefixer)
1688 }
1689 fn modalities(&self, _config: &str) -> Result<Modalities> {
1690 Ok(Modalities {
1691 input: vec![SupportedModality::Text, SupportedModality::Vision],
1692 output: vec![SupportedModality::Text],
1693 })
1694 }
1695}
1696
1697impl IsqModelLoader for VLlamaLoader {
1698 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1699 let config: MLlamaConfig = serde_json::from_str(config)?;
1700 let cross_attn_layers = &config.text_config.cross_attention_layers;
1701 let transformer_layers =
1702 (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1703 let mut text_regexes = Vec::new();
1704 for layer in transformer_layers {
1705 text_regexes.extend(vec![
1706 Regex::new(&format!(
1708 r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1709 ))?,
1710 Regex::new(&format!(
1711 r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1712 ))?,
1713 Regex::new(&format!(
1714 r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1715 ))?,
1716 Regex::new(&format!(
1717 r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1718 ))?,
1719 Regex::new(&format!(
1721 r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1722 ))?,
1723 Regex::new(&format!(
1724 r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1725 ))?,
1726 Regex::new(&format!(
1727 r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1728 ))?,
1729 ]);
1730 }
1731 let vision_regexes = vec![
1732 Regex::new(
1734 r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1735 )?,
1736 Regex::new(
1737 r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1738 )?,
1739 Regex::new(
1740 r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1741 )?,
1742 Regex::new(
1743 r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1744 )?,
1745 Regex::new(
1747 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1748 )?,
1749 Regex::new(
1750 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1751 )?,
1752 Regex::new(
1753 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1754 )?,
1755 Regex::new(
1756 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1757 )?,
1758 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1760 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1761 ];
1762
1763 Ok([text_regexes, vision_regexes].concat())
1764 }
1765 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1766 self.isq_layer_regexes(config)
1767 }
1768}
1769
1770impl DeviceMappedModelLoader for VLlamaLoader {
1771 fn mapped_max_act_size_elems(
1772 &self,
1773 config: &str,
1774 params: &AutoDeviceMapParams,
1775 ) -> Result<usize> {
1776 let AutoDeviceMapParams::Vision {
1777 max_seq_len,
1778 max_batch_size,
1779 max_image_shape: _,
1780 max_num_images,
1781 } = params
1782 else {
1783 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1784 };
1785
1786 let config: MLlamaConfig = serde_json::from_str(config)?;
1787
1788 let img_seq_len = {
1789 let cfg = &config.vision_config;
1790 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1791 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1792 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1793 };
1794 let img_seq_len = img_seq_len * max_num_images;
1795
1796 let max_cross_text_attn = {
1797 let cfg = &config.text_config;
1798 max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1799 };
1800
1801 let max_self_text_attn = {
1802 let cfg = &config.text_config;
1803 max_batch_size * cfg.num_attention_heads * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)
1804 };
1805
1806 Ok(max_self_text_attn.max(max_cross_text_attn))
1807 }
1808
1809 fn non_mapped_max_act_size_elems(
1810 &self,
1811 config: &str,
1812 params: &AutoDeviceMapParams,
1813 ) -> Result<usize> {
1814 let AutoDeviceMapParams::Vision {
1815 max_seq_len: _,
1816 max_batch_size,
1817 max_image_shape: _,
1818 max_num_images,
1819 } = params
1820 else {
1821 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1822 };
1823
1824 let config: MLlamaConfig = serde_json::from_str(config)?;
1825
1826 let img_seq_len = {
1827 let cfg = &config.vision_config;
1828 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1829 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1830 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1831 };
1832 let max_vision_attn = {
1833 let cfg = &config.vision_config;
1834 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1835 };
1836
1837 Ok(max_vision_attn)
1838 }
1839
1840 fn non_mapped_size_in_bytes(
1841 &self,
1842 config: &str,
1843 dtype: DType,
1844 weight_pack_factor: usize,
1845 _matformer_config: Option<&MatformerSliceConfig>,
1846 ) -> Result<usize> {
1847 let config: MLlamaConfig = serde_json::from_str(config)?;
1848 let text_elems = {
1849 let cfg = &config.text_config;
1850 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1851 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1853 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1854 } else {
1855 0
1856 };
1857 let norm = cfg.hidden_size;
1858 embed_tokens + lm_head + norm
1859 };
1860
1861 let vision_elems = {
1862 let cfg = &config.vision_config;
1863
1864 let conv_cfg = Conv2dConfig {
1865 stride: cfg.patch_size,
1866 ..Default::default()
1867 };
1868 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1869 * cfg.patch_size
1870 * cfg.patch_size;
1871
1872 let class_embedding = cfg.hidden_size;
1873
1874 let gated_positional_embedding = {
1875 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1876 let embedding = num_patches * cfg.hidden_size;
1877 let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1878 * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1879
1880 embedding + tile_embedding
1881 };
1882
1883 let pre_tile_positional_embedding =
1884 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1885 let post_tile_positional_embedding =
1886 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1887
1888 let layernorm_pre = cfg.hidden_size;
1889 let layernorm_post = cfg.hidden_size;
1890
1891 let encoder_layer = {
1892 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1893 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1894
1895 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1896 let q_proj =
1897 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1898 let k_proj =
1899 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1900 let v_proj =
1901 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1902 let o_proj =
1903 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1904
1905 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1906 + cfg.intermediate_size;
1907 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1908 + cfg.hidden_size;
1909
1910 input_layernorm
1911 + post_attention_layernorm
1912 + q_proj
1913 + k_proj
1914 + v_proj
1915 + o_proj
1916 + fc1
1917 + fc2
1918 };
1919
1920 patch_embedding
1921 + class_embedding
1922 + gated_positional_embedding
1923 + pre_tile_positional_embedding
1924 + post_tile_positional_embedding
1925 + layernorm_pre
1926 + layernorm_post
1927 + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1928 };
1929
1930 let elems = text_elems + vision_elems;
1931 Ok(elems * dtype.size_in_bytes())
1932 }
1933
1934 fn layer_sizes_in_bytes(
1935 &self,
1936 config: &str,
1937 dtype: DType,
1938 weight_pack_factor: usize,
1939 _matformer_config: Option<&MatformerSliceConfig>,
1940 ) -> Result<Vec<usize>> {
1941 let config: MLlamaConfig = serde_json::from_str(config)?;
1942 let cfg = &config.text_config;
1943
1944 let mut layer_sizes = Vec::new();
1945
1946 for i in 0..cfg.num_hidden_layers {
1947 let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1948 1
1950 } else {
1951 weight_pack_factor
1952 };
1953
1954 let per_layer_elems = {
1955 let input_layernorm = cfg.hidden_size;
1956 let post_attention_layernorm = cfg.hidden_size;
1957
1958 let size_in = cfg.hidden_size;
1959 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1960 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1961 let q_proj = size_in * size_q / weight_pack_factor;
1962 let k_proj = size_in * size_kv / weight_pack_factor;
1963 let v_proj = size_in * size_kv / weight_pack_factor;
1964 let o_proj = size_q * size_in / weight_pack_factor;
1965
1966 let h_size = cfg.hidden_size;
1967 let i_size = cfg.intermediate_size;
1968 let gate_proj = h_size * i_size / weight_pack_factor;
1969 let up_proj = h_size * i_size / weight_pack_factor;
1970 let down_proj = i_size * h_size / weight_pack_factor;
1971
1972 input_layernorm
1973 + post_attention_layernorm
1974 + q_proj
1975 + k_proj
1976 + v_proj
1977 + o_proj
1978 + gate_proj
1979 + up_proj
1980 + down_proj
1981 };
1982
1983 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1984 }
1985
1986 Ok(layer_sizes)
1987 }
1988
1989 fn num_layers(&self, config: &str) -> Result<usize> {
1990 let config: MLlamaConfig = serde_json::from_str(config)?;
1991 Ok(config.text_config.num_hidden_layers)
1992 }
1993
1994 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1995 let cfg: MLlamaConfig = serde_json::from_str(config)?;
1996 let cfg = &cfg.text_config;
1997
1998 let cfg = ModelConfigMetadata {
1999 max_seq_len: cfg.max_position_embeddings,
2000 num_layers: cfg.num_hidden_layers,
2001 hidden_size: cfg.hidden_size,
2002 num_kv_heads: cfg.num_key_value_heads,
2003 num_attn_heads: cfg.num_attention_heads,
2004 sliding_window: None,
2005 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2006 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2007 };
2008
2009 Ok(Box::new(cfg))
2010 }
2011
2012 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2013 Some(vec![NonMappedSubModel::Vision])
2014 }
2015}
2016
2017pub struct Qwen2VLLoader;
2023
2024pub struct Qwen2VLPrefixer;
2025
2026impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
2027 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2028 format!(
2029 "{}{prompt}",
2030 format!(
2031 "{}{}{}",
2032 Qwen2VLProcessor::VISION_START,
2033 Qwen2VLProcessor::IMAGE_PAD,
2034 Qwen2VLProcessor::VISION_END
2035 )
2036 .repeat(image_indexes.len())
2037 )
2038 }
2039}
2040
2041impl VisionModelLoader for Qwen2VLLoader {
2042 fn load(
2043 &self,
2044 config: &str,
2045 vb: ShardedVarBuilder,
2046 normal_loading_metadata: NormalLoadingMetadata,
2047 attention_mechanism: AttentionImplementation,
2048 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2049 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2050 Ok(Box::new(Qwen2VLModel::new(
2051 &cfg,
2052 vb,
2053 self.is_gptx(config),
2054 normal_loading_metadata,
2055 attention_mechanism,
2056 )?))
2057 }
2058 fn is_gptx(&self, _config: &str) -> bool {
2059 true
2060 }
2061 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2062 let config: Qwen2VLConfig = serde_json::from_str(config)?;
2063 Ok(Box::new(config))
2064 }
2065 fn get_processor(
2066 &self,
2067 _model_config: &str,
2068 _processor_config: Option<ProcessorConfig>,
2069 _preprocessor_config: PreProcessorConfig,
2070 max_edge: Option<u32>,
2071 ) -> Arc<dyn Processor + Send + Sync> {
2072 Arc::new(Qwen2VLProcessor::new(max_edge))
2073 }
2074 fn supports_paged_attention(&self, _config: &str) -> bool {
2075 false
2076 }
2077 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2078 Arc::new(Qwen2VLPrefixer)
2079 }
2080 fn modalities(&self, _config: &str) -> Result<Modalities> {
2081 Ok(Modalities {
2082 input: vec![SupportedModality::Text, SupportedModality::Vision],
2083 output: vec![SupportedModality::Text],
2084 })
2085 }
2086}
2087
2088impl IsqModelLoader for Qwen2VLLoader {
2089 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2090 Ok(vec![
2091 Regex::new(r"lm_head\.(weight|bias)$")?,
2092 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2094 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2095 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2096 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2097 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2099 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2100 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2101 ])
2102 }
2103 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2104 self.isq_layer_regexes(config)
2105 }
2106}
2107
2108impl DeviceMappedModelLoader for Qwen2VLLoader {
2109 fn mapped_max_act_size_elems(
2110 &self,
2111 config: &str,
2112 params: &AutoDeviceMapParams,
2113 ) -> Result<usize> {
2114 let AutoDeviceMapParams::Vision {
2115 max_seq_len,
2116 max_batch_size,
2117 max_image_shape,
2118 max_num_images,
2119 } = params
2120 else {
2121 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2122 };
2123
2124 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2125
2126 let img_seq_len = {
2127 let cfg = &cfg.vision_config;
2128 let grid_t = max_num_images / cfg.temporal_patch_size;
2129 let grid_h = max_image_shape.0 / cfg.patch_size;
2130 let grid_w = max_image_shape.1 / cfg.patch_size;
2131 grid_t * grid_h * grid_w
2132 };
2133 let img_seq_len = img_seq_len * max_num_images;
2134
2135 let max_text_attn = {
2136 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2138 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2139 };
2140
2141 Ok(max_text_attn)
2142 }
2143
2144 fn non_mapped_max_act_size_elems(
2145 &self,
2146 config: &str,
2147 params: &AutoDeviceMapParams,
2148 ) -> Result<usize> {
2149 let AutoDeviceMapParams::Vision {
2150 max_seq_len: _,
2151 max_batch_size,
2152 max_image_shape,
2153 max_num_images,
2154 } = params
2155 else {
2156 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2157 };
2158
2159 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2160
2161 let img_seq_len = {
2162 let cfg = &cfg.vision_config;
2163 let grid_t = max_num_images / cfg.temporal_patch_size;
2164 let grid_h = max_image_shape.0 / cfg.patch_size;
2165 let grid_w = max_image_shape.1 / cfg.patch_size;
2166 grid_t * grid_h * grid_w
2167 };
2168
2169 let max_vision_attn = {
2170 let cfg = &cfg.vision_config;
2171 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2172 };
2173
2174 Ok(max_vision_attn)
2175 }
2176
2177 fn non_mapped_size_in_bytes(
2178 &self,
2179 config: &str,
2180 dtype: DType,
2181 weight_pack_factor: usize,
2182 _matformer_config: Option<&MatformerSliceConfig>,
2183 ) -> Result<usize> {
2184 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2185 let text_elems = {
2186 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2187 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2189 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2190 } else {
2191 0
2192 };
2193 let norm = cfg.hidden_size;
2194 embed_tokens + lm_head + norm
2195 };
2196
2197 let patch_merger = {
2198 let cfg = &cfg.vision_config;
2199 let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2200
2201 let mlp0 = hidden_size * hidden_size + hidden_size;
2202 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2203
2204 let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2205
2206 mlp0 + mlp2 + ln_q
2207 };
2208
2209 let patch_embed = {
2210 let cfg = &cfg.vision_config;
2211 let conv_cfg = Conv3dConfig {
2212 stride: cfg.patch_size,
2213 ..Default::default()
2214 };
2215 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2216 cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2217 * kernel_sizes[0]
2218 * kernel_sizes[1]
2219 * kernel_sizes[2]
2220 };
2221
2222 let encoder_layer = {
2223 let cfg = &cfg.vision_config;
2224 let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2225 let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2226
2227 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2228 let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2229 let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2230 let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2231
2232 let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2233 let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2234
2235 norm1 + norm2 + fc1 + fc2 + qkv + out
2236 };
2237
2238 let elems =
2239 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2240
2241 Ok(elems * dtype.size_in_bytes())
2242 }
2243
2244 fn layer_sizes_in_bytes(
2245 &self,
2246 config: &str,
2247 dtype: DType,
2248 weight_pack_factor: usize,
2249 _matformer_config: Option<&MatformerSliceConfig>,
2250 ) -> Result<Vec<usize>> {
2251 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2252 let per_layer_elems = {
2253 let input_layernorm = cfg.hidden_size;
2254 let post_attention_layernorm = cfg.hidden_size;
2255
2256 let size_in = cfg.hidden_size;
2257 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2258 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2259 let q_proj = size_in * size_q / weight_pack_factor + size_q;
2260 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2261 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2262 let o_proj = size_q * size_in / weight_pack_factor;
2263
2264 let h_size = cfg.hidden_size;
2265 let i_size = cfg.intermediate_size;
2266 let gate_proj = h_size * i_size / weight_pack_factor;
2267 let up_proj = h_size * i_size / weight_pack_factor;
2268 let down_proj = i_size * h_size / weight_pack_factor;
2269
2270 input_layernorm
2271 + post_attention_layernorm
2272 + q_proj
2273 + k_proj
2274 + v_proj
2275 + o_proj
2276 + gate_proj
2277 + up_proj
2278 + down_proj
2279 };
2280 Ok(vec![
2281 per_layer_elems * dtype.size_in_bytes();
2282 cfg.num_hidden_layers
2283 ])
2284 }
2285
2286 fn num_layers(&self, config: &str) -> Result<usize> {
2287 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2288 Ok(cfg.num_hidden_layers)
2289 }
2290
2291 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2292 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2293
2294 let cfg = ModelConfigMetadata {
2295 max_seq_len: cfg.max_position_embeddings,
2296 num_layers: cfg.num_hidden_layers,
2297 hidden_size: cfg.hidden_size,
2298 num_kv_heads: cfg.num_key_value_heads,
2299 num_attn_heads: cfg.num_attention_heads,
2300 sliding_window: cfg.sliding_window,
2301 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2302 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2303 };
2304
2305 Ok(Box::new(cfg))
2306 }
2307
2308 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2309 Some(vec![NonMappedSubModel::Vision])
2310 }
2311}
2312
2313pub struct Idefics3Loader;
2319
2320pub struct Idefics3Prefixer;
2321
2322impl MultimodalPromptPrefixer for Idefics3Prefixer {
2323 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2324 prompt.to_string()
2326 }
2327}
2328
2329impl VisionModelLoader for Idefics3Loader {
2330 fn load(
2331 &self,
2332 config: &str,
2333 vb: ShardedVarBuilder,
2334 normal_loading_metadata: NormalLoadingMetadata,
2335 attention_mechanism: AttentionImplementation,
2336 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2337 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2338 Ok(Box::new(Idefics3Model::new(
2339 &cfg,
2340 vb,
2341 self.is_gptx(config),
2342 normal_loading_metadata,
2343 attention_mechanism,
2344 )?))
2345 }
2346 fn is_gptx(&self, _config: &str) -> bool {
2347 true
2348 }
2349 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2350 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2351 Ok(Box::new(cfg))
2352 }
2353 fn get_processor(
2354 &self,
2355 _model_config: &str,
2356 processor_config: Option<ProcessorConfig>,
2357 preprocessor_config: PreProcessorConfig,
2358 max_edge: Option<u32>,
2359 ) -> Arc<dyn Processor + Send + Sync> {
2360 Arc::new(Idefics3Processor::new(
2361 processor_config.unwrap_or_default(),
2362 preprocessor_config,
2363 max_edge,
2364 ))
2365 }
2366 fn supports_paged_attention(&self, _config: &str) -> bool {
2367 true
2368 }
2369 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2370 true
2371 }
2372 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2373 Arc::new(Idefics3Prefixer)
2374 }
2375 fn modalities(&self, _config: &str) -> Result<Modalities> {
2376 Ok(Modalities {
2377 input: vec![SupportedModality::Text, SupportedModality::Vision],
2378 output: vec![SupportedModality::Text],
2379 })
2380 }
2381}
2382
2383impl IsqModelLoader for Idefics3Loader {
2384 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2385 Ok(vec![
2386 Regex::new(r"lm_head\.(weight|bias)$")?,
2387 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2389 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2390 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2391 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2392 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2394 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2395 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2396 ])
2397 }
2398 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2399 Ok(vec![
2400 Regex::new(r"lm_head\.(weight|bias)$")?,
2401 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2403 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2404 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2405 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2406 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2408 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2409 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2410 ])
2427 }
2428}
2429
2430impl DeviceMappedModelLoader for Idefics3Loader {
2431 fn mapped_max_act_size_elems(
2432 &self,
2433 config: &str,
2434 params: &AutoDeviceMapParams,
2435 ) -> Result<usize> {
2436 let AutoDeviceMapParams::Vision {
2437 max_seq_len,
2438 max_batch_size,
2439 max_image_shape: _,
2440 max_num_images,
2441 } = params
2442 else {
2443 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2444 };
2445
2446 let cfg: Idefics3Config = serde_json::from_str(config)?;
2447
2448 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2449 let img_seq_len = (num_patches + 1) * max_num_images;
2450
2451 let max_text_attn = {
2452 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2454 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2455 };
2456
2457 Ok(max_text_attn)
2458 }
2459
2460 fn non_mapped_max_act_size_elems(
2461 &self,
2462 config: &str,
2463 params: &AutoDeviceMapParams,
2464 ) -> Result<usize> {
2465 let AutoDeviceMapParams::Vision {
2466 max_seq_len: _,
2467 max_batch_size,
2468 max_image_shape: _,
2469 max_num_images,
2470 } = params
2471 else {
2472 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2473 };
2474
2475 let cfg: Idefics3Config = serde_json::from_str(config)?;
2476
2477 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2478 let img_seq_len = num_patches + 1;
2479
2480 let max_vision_attn = {
2481 let images_factor = 5;
2483
2484 (max_batch_size * images_factor * max_num_images)
2485 * cfg.vision_config.num_attention_heads
2486 * img_seq_len
2487 * img_seq_len
2488 };
2489
2490 Ok(max_vision_attn)
2491 }
2492
2493 fn non_mapped_size_in_bytes(
2494 &self,
2495 config: &str,
2496 dtype: DType,
2497 weight_pack_factor: usize,
2498 _matformer_config: Option<&MatformerSliceConfig>,
2499 ) -> Result<usize> {
2500 let cfg: Idefics3Config = serde_json::from_str(config)?;
2501 let text_elems = {
2502 let cfg = &cfg.text_config;
2503
2504 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2505 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2506 let norm = cfg.hidden_size;
2507 embed_tokens + lm_head + norm
2508 };
2509
2510 let connector_elems = {
2511 let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2512 let out_dim = cfg.text_config.hidden_size;
2513
2514 in_dim * out_dim
2515 };
2516
2517 let vision_transformer = {
2518 let cfg = &cfg.vision_config;
2519
2520 let post_layernorm = cfg.hidden_size;
2521
2522 let conv_config = Conv2dConfig {
2523 stride: cfg.patch_size,
2524 ..Default::default()
2525 };
2526 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2527 * cfg.patch_size
2528 * cfg.patch_size;
2529
2530 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2531 let num_patches = num_patches_per_side.pow(2);
2532 let position_embedding = num_patches * cfg.hidden_size;
2533
2534 let layer_elems = {
2535 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2536 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2537
2538 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2539 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2540
2541 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2542 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2543 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2544 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2545
2546 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2547 };
2548
2549 post_layernorm
2550 + patch_embedding
2551 + position_embedding
2552 + layer_elems * cfg.num_hidden_layers
2553 };
2554
2555 let elems = text_elems + connector_elems + vision_transformer;
2556
2557 Ok(elems * dtype.size_in_bytes())
2558 }
2559
2560 fn layer_sizes_in_bytes(
2561 &self,
2562 config: &str,
2563 dtype: DType,
2564 weight_pack_factor: usize,
2565 _matformer_config: Option<&MatformerSliceConfig>,
2566 ) -> Result<Vec<usize>> {
2567 let cfg: Idefics3Config = serde_json::from_str(config)?;
2568 let cfg = cfg.text_config;
2569 let per_layer_elems = {
2570 let input_layernorm = cfg.hidden_size;
2571 let post_attention_layernorm = cfg.hidden_size;
2572
2573 let size_in = cfg.hidden_size;
2574 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2575 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2576 let q_proj = size_in * size_q / weight_pack_factor;
2577 let k_proj = size_in * size_kv / weight_pack_factor;
2578 let v_proj = size_in * size_kv / weight_pack_factor;
2579 let o_proj = size_q * size_in / weight_pack_factor;
2580
2581 let h_size = cfg.hidden_size;
2582 let i_size = cfg.intermediate_size;
2583 let gate_proj = h_size * i_size / weight_pack_factor;
2584 let up_proj = h_size * i_size / weight_pack_factor;
2585 let down_proj = i_size * h_size / weight_pack_factor;
2586
2587 input_layernorm
2588 + post_attention_layernorm
2589 + q_proj
2590 + k_proj
2591 + v_proj
2592 + o_proj
2593 + gate_proj
2594 + up_proj
2595 + down_proj
2596 };
2597 Ok(vec![
2598 per_layer_elems * dtype.size_in_bytes();
2599 cfg.num_hidden_layers
2600 ])
2601 }
2602
2603 fn num_layers(&self, config: &str) -> Result<usize> {
2604 let cfg: Idefics3Config = serde_json::from_str(config)?;
2605 Ok(cfg.text_config.num_hidden_layers)
2606 }
2607 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2608 let cfg: Idefics3Config = serde_json::from_str(config)?;
2609 let cfg = &cfg.text_config;
2610
2611 let cfg = ModelConfigMetadata {
2612 max_seq_len: cfg.max_position_embeddings,
2613 num_layers: cfg.num_hidden_layers,
2614 hidden_size: cfg.hidden_size,
2615 num_kv_heads: cfg.num_key_value_heads,
2616 num_attn_heads: cfg.num_attention_heads,
2617 sliding_window: None,
2618 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2619 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2620 };
2621
2622 Ok(Box::new(cfg))
2623 }
2624
2625 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2626 Some(vec![NonMappedSubModel::Vision])
2627 }
2628}
2629
2630pub struct MiniCpmOLoader;
2636
2637pub struct MiniCpmOPrefixer;
2638
2639impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2640 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2641 format!(
2642 "{}{prompt}",
2643 "(<image>./</image>)".repeat(image_indexes.len())
2644 )
2645 }
2646}
2647
2648impl VisionModelLoader for MiniCpmOLoader {
2649 fn load(
2650 &self,
2651 config: &str,
2652 vb: ShardedVarBuilder,
2653 normal_loading_metadata: NormalLoadingMetadata,
2654 attention_mechanism: AttentionImplementation,
2655 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2656 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2657 Ok(Box::new(MiniCpmOModel::new(
2658 &cfg,
2659 vb,
2660 self.is_gptx(config),
2661 normal_loading_metadata,
2662 attention_mechanism,
2663 )?))
2664 }
2665 fn is_gptx(&self, _config: &str) -> bool {
2666 true
2667 }
2668 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2669 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2670 Ok(Box::new(cfg))
2671 }
2672 fn get_processor(
2673 &self,
2674 _model_config: &str,
2675 processor_config: Option<ProcessorConfig>,
2676 preprocessor_config: PreProcessorConfig,
2677 max_edge: Option<u32>,
2678 ) -> Arc<dyn Processor + Send + Sync> {
2679 Arc::new(MiniCpmOProcessor::new(
2680 processor_config.unwrap_or_default(),
2681 preprocessor_config,
2682 max_edge,
2683 ))
2684 }
2685 fn supports_paged_attention(&self, _config: &str) -> bool {
2686 true
2687 }
2688 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2689 Arc::new(MiniCpmOPrefixer)
2690 }
2691 fn modalities(&self, _config: &str) -> Result<Modalities> {
2692 Ok(Modalities {
2693 input: vec![SupportedModality::Text, SupportedModality::Vision],
2694 output: vec![SupportedModality::Text],
2695 })
2696 }
2697}
2698
2699impl IsqModelLoader for MiniCpmOLoader {
2700 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2701 Ok(vec![
2702 Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2703 Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2705 Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2706 Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2707 Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2708 Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2710 Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2711 Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2712 ])
2713 }
2714 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2715 self.isq_layer_regexes(config)
2716 }
2717}
2718
2719impl DeviceMappedModelLoader for MiniCpmOLoader {
2720 fn mapped_max_act_size_elems(
2721 &self,
2722 config: &str,
2723 params: &AutoDeviceMapParams,
2724 ) -> Result<usize> {
2725 let AutoDeviceMapParams::Vision {
2726 max_seq_len,
2727 max_batch_size,
2728 max_image_shape: _,
2729 max_num_images,
2730 } = params
2731 else {
2732 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2733 };
2734
2735 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2736
2737 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2738 let img_seq_len = (num_patches + 1) * max_num_images;
2739
2740 let max_text_attn = {
2741 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2743 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2744 };
2745
2746 Ok(max_text_attn)
2747 }
2748
2749 fn non_mapped_max_act_size_elems(
2750 &self,
2751 config: &str,
2752 params: &AutoDeviceMapParams,
2753 ) -> Result<usize> {
2754 let AutoDeviceMapParams::Vision {
2755 max_seq_len: _,
2756 max_batch_size,
2757 max_image_shape: _,
2758 max_num_images,
2759 } = params
2760 else {
2761 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2762 };
2763
2764 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2765
2766 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2767 let img_seq_len = num_patches + 1;
2768
2769 let max_vision_attn = {
2770 let images_factor = 5;
2772
2773 (max_batch_size * images_factor * max_num_images)
2774 * cfg.vision_config.num_attention_heads
2775 * img_seq_len
2776 * img_seq_len
2777 };
2778
2779 Ok(max_vision_attn)
2780 }
2781
2782 fn non_mapped_size_in_bytes(
2783 &self,
2784 config: &str,
2785 dtype: DType,
2786 weight_pack_factor: usize,
2787 _matformer_config: Option<&MatformerSliceConfig>,
2788 ) -> Result<usize> {
2789 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2790 let text_elems = {
2791 let cfg = &cfg.text_config;
2792
2793 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2794 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2795 let norm = cfg.hidden_size;
2796 embed_tokens + lm_head + norm
2797 };
2798
2799 let vision_transformer = {
2800 let cfg = &cfg.vision_config;
2801
2802 let post_layernorm = cfg.hidden_size;
2803
2804 let conv_config = Conv2dConfig {
2805 stride: cfg.patch_size,
2806 ..Default::default()
2807 };
2808 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2809 * cfg.patch_size
2810 * cfg.patch_size;
2811
2812 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2813 let num_patches = num_patches_per_side.pow(2);
2814 let position_embedding = num_patches * cfg.hidden_size;
2815
2816 let layer_elems = {
2817 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2818 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2819
2820 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2821 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2822
2823 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2824 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2825 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2826 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2827
2828 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2829 };
2830
2831 post_layernorm
2832 + patch_embedding
2833 + position_embedding
2834 + layer_elems * cfg.num_hidden_layers
2835 };
2836
2837 let elems = text_elems + vision_transformer;
2838
2839 Ok(elems * dtype.size_in_bytes())
2840 }
2841
2842 fn layer_sizes_in_bytes(
2843 &self,
2844 config: &str,
2845 dtype: DType,
2846 weight_pack_factor: usize,
2847 _matformer_config: Option<&MatformerSliceConfig>,
2848 ) -> Result<Vec<usize>> {
2849 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2850 let cfg = cfg.text_config;
2851 let per_layer_elems = {
2852 let input_layernorm = cfg.hidden_size;
2853 let post_attention_layernorm = cfg.hidden_size;
2854
2855 let size_in = cfg.hidden_size;
2856 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2857 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2858 let q_proj = size_in * size_q / weight_pack_factor;
2859 let k_proj = size_in * size_kv / weight_pack_factor;
2860 let v_proj = size_in * size_kv / weight_pack_factor;
2861 let o_proj = size_q * size_in / weight_pack_factor;
2862
2863 let h_size = cfg.hidden_size;
2864 let i_size = cfg.intermediate_size;
2865 let gate_proj = h_size * i_size / weight_pack_factor;
2866 let up_proj = h_size * i_size / weight_pack_factor;
2867 let down_proj = i_size * h_size / weight_pack_factor;
2868
2869 input_layernorm
2870 + post_attention_layernorm
2871 + q_proj
2872 + k_proj
2873 + v_proj
2874 + o_proj
2875 + gate_proj
2876 + up_proj
2877 + down_proj
2878 };
2879 Ok(vec![
2880 per_layer_elems * dtype.size_in_bytes();
2881 cfg.num_hidden_layers
2882 ])
2883 }
2884
2885 fn num_layers(&self, config: &str) -> Result<usize> {
2886 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2887 Ok(cfg.text_config.num_hidden_layers)
2888 }
2889 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2890 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2891 let cfg = &cfg.text_config;
2892
2893 let cfg = ModelConfigMetadata {
2894 max_seq_len: cfg.max_position_embeddings,
2895 num_layers: cfg.num_hidden_layers,
2896 hidden_size: cfg.hidden_size,
2897 num_kv_heads: cfg.num_key_value_heads,
2898 num_attn_heads: cfg.num_attention_heads,
2899 sliding_window: None,
2900 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2901 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2902 };
2903
2904 Ok(Box::new(cfg))
2905 }
2906}
2907
2908pub struct Phi4MMLoader;
2914
2915pub struct Phi4MMPrefixer;
2916
2917impl MultimodalPromptPrefixer for Phi4MMPrefixer {
2918 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2919 format!(
2922 "{}{prompt}",
2923 image_indexes
2924 .into_iter()
2925 .map(|image_index| format!("<|image_{}|>", image_index + 1))
2926 .join("")
2927 )
2928 }
2929 fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
2930 format!(
2933 "{}{prompt}",
2934 audio_indexes
2935 .into_iter()
2936 .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
2937 .join("")
2938 )
2939 }
2940}
2941
2942impl VisionModelLoader for Phi4MMLoader {
2943 fn load(
2944 &self,
2945 config: &str,
2946 vb: ShardedVarBuilder,
2947 normal_loading_metadata: NormalLoadingMetadata,
2948 attention_mechanism: AttentionImplementation,
2949 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2950 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2951 Ok(Box::new(Phi4MMModel::new(
2952 &cfg,
2953 vb,
2954 self.is_gptx(config),
2955 normal_loading_metadata,
2956 attention_mechanism,
2957 )?))
2958 }
2959 fn is_gptx(&self, _config: &str) -> bool {
2960 true
2961 }
2962 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2963 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2964 Ok(Box::new(cfg))
2965 }
2966 fn get_processor(
2967 &self,
2968 _model_config: &str,
2969 processor_config: Option<ProcessorConfig>,
2970 preprocessor_config: PreProcessorConfig,
2971 _max_edge: Option<u32>,
2972 ) -> Arc<dyn Processor + Send + Sync> {
2973 Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2974 }
2975 fn supports_paged_attention(&self, _config: &str) -> bool {
2976 true
2977 }
2978 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2979 true
2980 }
2981 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2982 Arc::new(Phi4MMPrefixer)
2983 }
2984 fn modalities(&self, _config: &str) -> Result<Modalities> {
2985 Ok(Modalities {
2986 input: vec![
2987 SupportedModality::Text,
2988 SupportedModality::Vision,
2989 SupportedModality::Audio,
2990 ],
2991 output: vec![SupportedModality::Text],
2992 })
2993 }
2994}
2995
2996impl IsqModelLoader for Phi4MMLoader {
2997 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2998 Ok(vec![
2999 Regex::new(r"lm_head\.(weight|bias)$")?,
3000 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
3002 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3003 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
3005 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3006 ])
3007 }
3008 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3009 self.isq_layer_regexes(config)
3010 }
3011}
3012
3013impl DeviceMappedModelLoader for Phi4MMLoader {
3014 fn mapped_max_act_size_elems(
3015 &self,
3016 config: &str,
3017 params: &AutoDeviceMapParams,
3018 ) -> Result<usize> {
3019 let AutoDeviceMapParams::Vision {
3021 max_seq_len,
3022 max_batch_size,
3023 max_image_shape: _,
3024 max_num_images,
3025 } = params
3026 else {
3027 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3028 };
3029
3030 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3031
3032 let vcfg = &PHI4_MM_VISION_CFG;
3033
3034 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3035 let img_seq_len = (num_patches + 1) * max_num_images;
3036
3037 let max_text_attn = {
3038 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3040 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3041 };
3042
3043 Ok(max_text_attn)
3044 }
3045
3046 fn non_mapped_max_act_size_elems(
3047 &self,
3048 _config: &str,
3049 params: &AutoDeviceMapParams,
3050 ) -> Result<usize> {
3051 let AutoDeviceMapParams::Vision {
3052 max_seq_len: _,
3053 max_batch_size,
3054 max_image_shape,
3055 max_num_images,
3056 } = params
3057 else {
3058 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3059 };
3060
3061 let vcfg = &PHI4_MM_VISION_CFG;
3062
3063 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3064 let img_seq_len = num_patches + 1;
3065
3066 let max_batch_size = max_batch_size
3067 * (max_image_shape
3068 .0
3069 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3070 * max_image_shape
3071 .1
3072 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3073 + 1);
3074
3075 let max_vision_attn = (max_batch_size * max_num_images)
3076 * vcfg.num_attention_heads
3077 * img_seq_len
3078 * img_seq_len;
3079 let max_qkv = 3
3080 * (max_batch_size
3081 * vcfg.num_attention_heads
3082 * img_seq_len
3083 * (vcfg.hidden_size / vcfg.num_attention_heads));
3084
3085 Ok(max_vision_attn + max_qkv)
3086 }
3087
3088 fn non_mapped_size_in_bytes(
3089 &self,
3090 config: &str,
3091 dtype: DType,
3092 weight_pack_factor: usize,
3093 _matformer_config: Option<&MatformerSliceConfig>,
3094 ) -> Result<usize> {
3095 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3096 let elems = {
3097 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3098 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3100 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3101 } else {
3102 0
3103 };
3104 let norm = cfg.hidden_size;
3105
3106 let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3107 let projection_cls = img_embed
3108 .projection_cls
3109 .clone()
3110 .unwrap_or("linear".to_string());
3111 let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3112 let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3113 let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3114
3115 let proj = match (projection_cls.as_str(), use_hd_transform) {
3116 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3117 ("mlp", true) => {
3118 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3119 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3120 a + b
3121 }
3122 ("mlp", false) => {
3123 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3124 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3125 a + b
3126 }
3127 _ => {
3128 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3129 }
3130 };
3131
3132 let (glb_gn, sub_gn) = if with_learnable_separator {
3133 let glb_gn = image_dim_out * 4;
3134 let sub_gn = image_dim_out * 4;
3135 (glb_gn, sub_gn)
3136 } else {
3137 (0, 0)
3138 };
3139
3140 let vision_transformer = {
3141 let cfg = &PHI4_MM_VISION_CFG;
3142
3143 let post_layernorm = cfg.hidden_size;
3144
3145 let conv_config = Conv2dConfig {
3146 stride: cfg.patch_size,
3147 ..Default::default()
3148 };
3149 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3150 * cfg.patch_size
3151 * cfg.patch_size;
3152
3153 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3154 let num_patches = num_patches_per_side.pow(2);
3155 let position_embedding = num_patches * cfg.hidden_size;
3156
3157 let layer_elems = {
3158 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3159 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3160
3161 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3162 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3163
3164 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3165 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3166 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3167 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3168
3169 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3170 };
3171
3172 post_layernorm
3173 + patch_embedding
3174 + position_embedding
3175 + layer_elems * cfg.num_hidden_layers
3176 };
3177
3178 proj + glb_gn + sub_gn + vision_transformer
3179 } else {
3180 0
3181 };
3182
3183 embed_tokens + lm_head + norm + image_embed
3184 };
3185
3186 Ok(elems * dtype.size_in_bytes())
3187 }
3188
3189 fn layer_sizes_in_bytes(
3190 &self,
3191 config: &str,
3192 dtype: DType,
3193 weight_pack_factor: usize,
3194 _matformer_config: Option<&MatformerSliceConfig>,
3195 ) -> Result<Vec<usize>> {
3196 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3197 let per_layer_elems = {
3198 let input_layernorm = cfg.hidden_size;
3199 let post_attention_layernorm = cfg.hidden_size;
3200
3201 let size_in = cfg.hidden_size;
3202 let head_dim = cfg.head_dim();
3203 let op_size =
3204 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3205 let qkv_proj = size_in * op_size / weight_pack_factor;
3206 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3207
3208 let h_size = cfg.hidden_size;
3209 let i_size = cfg.intermediate_size;
3210 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3211 let down_proj = h_size * i_size / weight_pack_factor;
3212
3213 input_layernorm
3214 + post_attention_layernorm
3215 + qkv_proj
3216 + o_proj
3217 + gate_up_proj
3218 + down_proj
3219 };
3220 Ok(vec![
3221 per_layer_elems * dtype.size_in_bytes();
3222 cfg.num_hidden_layers
3223 ])
3224 }
3225
3226 fn num_layers(&self, config: &str) -> Result<usize> {
3227 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3228 Ok(cfg.num_hidden_layers)
3229 }
3230
3231 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3232 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3233
3234 let cfg = ModelConfigMetadata {
3235 max_seq_len: cfg.max_position_embeddings,
3236 num_layers: cfg.num_hidden_layers,
3237 hidden_size: cfg.hidden_size,
3238 num_kv_heads: cfg.num_key_value_heads(),
3239 num_attn_heads: cfg.num_attention_heads,
3240 sliding_window: cfg.sliding_window,
3241 k_head_dim: cfg.head_dim(),
3242 v_head_dim: cfg.head_dim(),
3243 };
3244
3245 Ok(Box::new(cfg))
3246 }
3247
3248 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3249 Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
3250 }
3251}
3252
3253pub struct Qwen2_5VLLoader;
3259
3260pub struct Qwen2_5VLPrefixer;
3261
3262impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3263 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3264 format!(
3265 "{}{prompt}",
3266 format!(
3267 "{}{}{}",
3268 Qwen2_5VLProcessor::VISION_START,
3269 Qwen2_5VLProcessor::IMAGE_PAD,
3270 Qwen2_5VLProcessor::VISION_END
3271 )
3272 .repeat(image_indexes.len())
3273 )
3274 }
3275}
3276
3277impl VisionModelLoader for Qwen2_5VLLoader {
3278 fn load(
3279 &self,
3280 config: &str,
3281 vb: ShardedVarBuilder,
3282 normal_loading_metadata: NormalLoadingMetadata,
3283 attention_mechanism: AttentionImplementation,
3284 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3285 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3286 Ok(Box::new(Qwen2_5VLModel::new(
3287 &cfg,
3288 vb,
3289 self.is_gptx(config),
3290 normal_loading_metadata,
3291 attention_mechanism,
3292 )?))
3293 }
3294 fn is_gptx(&self, _config: &str) -> bool {
3295 true
3296 }
3297 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3298 let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3299 Ok(Box::new(config))
3300 }
3301 fn get_processor(
3302 &self,
3303 _model_config: &str,
3304 _processor_config: Option<ProcessorConfig>,
3305 _preprocessor_config: PreProcessorConfig,
3306 max_edge: Option<u32>,
3307 ) -> Arc<dyn Processor + Send + Sync> {
3308 Arc::new(Qwen2_5VLProcessor::new(max_edge))
3309 }
3310 fn supports_paged_attention(&self, _config: &str) -> bool {
3311 false
3312 }
3313 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3314 Arc::new(Qwen2_5VLPrefixer)
3315 }
3316 fn modalities(&self, _config: &str) -> Result<Modalities> {
3317 Ok(Modalities {
3318 input: vec![SupportedModality::Text, SupportedModality::Vision],
3319 output: vec![SupportedModality::Text],
3320 })
3321 }
3322}
3323
3324impl IsqModelLoader for Qwen2_5VLLoader {
3325 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3326 Ok(vec![
3327 Regex::new(r"lm_head\.(weight|bias)$")?,
3328 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3330 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3331 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3332 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3333 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3335 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3336 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3337 ])
3338 }
3339 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3340 self.isq_layer_regexes(config)
3341 }
3342}
3343
3344impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3345 fn mapped_max_act_size_elems(
3346 &self,
3347 config: &str,
3348 params: &AutoDeviceMapParams,
3349 ) -> Result<usize> {
3350 let AutoDeviceMapParams::Vision {
3351 max_seq_len,
3352 max_batch_size,
3353 max_image_shape,
3354 max_num_images,
3355 } = params
3356 else {
3357 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3358 };
3359
3360 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3361
3362 let img_seq_len = {
3363 let cfg = &cfg.vision_config;
3364 let grid_t = max_num_images / cfg.temporal_patch_size;
3365 let grid_h = max_image_shape.0 / cfg.patch_size;
3366 let grid_w = max_image_shape.1 / cfg.patch_size;
3367 grid_t * grid_h * grid_w
3368 };
3369 let img_seq_len = img_seq_len * max_num_images;
3370
3371 let max_text_attn = {
3372 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3374 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3375 };
3376
3377 Ok(max_text_attn)
3378 }
3379
3380 fn non_mapped_max_act_size_elems(
3381 &self,
3382 config: &str,
3383 params: &AutoDeviceMapParams,
3384 ) -> Result<usize> {
3385 let AutoDeviceMapParams::Vision {
3386 max_seq_len: _,
3387 max_batch_size,
3388 max_image_shape,
3389 max_num_images,
3390 } = params
3391 else {
3392 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3393 };
3394
3395 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3396
3397 let img_seq_len = {
3398 let cfg = &cfg.vision_config;
3399 let grid_t = max_num_images / cfg.temporal_patch_size;
3400 let grid_h = max_image_shape.0 / cfg.patch_size;
3401 let grid_w = max_image_shape.1 / cfg.patch_size;
3402 grid_t * grid_h * grid_w
3403 };
3404
3405 let max_vision_attn = {
3406 let cfg = &cfg.vision_config;
3407 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3408 };
3409
3410 Ok(max_vision_attn)
3411 }
3412
3413 fn non_mapped_size_in_bytes(
3414 &self,
3415 config: &str,
3416 dtype: DType,
3417 weight_pack_factor: usize,
3418 _matformer_config: Option<&MatformerSliceConfig>,
3419 ) -> Result<usize> {
3420 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3421 let text_elems = {
3422 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3423 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3425 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3426 } else {
3427 0
3428 };
3429 let norm = cfg.hidden_size;
3430 embed_tokens + lm_head + norm
3431 };
3432
3433 let patch_merger = {
3434 let cfg = &cfg.vision_config;
3435 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3436
3437 let mlp0 = hidden_size * hidden_size + hidden_size;
3438 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3439
3440 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3441
3442 mlp0 + mlp2 + ln_q
3443 };
3444
3445 let patch_embed = {
3446 let cfg = &cfg.vision_config;
3447 let conv_cfg = Conv3dConfig {
3448 stride: cfg.patch_size,
3449 ..Default::default()
3450 };
3451 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3452 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3453 * kernel_sizes[0]
3454 * kernel_sizes[1]
3455 * kernel_sizes[2]
3456 };
3457
3458 let encoder_layer = {
3459 let cfg = &cfg.vision_config;
3460 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3461 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3462
3463 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3464 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3465 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3466
3467 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3468 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3469
3470 norm1 + norm2 + fc1 + fc2 + qkv + out
3471 };
3472
3473 let elems =
3474 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3475
3476 Ok(elems * dtype.size_in_bytes())
3477 }
3478
3479 fn layer_sizes_in_bytes(
3480 &self,
3481 config: &str,
3482 dtype: DType,
3483 weight_pack_factor: usize,
3484 _matformer_config: Option<&MatformerSliceConfig>,
3485 ) -> Result<Vec<usize>> {
3486 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3487 let per_layer_elems = {
3488 let input_layernorm = cfg.hidden_size;
3489 let post_attention_layernorm = cfg.hidden_size;
3490
3491 let size_in = cfg.hidden_size;
3492 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3493 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3494 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3495 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3496 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3497 let o_proj = size_q * size_in / weight_pack_factor;
3498
3499 let h_size = cfg.hidden_size;
3500 let i_size = cfg.intermediate_size;
3501 let gate_proj = h_size * i_size / weight_pack_factor;
3502 let up_proj = h_size * i_size / weight_pack_factor;
3503 let down_proj = i_size * h_size / weight_pack_factor;
3504
3505 input_layernorm
3506 + post_attention_layernorm
3507 + q_proj
3508 + k_proj
3509 + v_proj
3510 + o_proj
3511 + gate_proj
3512 + up_proj
3513 + down_proj
3514 };
3515 Ok(vec![
3516 per_layer_elems * dtype.size_in_bytes();
3517 cfg.num_hidden_layers
3518 ])
3519 }
3520
3521 fn num_layers(&self, config: &str) -> Result<usize> {
3522 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3523 Ok(cfg.num_hidden_layers)
3524 }
3525
3526 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3527 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3528
3529 let cfg = ModelConfigMetadata {
3530 max_seq_len: cfg.max_position_embeddings,
3531 num_layers: cfg.num_hidden_layers,
3532 hidden_size: cfg.hidden_size,
3533 num_kv_heads: cfg.num_key_value_heads,
3534 num_attn_heads: cfg.num_attention_heads,
3535 sliding_window: cfg.sliding_window,
3536 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3537 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3538 };
3539
3540 Ok(Box::new(cfg))
3541 }
3542
3543 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3544 Some(vec![NonMappedSubModel::Vision])
3545 }
3546}
3547
3548pub struct Gemma3Loader;
3554
3555pub struct Gemma3Prefixer;
3556
3557impl MultimodalPromptPrefixer for Gemma3Prefixer {
3558 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3559 prompt.to_string()
3560 }
3561}
3562
3563impl VisionModelLoader for Gemma3Loader {
3564 fn load(
3565 &self,
3566 config: &str,
3567 vb: ShardedVarBuilder,
3568 normal_loading_metadata: NormalLoadingMetadata,
3569 attention_mechanism: AttentionImplementation,
3570 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3571 let cfg: Gemma3Config = serde_json::from_str(config)?;
3572 Ok(Box::new(Gemma3Model::new(
3573 &cfg,
3574 vb,
3575 self.is_gptx(config),
3576 normal_loading_metadata,
3577 attention_mechanism,
3578 )?))
3579 }
3580 fn is_gptx(&self, _config: &str) -> bool {
3581 true
3582 }
3583 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3584 let config: Gemma3Config = serde_json::from_str(config)?;
3585 Ok(Box::new(config))
3586 }
3587 fn get_processor(
3588 &self,
3589 config: &str,
3590 processor_config: Option<ProcessorConfig>,
3591 _preprocessor_config: PreProcessorConfig,
3592 _max_edge: Option<u32>,
3593 ) -> Arc<dyn Processor + Send + Sync> {
3594 let config: Gemma3Config = serde_json::from_str(config).unwrap();
3595 Arc::new(Gemma3Processor::new(
3597 processor_config.unwrap_or_default(),
3598 matches!(config, Gemma3Config::WithVision { .. }),
3599 ))
3600 }
3601 fn supports_paged_attention(&self, _config: &str) -> bool {
3602 true
3603 }
3604 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3605 true
3606 }
3607 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3608 Arc::new(Gemma3Prefixer)
3609 }
3610 fn modalities(&self, _config: &str) -> Result<Modalities> {
3611 Ok(Modalities {
3612 input: vec![SupportedModality::Text, SupportedModality::Vision],
3613 output: vec![SupportedModality::Text],
3614 })
3615 }
3616}
3617
3618impl IsqModelLoader for Gemma3Loader {
3619 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3620 Ok(vec![
3621 Regex::new(r"lm_head\.(weight|bias)$")?,
3622 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3624 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3625 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3626 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3627 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3629 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3630 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3631 ])
3632 }
3633 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3634 Ok(vec![
3635 Regex::new(r"lm_head\.(weight|bias)$")?,
3636 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3638 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3639 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3640 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3641 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3643 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3644 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3645 ])
3646 }
3647}
3648
3649impl DeviceMappedModelLoader for Gemma3Loader {
3650 fn mapped_max_act_size_elems(
3651 &self,
3652 config: &str,
3653 params: &AutoDeviceMapParams,
3654 ) -> Result<usize> {
3655 let AutoDeviceMapParams::Vision {
3656 max_seq_len,
3657 max_batch_size,
3658 max_image_shape: _,
3659 max_num_images,
3660 } = params
3661 else {
3662 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3663 };
3664
3665 let cfg: Gemma3Config = serde_json::from_str(config)?;
3666
3667 match cfg {
3668 Gemma3Config::Text(text_config) => Ok(max_batch_size
3669 * text_config.num_attention_heads
3670 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)),
3671 Gemma3Config::WithVision {
3672 text_config,
3673 vision_config,
3674 ..
3675 } => {
3676 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3677 let img_seq_len = (num_patches + 1) * max_num_images;
3678
3679 let max_text_attn = {
3680 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3682 max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3683 };
3684 Ok(max_text_attn)
3685 }
3686 }
3687 }
3688
3689 fn non_mapped_max_act_size_elems(
3690 &self,
3691 config: &str,
3692 params: &AutoDeviceMapParams,
3693 ) -> Result<usize> {
3694 let AutoDeviceMapParams::Vision {
3695 max_seq_len: _,
3696 max_batch_size,
3697 max_image_shape: _,
3698 max_num_images,
3699 } = params
3700 else {
3701 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3702 };
3703
3704 let cfg: Gemma3Config = serde_json::from_str(config)?;
3705
3706 match cfg {
3707 Gemma3Config::WithVision { vision_config, .. } => {
3708 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3709 let img_seq_len = num_patches + 1;
3710
3711 let max_vision_attn = {
3712 (max_batch_size * max_num_images)
3713 * vision_config.num_attention_heads
3714 * img_seq_len
3715 * img_seq_len
3716 };
3717
3718 Ok(max_vision_attn)
3719 }
3720 Gemma3Config::Text(_) => Ok(0),
3721 }
3722 }
3723
3724 fn non_mapped_size_in_bytes(
3725 &self,
3726 config: &str,
3727 dtype: DType,
3728 weight_pack_factor: usize,
3729 _matformer_config: Option<&MatformerSliceConfig>,
3730 ) -> Result<usize> {
3731 let cfg: Gemma3Config = serde_json::from_str(config)?;
3732
3733 let text_elems = {
3734 let cfg = match &cfg {
3735 Gemma3Config::Text(cfg) => cfg,
3736 Gemma3Config::WithVision { text_config, .. } => text_config,
3737 };
3738 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3739 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3741 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3742 } else {
3743 0
3744 };
3745 let norm = cfg.hidden_size;
3746 embed_tokens + lm_head + norm
3747 };
3748
3749 let vision_transformer = if let Gemma3Config::WithVision {
3750 vision_config: cfg, ..
3751 } = &cfg
3752 {
3753 let post_layernorm = cfg.hidden_size;
3754
3755 let conv_config = Conv2dConfig {
3756 stride: cfg.patch_size,
3757 ..Default::default()
3758 };
3759 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3760 * cfg.patch_size
3761 * cfg.patch_size;
3762
3763 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3764 let num_patches = num_patches_per_side.pow(2);
3765 let position_embedding = num_patches * cfg.hidden_size;
3766
3767 let layer_elems = {
3768 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3769 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3770
3771 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3772 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3773
3774 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3775 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3776 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3777 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3778
3779 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3780 };
3781
3782 post_layernorm
3783 + patch_embedding
3784 + position_embedding
3785 + layer_elems * cfg.num_hidden_layers
3786 } else {
3787 0
3788 };
3789
3790 let elems = text_elems + vision_transformer;
3791
3792 Ok(elems * dtype.size_in_bytes())
3793 }
3794
3795 fn layer_sizes_in_bytes(
3796 &self,
3797 config: &str,
3798 dtype: DType,
3799 weight_pack_factor: usize,
3800 _matformer_config: Option<&MatformerSliceConfig>,
3801 ) -> Result<Vec<usize>> {
3802 let cfg: Gemma3Config = serde_json::from_str(config)?;
3803
3804 let txt_cfg = match &cfg {
3805 Gemma3Config::Text(cfg) => cfg,
3806 Gemma3Config::WithVision { text_config, .. } => text_config,
3807 };
3808 let per_layer_elems = {
3809 let cfg = txt_cfg;
3810
3811 let input_layernorm = cfg.hidden_size;
3812 let post_attention_layernorm = cfg.hidden_size;
3813
3814 let size_in = cfg.hidden_size;
3815 let size_q = cfg.head_dim * cfg.num_attention_heads;
3816 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3817 let q_proj =
3818 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3819 let k_proj =
3820 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3821 let v_proj =
3822 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3823 let o_proj =
3824 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3825
3826 let h_size = cfg.hidden_size;
3827 let i_size = cfg.intermediate_size;
3828 let gate_proj = h_size * i_size / weight_pack_factor;
3829 let up_proj = h_size * i_size / weight_pack_factor;
3830 let down_proj = i_size * h_size / weight_pack_factor;
3831
3832 input_layernorm
3833 + post_attention_layernorm
3834 + q_proj
3835 + k_proj
3836 + v_proj
3837 + o_proj
3838 + gate_proj
3839 + up_proj
3840 + down_proj
3841 };
3842 Ok(vec![
3843 per_layer_elems * dtype.size_in_bytes();
3844 txt_cfg.num_hidden_layers
3845 ])
3846 }
3847
3848 fn num_layers(&self, config: &str) -> Result<usize> {
3849 let cfg: Gemma3Config = serde_json::from_str(config)?;
3850
3851 let txt_cfg = match &cfg {
3852 Gemma3Config::Text(cfg) => cfg,
3853 Gemma3Config::WithVision { text_config, .. } => text_config,
3854 };
3855
3856 Ok(txt_cfg.num_hidden_layers)
3857 }
3858
3859 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3860 let cfg: Gemma3Config = serde_json::from_str(config)?;
3861
3862 let cfg = match &cfg {
3863 Gemma3Config::Text(cfg) => cfg,
3864 Gemma3Config::WithVision { text_config, .. } => text_config,
3865 };
3866
3867 let cfg = ModelConfigMetadata {
3868 max_seq_len: cfg.max_position_embeddings,
3869 num_layers: cfg.num_hidden_layers,
3870 hidden_size: cfg.hidden_size,
3871 num_kv_heads: cfg.num_key_value_heads,
3872 num_attn_heads: cfg.num_attention_heads,
3873 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3875 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3876 };
3877
3878 Ok(Box::new(cfg))
3879 }
3880
3881 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3882 Some(vec![NonMappedSubModel::Vision])
3883 }
3884}
3885
3886pub struct Mistral3Loader;
3892
3893pub struct Mistral3Prefixer;
3894
3895impl MultimodalPromptPrefixer for Mistral3Prefixer {
3896 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3897 prompt.to_string()
3898 }
3899}
3900
3901impl VisionModelLoader for Mistral3Loader {
3902 fn load(
3903 &self,
3904 config: &str,
3905 vb: ShardedVarBuilder,
3906 normal_loading_metadata: NormalLoadingMetadata,
3907 attention_mechanism: AttentionImplementation,
3908 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3909 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3910 Ok(Box::new(Mistral3Model::new(
3911 &cfg,
3912 vb,
3913 self.is_gptx(config),
3914 normal_loading_metadata,
3915 attention_mechanism,
3916 )?))
3917 }
3918 fn is_gptx(&self, _config: &str) -> bool {
3919 true
3920 }
3921 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3922 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3923 Ok(Box::new(cfg))
3924 }
3925 fn get_processor(
3926 &self,
3927 _model_config: &str,
3928 processor_config: Option<ProcessorConfig>,
3929 _preprocessor_config: PreProcessorConfig,
3930 _max_edge: Option<u32>,
3931 ) -> Arc<dyn Processor + Send + Sync> {
3932 Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3933 }
3934 fn supports_paged_attention(&self, _config: &str) -> bool {
3935 true
3936 }
3937 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3938 true
3939 }
3940 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3941 Arc::new(Mistral3Prefixer)
3942 }
3943 fn modalities(&self, _config: &str) -> Result<Modalities> {
3944 Ok(Modalities {
3945 input: vec![SupportedModality::Text, SupportedModality::Vision],
3946 output: vec![SupportedModality::Text],
3947 })
3948 }
3949}
3950
3951impl IsqModelLoader for Mistral3Loader {
3952 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3953 Ok(vec![
3954 Regex::new(r"lm_head\.(weight|bias)$")?,
3955 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3957 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3958 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3959 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3960 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3962 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3963 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3964 ])
3965 }
3966 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3967 Ok(vec![
3968 Regex::new(r"lm_head\.(weight|bias)$")?,
3969 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3971 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3972 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3973 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3974 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3976 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3977 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3978 ])
3979 }
3980}
3981
3982#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3983impl DeviceMappedModelLoader for Mistral3Loader {
3984 fn mapped_max_act_size_elems(
3985 &self,
3986 config: &str,
3987 params: &AutoDeviceMapParams,
3988 ) -> Result<usize> {
3989 let cfg: Mistral3Config = serde_json::from_str(config)?;
3990 let vcfg = &cfg.vision_config;
3991 let tcfg = &cfg.text_config;
3992
3993 let AutoDeviceMapParams::Vision {
3994 max_seq_len,
3995 max_batch_size,
3996 max_image_shape: (mut height, mut width),
3997 max_num_images,
3998 } = params
3999 else {
4000 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4001 };
4002
4003 let img_seq_len = {
4004 let (max_height, max_width) = (1540, 1540);
4008 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4009 if ratio > 1. {
4010 height = (height as f64 / ratio).floor() as usize;
4011 width = (width as f64 / ratio).floor() as usize;
4012 }
4013
4014 let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
4015 let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
4016
4017 height = num_height_tokens * vcfg.patch_size;
4018 width = num_width_tokens * vcfg.patch_size;
4019
4020 let num_height_tokens = height / vcfg.patch_size;
4021 let num_width_tokens = width / vcfg.patch_size;
4022
4023 (num_width_tokens + 1) * num_height_tokens
4024 };
4025
4026 let max_seq_len = img_seq_len * max_num_images + *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4028 Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
4029 }
4030
4031 fn non_mapped_max_act_size_elems(
4032 &self,
4033 config: &str,
4034 params: &AutoDeviceMapParams,
4035 ) -> Result<usize> {
4036 let cfg: Mistral3Config = serde_json::from_str(config)?;
4037 let cfg = &cfg.vision_config;
4038
4039 let AutoDeviceMapParams::Vision {
4040 max_seq_len: _,
4041 max_batch_size,
4042 max_image_shape: (mut height, mut width),
4043 max_num_images,
4044 } = params
4045 else {
4046 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4047 };
4048
4049 let img_seq_len = {
4050 let (max_height, max_width) = (1540, 1540);
4054 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4055 if ratio > 1. {
4056 height = (height as f64 / ratio).floor() as usize;
4057 width = (width as f64 / ratio).floor() as usize;
4058 }
4059
4060 let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4061 let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4062
4063 height = num_height_tokens * cfg.patch_size;
4064 width = num_width_tokens * cfg.patch_size;
4065
4066 let num_height_tokens = height / cfg.patch_size;
4067 let num_width_tokens = width / cfg.patch_size;
4068
4069 (num_width_tokens + 1) * num_height_tokens
4070 };
4071
4072 Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4073 }
4074
4075 fn non_mapped_size_in_bytes(
4076 &self,
4077 config: &str,
4078 dtype: DType,
4079 weight_pack_factor: usize,
4080 _matformer_config: Option<&MatformerSliceConfig>,
4081 ) -> Result<usize> {
4082 let cfg: Mistral3Config = serde_json::from_str(config)?;
4083
4084 let text_elems = {
4085 let cfg = &cfg.text_config;
4086
4087 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4088 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4090 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4091 } else {
4092 0
4093 };
4094 let norm = cfg.hidden_size;
4095 embed_tokens + lm_head + norm
4096 };
4097
4098 let vision_elems = {
4099 let cfg = &cfg.vision_config;
4100
4101 let patch_embed = {
4102 let conv_cfg = Conv2dConfig {
4103 stride: cfg.patch_size,
4104 ..Default::default()
4105 };
4106 cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4107 * cfg.patch_size
4108 * cfg.patch_size
4109 * cfg.patch_size
4110 };
4111 let ln_pre = cfg.hidden_size;
4112 let vision_layer = {
4113 let attn_norm = cfg.hidden_size;
4114 let ffn_norm = cfg.hidden_size;
4115
4116 let gate = cfg.hidden_size * cfg.intermediate_size;
4117 let up = cfg.hidden_size * cfg.intermediate_size;
4118 let down = cfg.hidden_size * cfg.intermediate_size;
4119
4120 let q = cfg.hidden_size * cfg.hidden_size;
4121 let k = cfg.hidden_size * cfg.hidden_size;
4122 let v = cfg.hidden_size * cfg.hidden_size;
4123 let o = cfg.hidden_size * cfg.hidden_size;
4124
4125 attn_norm + ffn_norm + gate + up + down + q + k + v + o
4126 };
4127
4128 patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4129 };
4130
4131 let elems = text_elems + vision_elems;
4132
4133 Ok(elems * dtype.size_in_bytes())
4134 }
4135
4136 fn layer_sizes_in_bytes(
4137 &self,
4138 config: &str,
4139 dtype: DType,
4140 weight_pack_factor: usize,
4141 _matformer_config: Option<&MatformerSliceConfig>,
4142 ) -> Result<Vec<usize>> {
4143 let cfg: Mistral3Config = serde_json::from_str(config)?;
4144 let cfg = &cfg.text_config;
4145
4146 let per_layer_elems = {
4147 let input_layernorm = cfg.hidden_size;
4148 let post_attention_layernorm = cfg.hidden_size;
4149
4150 let size_in = cfg.hidden_size;
4151 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4152 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4153 let q_proj = size_in * size_q / weight_pack_factor;
4154 let k_proj = size_in * size_kv / weight_pack_factor;
4155 let v_proj = size_in * size_kv / weight_pack_factor;
4156 let o_proj = size_q * size_in / weight_pack_factor;
4157
4158 let h_size = cfg.hidden_size;
4159 let i_size = cfg.intermediate_size;
4160 let gate_proj = h_size * i_size / weight_pack_factor;
4161 let up_proj = h_size * i_size / weight_pack_factor;
4162 let down_proj = i_size * h_size / weight_pack_factor;
4163
4164 input_layernorm
4165 + post_attention_layernorm
4166 + q_proj
4167 + k_proj
4168 + v_proj
4169 + o_proj
4170 + gate_proj
4171 + up_proj
4172 + down_proj
4173 };
4174 Ok(vec![
4175 per_layer_elems * dtype.size_in_bytes();
4176 cfg.num_hidden_layers
4177 ])
4178 }
4179
4180 fn num_layers(&self, config: &str) -> Result<usize> {
4181 let cfg: Mistral3Config = serde_json::from_str(config)?;
4182 let cfg = &cfg.text_config;
4183 Ok(cfg.num_hidden_layers)
4184 }
4185
4186 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4187 let cfg: Mistral3Config = serde_json::from_str(config)?;
4188 let cfg = &cfg.text_config;
4189
4190 let cfg = ModelConfigMetadata {
4191 max_seq_len: cfg.max_position_embeddings,
4192 num_layers: cfg.num_hidden_layers,
4193 hidden_size: cfg.hidden_size,
4194 num_kv_heads: cfg.num_key_value_heads,
4195 num_attn_heads: cfg.num_attention_heads,
4196 sliding_window: cfg.sliding_window,
4197 k_head_dim: cfg.head_dim(),
4198 v_head_dim: cfg.head_dim(),
4199 };
4200
4201 Ok(Box::new(cfg))
4202 }
4203
4204 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4205 Some(vec![NonMappedSubModel::Vision])
4206 }
4207}
4208
4209pub struct VLlama4Loader;
4215
4216pub struct VLlama4Prefixer;
4217
4218impl MultimodalPromptPrefixer for VLlama4Prefixer {
4219 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4220 format!(
4221 "{}{prompt}",
4222 llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4223 )
4224 }
4225}
4226
4227impl VisionModelLoader for VLlama4Loader {
4228 fn load(
4229 &self,
4230 config: &str,
4231 vb: ShardedVarBuilder,
4232 normal_loading_metadata: NormalLoadingMetadata,
4233 attention_mechanism: AttentionImplementation,
4234 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4235 let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4236 Ok(Box::new(Llama4Model::new(
4237 &cfg,
4238 vb,
4239 self.is_gptx(config),
4240 normal_loading_metadata,
4241 attention_mechanism,
4242 )?))
4243 }
4244 fn is_gptx(&self, _config: &str) -> bool {
4245 false
4246 }
4247 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4248 let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4249 Ok(Box::new(cfg))
4250 }
4251 fn get_processor(
4252 &self,
4253 _model_config: &str,
4254 processor_config: Option<ProcessorConfig>,
4255 _preprocessor_config: PreProcessorConfig,
4256 _max_edge: Option<u32>,
4257 ) -> Arc<dyn Processor + Send + Sync> {
4258 Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4259 }
4260 fn supports_paged_attention(&self, _config: &str) -> bool {
4261 true
4262 }
4263 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4264 Arc::new(VLlama4Prefixer)
4265 }
4266 fn modalities(&self, _config: &str) -> Result<Modalities> {
4267 Ok(Modalities {
4268 input: vec![SupportedModality::Text, SupportedModality::Vision],
4269 output: vec![SupportedModality::Text],
4270 })
4271 }
4272}
4273
4274impl IsqModelLoader for VLlama4Loader {
4275 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4276 Ok(vec![
4277 Regex::new(r"lm_head\.(weight|bias)$")?,
4278 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4280 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4281 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4282 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4283 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4285 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4286 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4287 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4288 Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4289 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4290 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4291 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4292 Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4294 Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4295 Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4296 ])
4297 }
4298 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4299 Ok(vec![
4300 Regex::new(r"lm_head\.(weight|bias)$")?,
4301 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4303 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4304 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4305 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4306 Regex::new(
4308 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4309 )?,
4310 Regex::new(
4311 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4312 )?,
4313 Regex::new(
4314 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4315 )?,
4316 Regex::new(
4317 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4318 )?,
4319 Regex::new(
4320 r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4321 )?,
4322 Regex::new(
4323 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4324 )?,
4325 Regex::new(
4326 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4327 )?,
4328 Regex::new(
4329 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4330 )?,
4331 Regex::new(
4333 r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4334 )?,
4335 Regex::new(
4336 r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4337 )?,
4338 Regex::new(
4339 r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4340 )?,
4341 ])
4342 }
4343}
4344
4345impl VLlama4Loader {
4346 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4349 fn run_dummy_processing(
4350 &self,
4351 cfg: &Llama4Config,
4352 height: usize,
4353 width: usize,
4354 max_num_images: usize,
4355 max_batch_size: usize,
4356 ) -> Result<(usize, usize)> {
4357 let cfg = &cfg.vision_config;
4358
4359 let img_processor =
4360 Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4361 let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4362 let res = img_processor.preprocess(
4363 vec![image; max_num_images],
4364 vec![],
4365 &PreProcessorConfig::default(),
4366 &Device::Cpu,
4367 (max_batch_size, max_num_images),
4368 )?;
4369
4370 let pixels_batch_size = res.pixel_values.dim(0)?;
4371 let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4372
4373 let (image_h, image_w) = (
4374 res.pixel_values.dim(D::Minus2).unwrap(),
4375 res.pixel_values.dim(D::Minus1).unwrap(),
4376 );
4377 let num_patches_per_chunk = (image_h / img_processor.patch_size)
4378 * (image_w / img_processor.patch_size)
4379 / img_processor.downsample_ratio;
4380
4381 Ok((
4382 pixels_max_batch_size,
4383 num_patches_per_chunk * pixels_max_batch_size,
4384 ))
4385 }
4386}
4387
4388impl DeviceMappedModelLoader for VLlama4Loader {
4389 fn mapped_max_act_size_elems(
4390 &self,
4391 config: &str,
4392 params: &AutoDeviceMapParams,
4393 ) -> Result<usize> {
4394 let AutoDeviceMapParams::Vision {
4395 max_seq_len,
4396 max_batch_size,
4397 max_image_shape: (height, width),
4398 max_num_images,
4399 } = params
4400 else {
4401 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4402 };
4403
4404 let cfg: Llama4Config = serde_json::from_str(config)?;
4405
4406 let (_pixels_batch_size, num_text_image_toks) =
4407 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4408
4409 let max_seq_len = max_seq_len.min(&ATTENTION_CHUNK_SIZE) + num_text_image_toks;
4410
4411 Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4412 }
4413 fn non_mapped_max_act_size_elems(
4414 &self,
4415 config: &str,
4416 params: &AutoDeviceMapParams,
4417 ) -> Result<usize> {
4418 let AutoDeviceMapParams::Vision {
4419 max_seq_len: _,
4420 max_batch_size,
4421 max_image_shape: (height, width),
4422 max_num_images,
4423 } = params
4424 else {
4425 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4426 };
4427
4428 let cfg: Llama4Config = serde_json::from_str(config)?;
4429
4430 let (pixels_batch_size, _num_text_image_toks) =
4431 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4432 let max_seq_len = cfg.vision_config.num_patches();
4433
4434 Ok((max_batch_size * pixels_batch_size)
4435 * cfg.vision_config.num_attention_heads
4436 * max_seq_len
4437 * max_seq_len)
4438 }
4439
4440 fn non_mapped_size_in_bytes(
4441 &self,
4442 config: &str,
4443 dtype: DType,
4444 weight_pack_factor: usize,
4445 _matformer_config: Option<&MatformerSliceConfig>,
4446 ) -> Result<usize> {
4447 let cfg: Llama4Config = serde_json::from_str(config)?;
4448 let tcfg = &cfg.text_config;
4449
4450 let text_elems = {
4451 let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4452 let lm_head = if !tcfg.tie_word_embeddings {
4453 tcfg.hidden_size * tcfg.vocab_size
4454 } else {
4455 0
4456 };
4457 let norm = tcfg.hidden_size;
4458 embed_tokens + lm_head + norm
4459 };
4460
4461 let vision_elems = {
4462 let cfg = &cfg.vision_config;
4463
4464 let num_patches = cfg.num_patches();
4465
4466 let unfold_elems =
4467 (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4468 let class_embeddng_elems = cfg.hidden_size;
4469 let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4470 let layernorm_pre_elems = cfg.hidden_size;
4471 let layernorm_post_elems = cfg.hidden_size;
4472
4473 let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4474 / weight_pack_factor
4475 + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4476
4477 let encoder_layer = {
4478 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4479 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4480
4481 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4482 let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4483 / weight_pack_factor
4484 + cfg.num_attention_heads * head_dim;
4485 let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4486 / weight_pack_factor
4487 + cfg.num_attention_heads * head_dim;
4488 let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4489 / weight_pack_factor
4490 + cfg.num_attention_heads * head_dim;
4491 let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4492 / weight_pack_factor
4493 + cfg.num_attention_heads * head_dim;
4494
4495 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4496 + cfg.intermediate_size;
4497 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4498 + cfg.hidden_size;
4499
4500 input_layernorm
4501 + post_attention_layernorm
4502 + q_proj
4503 + k_proj
4504 + v_proj
4505 + o_proj
4506 + fc1
4507 + fc2
4508 };
4509
4510 unfold_elems
4511 + class_embeddng_elems
4512 + positional_embedding_vlm_elems
4513 + layernorm_post_elems
4514 + layernorm_pre_elems
4515 + pixel_shuffle_elems
4516 + encoder_layer * cfg.num_hidden_layers
4517 };
4518
4519 let elems = text_elems + vision_elems;
4520
4521 Ok(elems * dtype.size_in_bytes())
4522 }
4523
4524 fn layer_sizes_in_bytes(
4525 &self,
4526 config: &str,
4527 dtype: DType,
4528 weight_pack_factor: usize,
4529 _matformer_config: Option<&MatformerSliceConfig>,
4530 ) -> Result<Vec<usize>> {
4531 let cfg: Llama4Config = serde_json::from_str(config)?;
4532 let tcfg = &cfg.text_config;
4533
4534 let mut per_layer_elems = Vec::new();
4535
4536 for layer_idx in 0..tcfg.num_hidden_layers {
4537 let input_layernorm = tcfg.hidden_size;
4538 let post_attention_layernorm = tcfg.hidden_size;
4539
4540 let size_in = tcfg.hidden_size;
4541 let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4542 let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4543 let q_proj = size_in * size_q / weight_pack_factor;
4544 let k_proj = size_in * size_kv / weight_pack_factor;
4545 let v_proj = size_in * size_kv / weight_pack_factor;
4546 let o_proj = size_q * size_in / weight_pack_factor;
4547
4548 let use_moe = tcfg.moe_layers().contains(&layer_idx);
4549 let moe_block = if use_moe {
4550 let h_size = tcfg.hidden_size;
4551 let i_size = tcfg.intermediate_size;
4552 let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4553 let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4554 let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4555
4556 gate_proj + up_proj + down_proj
4557 } else {
4558 let h_size = tcfg.hidden_size;
4559 let i_size = tcfg.intermediate_size_mlp;
4560 let gate_proj = h_size * i_size / weight_pack_factor;
4561 let up_proj = h_size * i_size / weight_pack_factor;
4562 let down_proj = i_size * h_size / weight_pack_factor;
4563
4564 gate_proj + up_proj + down_proj
4565 };
4566
4567 per_layer_elems.push(
4568 input_layernorm
4569 + post_attention_layernorm
4570 + q_proj
4571 + k_proj
4572 + v_proj
4573 + o_proj
4574 + moe_block,
4575 );
4576 }
4577
4578 Ok(per_layer_elems
4579 .into_iter()
4580 .map(|x| x * dtype.size_in_bytes())
4581 .collect())
4582 }
4583
4584 fn num_layers(&self, config: &str) -> Result<usize> {
4585 let cfg: Llama4Config = serde_json::from_str(config)?;
4586 Ok(cfg.text_config.num_hidden_layers)
4587 }
4588
4589 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4590 let cfg: Llama4Config = serde_json::from_str(config)?;
4591 let cfg = &cfg.text_config;
4592
4593 let cfg = ModelConfigMetadata {
4594 max_seq_len: cfg.max_position_embeddings,
4595 num_layers: cfg.num_hidden_layers,
4596 hidden_size: cfg.hidden_size,
4597 num_kv_heads: cfg.num_attention_heads,
4598 num_attn_heads: cfg.num_attention_heads,
4599 sliding_window: None,
4600 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4601 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4602 };
4603
4604 Ok(Box::new(cfg))
4605 }
4606
4607 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4608 Some(vec![NonMappedSubModel::Vision])
4609 }
4610}
4611
4612pub struct Gemma3nLoader;
4618
4619#[allow(dead_code)]
4620pub struct Gemma3nPrefixer;
4621
4622impl MultimodalPromptPrefixer for Gemma3nPrefixer {
4623 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4624 prompt.to_string()
4625 }
4626}
4627
4628impl VisionModelLoader for Gemma3nLoader {
4629 fn load(
4630 &self,
4631 config: &str,
4632 vb: ShardedVarBuilder,
4633 normal_loading_metadata: NormalLoadingMetadata,
4634 attention_mechanism: AttentionImplementation,
4635 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4636 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4637 Ok(Box::new(Gemma3nModel::new(
4638 &cfg,
4639 vb,
4640 self.is_gptx(config),
4641 normal_loading_metadata,
4642 attention_mechanism,
4643 )?))
4644 }
4645 fn is_gptx(&self, _config: &str) -> bool {
4646 true
4647 }
4648 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4649 let config: Gemma3nConfig = serde_json::from_str(config)?;
4650 Ok(Box::new(config))
4651 }
4652 fn get_processor(
4653 &self,
4654 _config: &str,
4655 processor_config: Option<ProcessorConfig>,
4656 _preprocessor_config: PreProcessorConfig,
4657 _max_edge: Option<u32>,
4658 ) -> Arc<dyn Processor + Send + Sync> {
4659 Arc::new(Gemma3nProcessor::new(
4661 processor_config.unwrap_or_default(),
4662 true,
4663 ))
4664 }
4665 fn supports_paged_attention(&self, _config: &str) -> bool {
4666 false
4667 }
4668 fn supports_prefix_cacher(&self, _config: &str) -> bool {
4669 true
4670 }
4671 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4672 Arc::new(Gemma3Prefixer)
4673 }
4674 fn modalities(&self, _config: &str) -> Result<Modalities> {
4675 Ok(Modalities {
4676 input: vec![
4677 SupportedModality::Text,
4678 SupportedModality::Vision,
4679 SupportedModality::Audio,
4680 ],
4681 output: vec![SupportedModality::Text],
4682 })
4683 }
4684}
4685
4686impl IsqModelLoader for Gemma3nLoader {
4687 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4688 Ok(vec![
4689 Regex::new(r"lm_head\.(weight|bias)$")?,
4690 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4692 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4693 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4694 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4695 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4697 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4698 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4699 Regex::new(r"conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$")?,
4701 Regex::new(r"conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$")?,
4702 Regex::new(r"conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$")?,
4703 Regex::new(
4704 r"conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4705 )?,
4706 Regex::new(r"conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4707 Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$")?,
4709 Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$")?,
4710 Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$")?,
4711 Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$")?,
4712 Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$")?,
4714 Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$")?,
4715 Regex::new(r"subsample_conv_projection\.input_proj_linear\.(weight|bias)$")?,
4717 Regex::new(r"embed_vision\.embedding_projection\.(weight|bias)$")?,
4719 Regex::new(r"embed_audio\.embedding_projection\.(weight|bias)$")?,
4720 ])
4721 }
4722 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4723 Ok(vec![
4724 Regex::new(r"lm_head\.(weight|bias)$")?,
4725 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4727 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4728 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4729 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4730 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4732 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4733 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4734 Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
4736 Regex::new(r"model\.language_model\.altup_projections\.(\d+)\.(weight|bias)$")?,
4737 Regex::new(r"model\.language_model\.altup_unembed_projections\.(\d+)\.(weight|bias)$")?,
4738 Regex::new(
4740 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$",
4741 )?,
4742 Regex::new(
4743 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$",
4744 )?,
4745 Regex::new(
4746 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$",
4747 )?,
4748 Regex::new(
4749 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4750 )?,
4751 Regex::new(r"model\.audio_tower\.conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4752 Regex::new(
4754 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$",
4755 )?,
4756 Regex::new(
4757 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$",
4758 )?,
4759 Regex::new(
4760 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$",
4761 )?,
4762 Regex::new(
4763 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$",
4764 )?,
4765 Regex::new(
4767 r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$",
4768 )?,
4769 Regex::new(
4770 r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$",
4771 )?,
4772 Regex::new(
4774 r"model\.audio_tower\.subsample_conv_projection\.input_proj_linear\.(weight|bias)$",
4775 )?,
4776 Regex::new(r"model\.embed_vision\.embedding_projection\.(weight|bias)$")?,
4778 Regex::new(r"model\.embed_audio\.embedding_projection\.(weight|bias)$")?,
4779 ])
4780 }
4781}
4782
4783impl DeviceMappedModelLoader for Gemma3nLoader {
4784 fn mapped_max_act_size_elems(
4785 &self,
4786 config: &str,
4787 params: &AutoDeviceMapParams,
4788 ) -> Result<usize> {
4789 let AutoDeviceMapParams::Vision {
4790 max_seq_len,
4791 max_batch_size,
4792 max_image_shape: _,
4793 max_num_images,
4794 } = params
4795 else {
4796 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4797 };
4798
4799 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4800 let text_cfg = &cfg.text_config;
4801
4802 let mut total_seq_len = *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4806
4807 {
4809 let msfa_spatial_size = 16; let vision_tokens_per_image = msfa_spatial_size * msfa_spatial_size; total_seq_len += vision_tokens_per_image * max_num_images;
4814 }
4815
4816 {
4818 let audio_tokens = cfg.audio_soft_tokens_per_image;
4821 total_seq_len += audio_tokens;
4822 }
4823
4824 let max_text_attn =
4826 max_batch_size * text_cfg.num_attention_heads * total_seq_len * total_seq_len;
4827
4828 Ok(max_text_attn)
4829 }
4830
4831 fn non_mapped_max_act_size_elems(
4832 &self,
4833 config: &str,
4834 params: &AutoDeviceMapParams,
4835 ) -> Result<usize> {
4836 let AutoDeviceMapParams::Vision {
4837 max_seq_len: _,
4838 max_batch_size,
4839 max_image_shape: _,
4840 max_num_images,
4841 } = params
4842 else {
4843 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4844 };
4845
4846 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4847
4848 let mut max_activation = 0;
4850
4851 {
4853 let vision_tower_act = {
4863 let num_heads = 16; let spatial_size = 24; let seq_len = spatial_size * spatial_size;
4869
4870 max_batch_size * max_num_images * num_heads * seq_len * seq_len
4872 };
4873
4874 let vision_embed_act = {
4876 let msfa_channels = 2048; let spatial_size = 16; let vision_features =
4880 max_batch_size * max_num_images * msfa_channels * spatial_size * spatial_size;
4881
4882 let projected = max_batch_size
4884 * max_num_images
4885 * spatial_size
4886 * spatial_size
4887 * cfg.text_config.hidden_size;
4888
4889 vision_features.max(projected)
4890 };
4891
4892 max_activation = max_activation.max(vision_tower_act).max(vision_embed_act);
4893 }
4894
4895 {
4897 let audio_cfg = &cfg.audio_config;
4898
4899 let max_audio_frames = 1280;
4904
4905 let subsample_factor: usize = audio_cfg
4906 .sscp_conv_stride_size
4907 .iter()
4908 .map(|stride| stride[0]) .product();
4910 let audio_seq_after_subsample = max_audio_frames / subsample_factor;
4911
4912 let audio_encoder_act = {
4914 let intermediate_size = audio_cfg.hidden_size * 4; max_batch_size * audio_seq_after_subsample * intermediate_size
4919 };
4920
4921 let audio_attn_act = {
4923 let chunk_size = audio_cfg.conf_attention_chunk_size;
4925 let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
4926 + audio_cfg.conf_attention_context_right;
4927
4928 let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
4930
4931 max_batch_size
4932 * audio_cfg.conf_num_attention_heads
4933 * num_chunks
4934 * chunk_size
4935 * context_size
4936 };
4937
4938 max_activation = max_activation.max(audio_encoder_act).max(audio_attn_act);
4939 }
4940
4941 Ok(max_activation)
4942 }
4943
4944 fn non_mapped_size_in_bytes(
4945 &self,
4946 config: &str,
4947 dtype: DType,
4948 weight_pack_factor: usize,
4949 matformer_config: Option<&MatformerSliceConfig>,
4950 ) -> Result<usize> {
4951 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4952
4953 let text_cfg = if let Some(matformer_cfg) = matformer_config {
4955 use crate::device_map::DummyDeviceMapper;
4956 use crate::vision_models::gemma3n::text::handle_matformer_slicing;
4957
4958 let dummy_mapper = DummyDeviceMapper {
4959 nm_device: Device::Cpu,
4960 };
4961 let (adjusted_cfg, _, _, _, _) = handle_matformer_slicing(
4962 &cfg.text_config,
4963 &Some(matformer_cfg.clone()),
4964 &dummy_mapper,
4965 )?;
4966 adjusted_cfg
4967 } else {
4968 cfg.text_config.clone()
4969 };
4970
4971 let text_cfg = &text_cfg;
4972
4973 let text_elems = {
4975 let embed_tokens = text_cfg.hidden_size * text_cfg.vocab_size;
4977 let embed_tokens_per_layer = text_cfg.num_hidden_layers
4978 * text_cfg.hidden_size_per_layer_input
4979 * text_cfg.vocab_size_per_layer_input;
4980
4981 let lm_head = if !text_cfg.tie_word_embeddings || weight_pack_factor != 1 {
4983 text_cfg.hidden_size * text_cfg.vocab_size / weight_pack_factor
4984 } else {
4985 0
4986 };
4987
4988 let norm = text_cfg.hidden_size;
4990
4991 let altup_projections =
4993 (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
4994 / weight_pack_factor;
4995 let altup_unembed_projections =
4996 (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
4997 / weight_pack_factor;
4998
4999 let per_layer_model_projection = text_cfg.num_hidden_layers
5001 * text_cfg.hidden_size
5002 * text_cfg.hidden_size_per_layer_input
5003 / weight_pack_factor;
5004 let per_layer_projection_norm = text_cfg.hidden_size;
5005
5006 embed_tokens
5007 + embed_tokens_per_layer
5008 + lm_head
5009 + norm
5010 + altup_projections
5011 + altup_unembed_projections
5012 + per_layer_model_projection
5013 + per_layer_projection_norm
5014 };
5015
5016 let vision_elems = {
5018 let vision_cfg = &cfg.vision_config;
5019 let vision_tower_elems = {
5023 use crate::vision_models::gemma3n::vision::{
5024 gemma3n_mobilenet_def, make_divisible, BlockType, INPUT_CHANNELS,
5025 MSFA_EXPANSION_RATIO, MSFA_IN_CHANNELS, MSFA_OUT_CHANNELS, STEM_KERNEL_SIZE,
5026 STEM_OUT_CHANNELS,
5027 };
5028
5029 let stem_conv =
5031 INPUT_CHANNELS * STEM_OUT_CHANNELS * STEM_KERNEL_SIZE * STEM_KERNEL_SIZE;
5032 let stem_norm = STEM_OUT_CHANNELS; let mut in_chs = STEM_OUT_CHANNELS;
5036 let mut total_elems = stem_conv + stem_norm;
5037
5038 let block_defs = gemma3n_mobilenet_def();
5040
5041 for stage_blocks in block_defs.iter() {
5042 for block_type in stage_blocks.iter() {
5043 match block_type {
5044 BlockType::EdgeResidual {
5045 out_channels,
5046 kernel_size,
5047 stride: _,
5048 expand_ratio,
5049 ..
5050 } => {
5051 #[allow(clippy::cast_precision_loss)]
5052 let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5053 total_elems += in_chs * mid_chs * kernel_size * kernel_size; total_elems += mid_chs; total_elems += mid_chs * out_channels; total_elems += out_channels; in_chs = *out_channels;
5059 }
5060 BlockType::UniversalInvertedResidual {
5061 out_channels,
5062 start_kernel_size,
5063 mid_kernel_size,
5064 stride: _,
5065 expand_ratio,
5066 ..
5067 } => {
5068 #[allow(clippy::cast_precision_loss)]
5069 let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5070 if *expand_ratio != 1.0 {
5072 total_elems += in_chs * mid_chs; total_elems += mid_chs; }
5075 if *start_kernel_size > 0 {
5076 total_elems += mid_chs * start_kernel_size * start_kernel_size; total_elems += mid_chs; }
5079 if *mid_kernel_size > 0 {
5080 total_elems += mid_chs * mid_kernel_size * mid_kernel_size; total_elems += mid_chs; }
5083 total_elems += mid_chs * out_channels; total_elems += out_channels; total_elems += out_channels; in_chs = *out_channels;
5087 }
5088 BlockType::MultiQueryAttention {
5089 num_heads,
5090 kv_dim,
5091 kv_stride: _,
5092 ..
5093 } => {
5094 let dw_kernel_size = 3; total_elems += in_chs; total_elems += in_chs * num_heads * kv_dim; total_elems += in_chs * kv_dim; total_elems += in_chs * dw_kernel_size * dw_kernel_size; total_elems += *kv_dim; total_elems += 1; total_elems += *kv_dim; total_elems += num_heads * kv_dim * in_chs; total_elems += in_chs; }
5106 }
5107 }
5108 }
5109
5110 let msfa_in = MSFA_IN_CHANNELS.iter().sum::<usize>();
5112 let msfa_out = MSFA_OUT_CHANNELS;
5113 #[allow(clippy::cast_precision_loss)]
5114 let msfa_mid = make_divisible(msfa_in as f64 * MSFA_EXPANSION_RATIO, 8);
5115
5116 total_elems += msfa_in * msfa_mid; total_elems += msfa_mid; total_elems += msfa_mid * msfa_out; total_elems += msfa_out; total_elems += msfa_out; total_elems
5124 };
5125
5126 let embed_vision_elems = {
5128 let embedding = vision_cfg.vocab_size * vision_cfg.hidden_size;
5130
5131 let hard_norm = vision_cfg.hidden_size;
5133 let soft_norm = vision_cfg.hidden_size;
5134
5135 let projection = vision_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5137
5138 let post_norm = text_cfg.hidden_size;
5140
5141 embedding + hard_norm + soft_norm + projection + post_norm
5142 };
5143
5144 vision_tower_elems + embed_vision_elems
5145 };
5146
5147 let audio_elems = {
5149 let audio_cfg = &cfg.audio_config;
5150
5151 let subsample_conv_projection_elems = {
5153 let mut conv_elems = 0;
5155
5156 let in_ch_0 = 1;
5158 let out_ch_0 = audio_cfg.sscp_conv_channel_size[0];
5159 let kernel_0 = &audio_cfg.sscp_conv_kernel_size[0];
5160 conv_elems += in_ch_0 * out_ch_0 * kernel_0[0] * kernel_0[1];
5161
5162 let in_ch_1 = out_ch_0;
5164 let out_ch_1 = audio_cfg.sscp_conv_channel_size[1];
5165 let kernel_1 = &audio_cfg.sscp_conv_kernel_size[1];
5166 conv_elems += in_ch_1 * out_ch_1 * kernel_1[0] * kernel_1[1];
5167
5168 let norm_0 = out_ch_0; let norm_1 = out_ch_1; let mut f_out = audio_cfg.input_feat_size;
5174 for i in 0..2 {
5175 let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
5176 let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
5177 let pad_left = 1;
5178 let pad_right = 1;
5179 f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
5180 }
5181 let input_proj_in_features = out_ch_1 * f_out;
5182 let input_proj_linear =
5183 input_proj_in_features * audio_cfg.hidden_size / weight_pack_factor;
5184
5185 conv_elems + norm_0 + norm_1 + input_proj_linear
5186 };
5187
5188 let conformer_elems = {
5190 let mut total = 0;
5191
5192 for _ in 0..audio_cfg.conf_num_hidden_layers {
5193 let attention_elems = {
5195 let pre_attn_norm = audio_cfg.hidden_size;
5197 let post_norm = audio_cfg.hidden_size;
5198
5199 let q_proj =
5201 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5202 let k_proj =
5203 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5204 let v_proj =
5205 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5206 let post =
5207 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5208
5209 let pos_proj =
5211 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5212 let per_dim_scale =
5213 audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads; let inv_timescales = audio_cfg.hidden_size / 2; let pos_indices = audio_cfg.conf_attention_context_left
5216 + audio_cfg.conf_attention_context_right
5217 + 1;
5218
5219 let chunk_size = audio_cfg.conf_attention_chunk_size;
5221 let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5222 + audio_cfg.conf_attention_context_right;
5223 let local_causal_valid_mask = chunk_size * context_size; let invalid_logits_tensor = 1; pre_attn_norm
5227 + post_norm
5228 + q_proj
5229 + k_proj
5230 + v_proj
5231 + post
5232 + pos_proj
5233 + per_dim_scale
5234 + inv_timescales
5235 + pos_indices
5236 + local_causal_valid_mask
5237 + invalid_logits_tensor
5238 };
5239
5240 let ffw_elems = {
5242 let intermediate_size = audio_cfg.hidden_size * 4;
5248
5249 let ffw_start = {
5250 let pre_norm = audio_cfg.hidden_size;
5251 let layer_1 =
5252 audio_cfg.hidden_size * intermediate_size / weight_pack_factor;
5253 let layer_2 =
5254 intermediate_size * audio_cfg.hidden_size / weight_pack_factor;
5255 let post_norm = audio_cfg.hidden_size;
5256 pre_norm + layer_1 + layer_2 + post_norm
5257 };
5258
5259 let ffw_end = ffw_start; ffw_start + ffw_end
5262 };
5263
5264 let lconv1d_elems = {
5266 let pre_layer_norm = audio_cfg.hidden_size;
5268 let conv_norm = audio_cfg.hidden_size;
5269
5270 let linear_start = audio_cfg.hidden_size * (audio_cfg.hidden_size * 2)
5272 / weight_pack_factor;
5273 let linear_end =
5274 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5275
5276 let depthwise = audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
5278
5279 pre_layer_norm + conv_norm + linear_start + linear_end + depthwise
5280 };
5281
5282 let block_norm = audio_cfg.hidden_size;
5284
5285 total += attention_elems + ffw_elems + lconv1d_elems + block_norm;
5286 }
5287
5288 total
5289 };
5290
5291 let embed_audio_elems = {
5293 let embedding = audio_cfg.vocab_size * audio_cfg.hidden_size;
5295
5296 let hard_embedding_norm = audio_cfg.hidden_size; let soft_embedding_norm = audio_cfg.hidden_size; let embedding_post_projection_norm = text_cfg.hidden_size; let embedding_projection =
5303 audio_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5304
5305 embedding
5306 + hard_embedding_norm
5307 + soft_embedding_norm
5308 + embedding_post_projection_norm
5309 + embedding_projection
5310 };
5311
5312 subsample_conv_projection_elems + conformer_elems + embed_audio_elems
5313 };
5314
5315 let vision_dtype = if dtype == DType::F16 {
5316 DType::F32
5318 } else {
5319 dtype
5320 };
5321
5322 let total_elems = text_elems * dtype.size_in_bytes()
5323 + vision_elems * vision_dtype.size_in_bytes()
5324 + audio_elems * dtype.size_in_bytes();
5325
5326 Ok(total_elems)
5327 }
5328
5329 fn layer_sizes_in_bytes(
5330 &self,
5331 config: &str,
5332 dtype: DType,
5333 weight_pack_factor: usize,
5334 matformer_config: Option<&MatformerSliceConfig>,
5335 ) -> Result<Vec<usize>> {
5336 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5337
5338 let (text_cfg, _layer_rename_map, _layers_skipped) = if let Some(matformer_cfg) =
5340 matformer_config
5341 {
5342 use crate::device_map::DummyDeviceMapper;
5343 use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5344
5345 let dummy_mapper = DummyDeviceMapper {
5346 nm_device: Device::Cpu,
5347 };
5348 let (adjusted_cfg, _, _, layer_rename_map, layers_skipped) = handle_matformer_slicing(
5349 &cfg.text_config,
5350 &Some(matformer_cfg.clone()),
5351 &dummy_mapper,
5352 )?;
5353 (adjusted_cfg, layer_rename_map, layers_skipped)
5354 } else {
5355 (cfg.text_config.clone(), None, None)
5356 };
5357
5358 let text_cfg = &text_cfg;
5359
5360 let mut layer_sizes = Vec::new();
5362
5363 for layer_idx in 0..text_cfg.num_hidden_layers {
5367 let per_layer_elems = {
5368 let input_layernorm = text_cfg.hidden_size;
5370 let post_attention_layernorm = text_cfg.hidden_size;
5371 let pre_feedforward_layernorm = text_cfg.hidden_size;
5372 let post_feedforward_layernorm = text_cfg.hidden_size;
5373 let post_per_layer_input_norm = text_cfg.hidden_size;
5374
5375 let size_in = text_cfg.hidden_size;
5377 let size_q = text_cfg.num_attention_heads * text_cfg.head_dim;
5378 let size_kv = text_cfg.num_key_value_heads * text_cfg.head_dim;
5379
5380 let q_proj = size_in * size_q / weight_pack_factor;
5381 let k_proj = size_in * size_kv / weight_pack_factor;
5382 let v_proj = size_in * size_kv / weight_pack_factor;
5383 let o_proj = size_q * size_in / weight_pack_factor;
5384
5385 let q_norm = text_cfg.head_dim;
5387 let k_norm = text_cfg.head_dim;
5388 let v_norm = text_cfg.head_dim; let intermediate_size = match &text_cfg.intermediate_size {
5392 IntermediateSize::Single(size) => *size,
5393 IntermediateSize::PerLayer(sizes) => sizes[layer_idx],
5394 IntermediateSize::Matformer(sizes, _) => sizes[layer_idx],
5395 };
5396 let gate_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5397 let up_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5398 let down_proj = intermediate_size * text_cfg.hidden_size / weight_pack_factor;
5399
5400 let altup_elems = {
5402 let correct_output_scale = text_cfg.hidden_size;
5403 let correction_coefs = text_cfg.altup_num_inputs * text_cfg.altup_num_inputs;
5404 let prediction_coefs =
5405 text_cfg.altup_num_inputs * text_cfg.altup_num_inputs.pow(2);
5406 let modality_router = text_cfg.hidden_size * text_cfg.altup_num_inputs;
5407 let router_norm = text_cfg.hidden_size;
5408
5409 correct_output_scale
5410 + correction_coefs
5411 + prediction_coefs
5412 + modality_router
5413 + router_norm
5414 };
5415
5416 let laurel_elems = {
5418 let left = text_cfg.hidden_size * text_cfg.laurel_rank;
5419 let right = text_cfg.laurel_rank * text_cfg.hidden_size;
5420 let post_norm = text_cfg.hidden_size;
5421
5422 left + right + post_norm
5423 };
5424
5425 let per_layer_input_gate =
5427 text_cfg.hidden_size * text_cfg.hidden_size_per_layer_input;
5428 let per_layer_projection =
5429 text_cfg.hidden_size_per_layer_input * text_cfg.hidden_size;
5430
5431 input_layernorm
5432 + post_attention_layernorm
5433 + pre_feedforward_layernorm
5434 + post_feedforward_layernorm
5435 + post_per_layer_input_norm
5436 + q_proj
5437 + k_proj
5438 + v_proj
5439 + o_proj
5440 + q_norm
5441 + k_norm
5442 + v_norm
5443 + gate_proj
5444 + up_proj
5445 + down_proj
5446 + altup_elems
5447 + laurel_elems
5448 + per_layer_input_gate
5449 + per_layer_projection
5450 };
5451
5452 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5453 }
5454
5455 Ok(layer_sizes)
5456 }
5457
5458 fn num_layers(&self, config: &str) -> Result<usize> {
5459 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5460 Ok(cfg.text_config.num_hidden_layers)
5461 }
5462
5463 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5464 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5465 let cfg = cfg.text_config;
5466
5467 let cfg = ModelConfigMetadata {
5468 max_seq_len: cfg.max_position_embeddings,
5469 num_layers: cfg.num_hidden_layers,
5470 hidden_size: cfg.hidden_size,
5471 num_kv_heads: cfg.num_key_value_heads,
5472 num_attn_heads: cfg.num_attention_heads,
5473 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5475 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5476 };
5477
5478 Ok(Box::new(cfg))
5479 }
5480
5481 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5482 Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
5483 }
5484}
5485
5486pub struct Qwen3VLLoader;
5492
5493pub struct Qwen3VLPrefixer;
5494
5495impl MultimodalPromptPrefixer for Qwen3VLPrefixer {
5496 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
5497 format!(
5498 "{}{prompt}",
5499 format!(
5500 "{}{}{}",
5501 Qwen3VLProcessor::VISION_START,
5502 Qwen3VLProcessor::IMAGE_PAD,
5503 Qwen3VLProcessor::VISION_END
5504 )
5505 .repeat(image_indexes.len())
5506 )
5507 }
5508}
5509
5510impl VisionModelLoader for Qwen3VLLoader {
5511 fn load(
5512 &self,
5513 config: &str,
5514 vb: ShardedVarBuilder,
5515 normal_loading_metadata: NormalLoadingMetadata,
5516 attention_mechanism: AttentionImplementation,
5517 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
5518 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5519 Ok(Box::new(Qwen3VLModel::new(
5520 &cfg,
5521 vb,
5522 self.is_gptx(config),
5523 normal_loading_metadata,
5524 attention_mechanism,
5525 )?))
5526 }
5527 fn is_gptx(&self, _config: &str) -> bool {
5528 true
5529 }
5530 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5531 let config: Qwen3VLConfig = serde_json::from_str(config)?;
5532 Ok(Box::new(config))
5533 }
5534 fn get_processor(
5535 &self,
5536 _model_config: &str,
5537 _processor_config: Option<ProcessorConfig>,
5538 _preprocessor_config: PreProcessorConfig,
5539 max_edge: Option<u32>,
5540 ) -> Arc<dyn Processor + Send + Sync> {
5541 Arc::new(Qwen3VLProcessor::new(max_edge))
5542 }
5543 fn supports_paged_attention(&self, _config: &str) -> bool {
5544 true
5545 }
5546 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5547 Arc::new(Qwen3VLPrefixer)
5548 }
5549 fn modalities(&self, _config: &str) -> Result<Modalities> {
5550 Ok(Modalities {
5551 input: vec![SupportedModality::Text, SupportedModality::Vision],
5552 output: vec![SupportedModality::Text],
5553 })
5554 }
5555}
5556
5557impl IsqModelLoader for Qwen3VLLoader {
5558 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5559 Ok(vec![
5560 Regex::new(r"lm_head\.(weight|bias)$")?,
5561 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5563 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5564 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5565 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5566 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5568 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5569 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5570 ])
5571 }
5572 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5573 self.isq_layer_regexes(config)
5574 }
5575}
5576
5577impl DeviceMappedModelLoader for Qwen3VLLoader {
5578 fn mapped_max_act_size_elems(
5579 &self,
5580 config: &str,
5581 params: &AutoDeviceMapParams,
5582 ) -> Result<usize> {
5583 let AutoDeviceMapParams::Vision {
5584 max_seq_len,
5585 max_batch_size,
5586 max_image_shape,
5587 max_num_images,
5588 } = params
5589 else {
5590 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5591 };
5592
5593 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5594
5595 let img_seq_len = {
5596 let cfg = &cfg.vision_config;
5597 let grid_t = max_num_images / cfg.temporal_patch_size;
5598 let grid_h = max_image_shape.0 / cfg.patch_size;
5599 let grid_w = max_image_shape.1 / cfg.patch_size;
5600 grid_t * grid_h * grid_w
5601 };
5602 let img_seq_len = img_seq_len * max_num_images;
5603
5604 let max_text_attn = {
5605 let cfg = &cfg.text_config;
5606 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5608 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5609 };
5610
5611 Ok(max_text_attn)
5612 }
5613
5614 fn non_mapped_max_act_size_elems(
5615 &self,
5616 config: &str,
5617 params: &AutoDeviceMapParams,
5618 ) -> Result<usize> {
5619 let AutoDeviceMapParams::Vision {
5620 max_seq_len: _,
5621 max_batch_size,
5622 max_image_shape,
5623 max_num_images,
5624 } = params
5625 else {
5626 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5627 };
5628
5629 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5630
5631 let img_seq_len = {
5632 let cfg = &cfg.vision_config;
5633 let grid_t = max_num_images / cfg.temporal_patch_size;
5634 let grid_h = max_image_shape.0 / cfg.patch_size;
5635 let grid_w = max_image_shape.1 / cfg.patch_size;
5636 grid_t * grid_h * grid_w
5637 };
5638
5639 let max_vision_attn = {
5640 let cfg = &cfg.vision_config;
5641 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5642 };
5643
5644 Ok(max_vision_attn)
5645 }
5646
5647 fn non_mapped_size_in_bytes(
5648 &self,
5649 config: &str,
5650 dtype: DType,
5651 weight_pack_factor: usize,
5652 _matformer_config: Option<&MatformerSliceConfig>,
5653 ) -> Result<usize> {
5654 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5655 let tie = cfg.tie_word_embeddings;
5656 let text_elems = {
5657 let cfg = &cfg.text_config;
5658 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5659 let lm_head = if !tie || weight_pack_factor != 1 {
5661 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5662 } else {
5663 0
5664 };
5665 let norm = cfg.hidden_size;
5666 embed_tokens + lm_head + norm
5667 };
5668
5669 let patch_merger = {
5670 let cfg = &cfg.vision_config;
5671 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
5672
5673 let mlp0 = hidden_size * hidden_size + hidden_size;
5674 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
5675
5676 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5677
5678 mlp0 + mlp2 + ln_q
5679 };
5680
5681 let patch_embed = {
5682 let cfg = &cfg.vision_config;
5683 let conv_cfg = Conv3dConfig {
5684 stride: cfg.patch_size,
5685 ..Default::default()
5686 };
5687 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
5688 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
5689 * kernel_sizes[0]
5690 * kernel_sizes[1]
5691 * kernel_sizes[2]
5692 };
5693
5694 let encoder_layer = {
5695 let cfg = &cfg.vision_config;
5696 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5697 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5698
5699 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
5700 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
5701 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
5702
5703 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
5704 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
5705
5706 norm1 + norm2 + fc1 + fc2 + qkv + out
5707 };
5708
5709 let elems =
5710 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
5711
5712 Ok(elems * dtype.size_in_bytes())
5713 }
5714
5715 fn layer_sizes_in_bytes(
5716 &self,
5717 config: &str,
5718 dtype: DType,
5719 weight_pack_factor: usize,
5720 _matformer_config: Option<&MatformerSliceConfig>,
5721 ) -> Result<Vec<usize>> {
5722 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5723 let per_layer_elems = {
5724 let cfg = &cfg.text_config;
5725 let input_layernorm = cfg.hidden_size;
5726 let post_attention_layernorm = cfg.hidden_size;
5727
5728 let size_in = cfg.hidden_size;
5729 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
5730 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
5731 let q_proj = size_in * size_q / weight_pack_factor + size_q;
5732 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
5733 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
5734 let o_proj = size_q * size_in / weight_pack_factor;
5735
5736 let h_size = cfg.hidden_size;
5737 let i_size = cfg.intermediate_size;
5738 let gate_proj = h_size * i_size / weight_pack_factor;
5739 let up_proj = h_size * i_size / weight_pack_factor;
5740 let down_proj = i_size * h_size / weight_pack_factor;
5741
5742 input_layernorm
5743 + post_attention_layernorm
5744 + q_proj
5745 + k_proj
5746 + v_proj
5747 + o_proj
5748 + gate_proj
5749 + up_proj
5750 + down_proj
5751 };
5752 Ok(vec![
5753 per_layer_elems * dtype.size_in_bytes();
5754 cfg.text_config.num_hidden_layers
5755 ])
5756 }
5757
5758 fn num_layers(&self, config: &str) -> Result<usize> {
5759 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5760 let cfg = &cfg.text_config;
5761 Ok(cfg.num_hidden_layers)
5762 }
5763
5764 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5765 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5766 let cfg = &cfg.text_config;
5767
5768 let cfg = ModelConfigMetadata {
5769 max_seq_len: cfg.max_position_embeddings,
5770 num_layers: cfg.num_hidden_layers,
5771 hidden_size: cfg.hidden_size,
5772 num_kv_heads: cfg.num_key_value_heads,
5773 num_attn_heads: cfg.num_attention_heads,
5774 sliding_window: cfg.sliding_window,
5775 k_head_dim: cfg.head_dim,
5776 v_head_dim: cfg.head_dim,
5777 };
5778
5779 Ok(Box::new(cfg))
5780 }
5781
5782 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5783 Some(vec![NonMappedSubModel::Vision])
5784 }
5785}