1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::attention::ATTENTION_CHUNK_SIZE;
24use crate::device_map::DeviceMapper;
25use crate::layers::Conv3dConfig;
26use crate::matformer::MatformerSliceConfig;
27use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
28use crate::pipeline::isq::IsqModelLoader;
29use crate::pipeline::loaders::AutoDeviceMapParams;
30use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
31use crate::pipeline::{
32 EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
33 SupportedModality,
34};
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::vision_models::clip::ClipConfig;
37use crate::vision_models::gemma3::config::Gemma3Config;
38use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
39use crate::vision_models::gemma3n::config::{Gemma3nConfig, IntermediateSize};
40use crate::vision_models::gemma3n::{Gemma3nModel, Gemma3nProcessor};
41use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
42use crate::vision_models::idefics2_input_processor::Idefics2Processor;
43use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
44use crate::vision_models::image_processor::ImagePreProcessor;
45use crate::vision_models::inputs_processor::Phi4MMProcessor;
46use crate::vision_models::llama4::{
47 self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
48};
49use crate::vision_models::llava::config::Config as LLaVAConfig;
50use crate::vision_models::llava15::Model as LLaVA;
51use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
52use crate::vision_models::llava_next::Model as LLaVANext;
53use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
54use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
55use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
56use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
57use crate::vision_models::phi3_inputs_processor::Phi3Processor;
58use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
59use crate::vision_models::preprocessor_config::PreProcessorConfig;
60use crate::vision_models::processor_config::ProcessorConfig;
61use crate::vision_models::qwen2_5_vl::{
62 Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
63};
64use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
65use crate::vision_models::qwen3_vl::{Config as Qwen3VLConfig, Qwen3VLModel, Qwen3VLProcessor};
66use crate::vision_models::qwen3_vl_moe::{
67 Config as Qwen3VLMoEConfig, Qwen3VLMoEModel, Qwen3VLMoEProcessor,
68};
69use crate::vision_models::{minicpmo, phi4};
70
71pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
72 #[allow(clippy::too_many_arguments)]
74 fn forward(
75 &self,
76 input_ids: &Tensor,
77 pixel_values: Option<Tensor>,
78 seqlen_offsets: &[usize],
79 context_lens: Vec<(usize, usize)>,
80 position_ids: Vec<usize>,
81 model_specific_args: Box<dyn Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
83 flash_params: &FlashParams,
84 ) -> candle_core::Result<Tensor>;
85 fn device(&self) -> &Device;
86 fn cache(&self) -> &EitherCache;
87 fn cache_mut(&mut self) -> &mut EitherCache;
88 fn max_seq_len(&self) -> usize;
89 fn config(&self) -> &ModelConfigMetadata;
90 fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
92}
93
94pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
95 fn load(
96 &self,
97 config: &str,
98 vb: ShardedVarBuilder,
99 normal_loading_metadata: NormalLoadingMetadata,
100 attention_mechanism: AttentionImplementation,
101 ) -> Result<Box<dyn VisionModel + Send + Sync>>;
102 fn is_gptx(&self, config: &str) -> bool;
103 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
104 fn get_processor(
105 &self,
106 model_config: &str,
107 processor_config: Option<ProcessorConfig>,
108 preprocessor_config: PreProcessorConfig,
109 max_edge: Option<u32>,
110 ) -> Arc<dyn Processor + Send + Sync>;
111 fn supports_paged_attention(&self, config: &str) -> bool;
112 fn supports_prefix_cacher(&self, _config: &str) -> bool {
113 false
115 }
116 fn modalities(&self, config: &str) -> Result<Modalities>;
117 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
118 fn get_device_for_tensor(
119 &self,
120 config: &str,
121 _mapper: &dyn DeviceMapper,
122 loading_isq: bool,
123 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
124 if loading_isq {
125 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
126 } else {
127 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
128 let num_layers = self.model_config(config)?.num_layers();
129 let closure = move |name: String| {
130 if let Some(captures) = re.captures(&name) {
131 captures
132 .get(1)
133 .and_then(|m| m.as_str().parse::<usize>().ok())
134 .map(|l| l.min(num_layers))
135 .map(DeviceForLoadTensor::Idx)
136 .unwrap_or(DeviceForLoadTensor::Base)
137 } else {
138 DeviceForLoadTensor::Base
139 }
140 };
141
142 Ok(Arc::new(closure))
143 }
144 }
145}
146
147#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
148#[derive(Clone, Debug, Deserialize, PartialEq)]
149pub enum VisionLoaderType {
151 #[serde(rename = "phi3v")]
152 Phi3V,
153 #[serde(rename = "idefics2")]
154 Idefics2,
155 #[serde(rename = "llava_next")]
156 LLaVANext,
157 #[serde(rename = "llava")]
158 LLaVA,
159 #[serde(rename = "vllama")]
160 VLlama,
161 #[serde(rename = "qwen2vl")]
162 Qwen2VL,
163 #[serde(rename = "idefics3")]
164 Idefics3,
165 #[serde(rename = "minicpmo")]
166 MiniCpmO,
167 #[serde(rename = "phi4mm")]
168 Phi4MM,
169 #[serde(rename = "qwen2_5vl")]
170 Qwen2_5VL,
171 #[serde(rename = "gemma3")]
172 Gemma3,
173 #[serde(rename = "mistral3")]
174 Mistral3,
175 #[serde(rename = "llama4")]
176 Llama4,
177 #[serde(rename = "gemma3n")]
178 Gemma3n,
179 #[serde(rename = "qwen3vl")]
180 Qwen3VL,
181 #[serde(rename = "qwen3vlmoe")]
182 Qwen3VLMoE,
183}
184
185impl VisionLoaderType {
187 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
188 match name {
189 "Phi3VForCausalLM" => Ok(Self::Phi3V),
190 "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
191 "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
192 "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
193 "MllamaForConditionalGeneration" => Ok(Self::VLlama),
194 "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
195 "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
196 "MiniCPMO" => Ok(Self::MiniCpmO),
197 "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
198 "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
199 "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
200 "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
201 "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
202 "Gemma3nForConditionalGeneration" => Ok(Self::Gemma3n),
203 "Qwen3VLForConditionalGeneration" => Ok(Self::Qwen3VL),
204 "Qwen3VLMoeForConditionalGeneration" => Ok(Self::Qwen3VLMoE),
205 other => anyhow::bail!(
206 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
207 ),
208 }
209 }
210}
211
212impl FromStr for VisionLoaderType {
213 type Err = String;
214 fn from_str(s: &str) -> Result<Self, Self::Err> {
215 match s {
216 "phi3v" => Ok(Self::Phi3V),
217 "idefics2" => Ok(Self::Idefics2),
218 "llava_next" => Ok(Self::LLaVANext),
219 "llava" => Ok(Self::LLaVA),
220 "vllama" => Ok(Self::VLlama),
221 "qwen2vl" => Ok(Self::Qwen2VL),
222 "idefics3" => Ok(Self::Idefics3),
223 "minicpmo" => Ok(Self::MiniCpmO),
224 "phi4mm" => Ok(Self::Phi4MM),
225 "qwen2_5vl" => Ok(Self::Qwen2_5VL),
226 "gemma3" => Ok(Self::Gemma3),
227 "mistral3" => Ok(Self::Mistral3),
228 "llama4" => Ok(Self::Llama4),
229 "gemma3n" => Ok(Self::Gemma3n),
230 "qwen3vl" => Ok(Self::Qwen3VL),
231 "qwen3vlmoe" => Ok(Self::Qwen3VLMoE),
232 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`, `gemma3n`, `qwen3vl`, `qwen3vlmoe`.")),
233 }
234 }
235}
236
237impl std::fmt::Display for VisionLoaderType {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 let name = match self {
240 VisionLoaderType::Phi3V => "phi3v",
241 VisionLoaderType::Idefics2 => "idefics2",
242 VisionLoaderType::LLaVANext => "llava_next",
243 VisionLoaderType::LLaVA => "llava",
244 VisionLoaderType::VLlama => "vllama",
245 VisionLoaderType::Qwen2VL => "qwen2vl",
246 VisionLoaderType::Idefics3 => "idefics3",
247 VisionLoaderType::MiniCpmO => "minicpmo",
248 VisionLoaderType::Phi4MM => "phi4mm",
249 VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
250 VisionLoaderType::Gemma3 => "gemma3",
251 VisionLoaderType::Mistral3 => "mistral3",
252 VisionLoaderType::Llama4 => "llama4",
253 VisionLoaderType::Gemma3n => "gemma3n",
254 VisionLoaderType::Qwen3VL => "qwen3vl",
255 VisionLoaderType::Qwen3VLMoE => "qwen3vlmoe",
256 };
257 write!(f, "{name}")
258 }
259}
260
261#[derive(Deserialize)]
262struct AutoVisionLoaderConfig {
263 architectures: Vec<String>,
264}
265
266pub struct AutoVisionLoader;
268
269impl AutoVisionLoader {
270 fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
271 let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
272 if auto_cfg.architectures.len() != 1 {
273 anyhow::bail!("Expected exactly one architecture in config");
274 }
275
276 let name = &auto_cfg.architectures[0];
277 let tp = VisionLoaderType::from_causal_lm_name(name)?;
278
279 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
280
281 Ok(match tp {
283 VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
284 VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
285 VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
286 VisionLoaderType::LLaVA => Box::new(LLaVALoader),
287 VisionLoaderType::VLlama => Box::new(VLlamaLoader),
288 VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
289 VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
290 VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
291 VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
292 VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
293 VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
294 VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
295 VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
296 VisionLoaderType::Gemma3n => Box::new(Gemma3nLoader),
297 VisionLoaderType::Qwen3VL => Box::new(Qwen3VLLoader),
298 VisionLoaderType::Qwen3VLMoE => Box::new(Qwen3VLMoELoader),
299 })
300 }
301}
302
303impl VisionModelLoader for AutoVisionLoader {
304 fn load(
305 &self,
306 config: &str,
307 vb: ShardedVarBuilder,
308 normal_loading_metadata: NormalLoadingMetadata,
309 attention_mechanism: AttentionImplementation,
310 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
311 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
312 }
313
314 fn is_gptx(&self, config: &str) -> bool {
315 Self::get_loader(config)
316 .expect("AutoVisionLoader get_loader")
317 .is_gptx(config)
318 }
319
320 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
321 Self::get_loader(config)?.get_config_repr(config)
322 }
323
324 fn get_processor(
325 &self,
326 model_config: &str,
327 proc_cfg: Option<ProcessorConfig>,
328 preproc_cfg: PreProcessorConfig,
329 max_edge: Option<u32>,
330 ) -> Arc<dyn Processor + Send + Sync> {
331 Self::get_loader(model_config)
332 .expect("AutoVisionLoader get_loader")
333 .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
334 }
335
336 fn supports_paged_attention(&self, config: &str) -> bool {
337 Self::get_loader(config)
338 .expect("AutoVisionLoader")
339 .supports_paged_attention(config)
340 }
341
342 fn modalities(&self, config: &str) -> Result<Modalities> {
343 Self::get_loader(config)?.modalities(config)
344 }
345
346 fn supports_prefix_cacher(&self, config: &str) -> bool {
347 Self::get_loader(config)
348 .expect("AutoVisionLoader")
349 .supports_prefix_cacher(config)
350 }
351
352 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
353 Self::get_loader(config)
354 .expect("AutoVisionLoader")
355 .prefixer(config)
356 }
357
358 fn get_device_for_tensor(
359 &self,
360 config: &str,
361 mapper: &dyn DeviceMapper,
362 loading_isq: bool,
363 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
364 Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
365 }
366}
367
368impl IsqModelLoader for AutoVisionLoader {
369 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
370 Self::get_loader(config)?.isq_layer_regexes(config)
371 }
372 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
373 Self::get_loader(config)?.immediate_isq_predicates(config)
374 }
375}
376
377impl DeviceMappedModelLoader for AutoVisionLoader {
378 fn mapped_max_act_size_elems(
379 &self,
380 config: &str,
381 params: &AutoDeviceMapParams,
382 ) -> Result<usize> {
383 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
384 }
385 fn non_mapped_max_act_size_elems(
386 &self,
387 config: &str,
388 params: &AutoDeviceMapParams,
389 ) -> Result<usize> {
390 Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
391 }
392 fn non_mapped_size_in_bytes(
393 &self,
394 config: &str,
395 dtype: DType,
396 weight_pack_factor: usize,
397 _matformer_config: Option<&MatformerSliceConfig>,
398 ) -> Result<usize> {
399 Self::get_loader(config)?.non_mapped_size_in_bytes(
400 config,
401 dtype,
402 weight_pack_factor,
403 _matformer_config,
404 )
405 }
406 fn layer_sizes_in_bytes(
407 &self,
408 config: &str,
409 dtype: DType,
410 weight_pack_factor: usize,
411 _matformer_config: Option<&MatformerSliceConfig>,
412 ) -> Result<Vec<usize>> {
413 Self::get_loader(config)?.layer_sizes_in_bytes(
414 config,
415 dtype,
416 weight_pack_factor,
417 _matformer_config,
418 )
419 }
420 fn num_layers(&self, config: &str) -> Result<usize> {
421 Self::get_loader(config)?.num_layers(config)
422 }
423 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
424 Self::get_loader(config)?.model_config(config)
425 }
426}
427
428macro_rules! bias_if {
429 ($cond:expr, $size:expr) => {
430 if $cond {
431 $size
432 } else {
433 0
434 }
435 };
436}
437
438fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
439 let pre_layer_norm = cfg.hidden_size;
440 let final_layer_norm = cfg.hidden_size;
441
442 let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
443 let num_positions = num_patches + 1;
444
445 let class_embedding = cfg.hidden_size;
446
447 let position_ids = num_positions;
448 let position_embedding = num_positions * cfg.hidden_size;
449
450 let conv2dconfig = Conv2dConfig {
451 stride: cfg.patch_size,
452 ..Default::default()
453 };
454 let patch_embedding =
455 cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
456
457 let encoder_layer_elems = {
458 let layer_norm1 = cfg.hidden_size;
459 let layer_norm2 = cfg.hidden_size;
460
461 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
462 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
463 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
464 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
465
466 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
467 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
468
469 layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
470 };
471
472 pre_layer_norm
473 + final_layer_norm
474 + class_embedding
475 + position_ids
476 + position_embedding
477 + patch_embedding
478 + cfg.num_hidden_layers * encoder_layer_elems
479}
480
481pub struct Phi3VLoader;
487
488pub struct Phi3VPrefixer;
489
490impl MultimodalPromptPrefixer for Phi3VPrefixer {
491 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
492 format!(
494 "{}{prompt}",
495 image_indexes
496 .into_iter()
497 .map(|image_index| format!("<|image_{}|>", image_index + 1))
498 .join("")
499 )
500 }
501}
502
503impl VisionModelLoader for Phi3VLoader {
504 fn load(
505 &self,
506 config: &str,
507 vb: ShardedVarBuilder,
508 normal_loading_metadata: NormalLoadingMetadata,
509 attention_mechanism: AttentionImplementation,
510 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
511 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
512 Ok(Box::new(Phi3::new(
513 &cfg,
514 vb,
515 self.is_gptx(config),
516 normal_loading_metadata,
517 attention_mechanism,
518 )?))
519 }
520 fn is_gptx(&self, _config: &str) -> bool {
521 true
522 }
523 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
524 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
525 Ok(Box::new(cfg))
526 }
527 fn get_processor(
528 &self,
529 _model_config: &str,
530 processor_config: Option<ProcessorConfig>,
531 preprocessor_config: PreProcessorConfig,
532 _max_edge: Option<u32>,
533 ) -> Arc<dyn Processor + Send + Sync> {
534 Phi3Processor::new_processor(processor_config, preprocessor_config)
535 }
536 fn supports_paged_attention(&self, _config: &str) -> bool {
537 true
538 }
539 fn supports_prefix_cacher(&self, _config: &str) -> bool {
540 true
541 }
542 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
543 Arc::new(Phi3VPrefixer)
544 }
545 fn modalities(&self, _config: &str) -> Result<Modalities> {
546 Ok(Modalities {
547 input: vec![SupportedModality::Text, SupportedModality::Vision],
548 output: vec![SupportedModality::Text],
549 })
550 }
551}
552
553impl IsqModelLoader for Phi3VLoader {
554 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
555 Ok(vec![
556 Regex::new(r"lm_head\.(weight|bias)$")?,
557 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
559 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
560 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
562 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
563 ])
564 }
565 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
566 self.isq_layer_regexes(config)
567 }
568}
569
570impl DeviceMappedModelLoader for Phi3VLoader {
571 fn mapped_max_act_size_elems(
572 &self,
573 config: &str,
574 params: &AutoDeviceMapParams,
575 ) -> Result<usize> {
576 let AutoDeviceMapParams::Vision {
578 max_seq_len,
579 max_batch_size,
580 max_image_shape: _,
581 max_num_images,
582 } = params
583 else {
584 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
585 };
586
587 let cfg: Phi3Config = serde_json::from_str(config)?;
588
589 let vcfg = &PHI3V_CLIP_CONFIG;
590
591 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
592 let img_seq_len = (num_patches + 1) * max_num_images;
593
594 let max_text_attn = {
595 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
597 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
598 };
599
600 Ok(max_text_attn)
601 }
602
603 fn non_mapped_max_act_size_elems(
604 &self,
605 config: &str,
606 params: &AutoDeviceMapParams,
607 ) -> Result<usize> {
608 let AutoDeviceMapParams::Vision {
610 max_seq_len: _,
611 max_batch_size,
612 max_image_shape: _,
613 max_num_images,
614 } = params
615 else {
616 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
617 };
618
619 let cfg: Phi3Config = serde_json::from_str(config)?;
620
621 let vcfg = &PHI3V_CLIP_CONFIG;
622
623 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
624 let img_seq_len = num_patches + 1;
625
626 let max_vision_attn = {
627 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
628 };
629
630 Ok(max_vision_attn)
631 }
632
633 fn non_mapped_size_in_bytes(
634 &self,
635 config: &str,
636 dtype: DType,
637 weight_pack_factor: usize,
638 _matformer_config: Option<&MatformerSliceConfig>,
639 ) -> Result<usize> {
640 let cfg: Phi3Config = serde_json::from_str(config)?;
641 let elems = {
642 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
643 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
645 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
646 } else {
647 0
648 };
649 let norm = cfg.hidden_size;
650
651 let image_embed = {
652 let projection_cls = cfg
653 .embd_layer
654 .projection_cls
655 .clone()
656 .unwrap_or("linear".to_string());
657 let with_learnable_separator =
658 cfg.embd_layer.with_learnable_separator.unwrap_or(false);
659 let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
660 let image_dim_out = cfg.img_processor.image_dim_out;
661
662 let proj = match (projection_cls.as_str(), use_hd_transform) {
663 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
664 ("mlp", true) => {
665 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
666 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
667 a + b
668 }
669 ("mlp", false) => {
670 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
671 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
672 a + b
673 }
674 _ => {
675 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
676 }
677 };
678
679 let (glb_gn, sub_gn) = if with_learnable_separator {
680 let glb_gn = image_dim_out * 4;
681 let sub_gn = image_dim_out * 4;
682 (glb_gn, sub_gn)
683 } else {
684 (0, 0)
685 };
686
687 let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
688
689 proj + glb_gn + sub_gn + clip_vit
690 };
691
692 embed_tokens + lm_head + norm + image_embed
693 };
694
695 Ok(elems * dtype.size_in_bytes())
696 }
697
698 fn layer_sizes_in_bytes(
699 &self,
700 config: &str,
701 dtype: DType,
702 weight_pack_factor: usize,
703 _matformer_config: Option<&MatformerSliceConfig>,
704 ) -> Result<Vec<usize>> {
705 let cfg: Phi3Config = serde_json::from_str(config)?;
706 let per_layer_elems = {
707 let input_layernorm = cfg.hidden_size;
708 let post_attention_layernorm = cfg.hidden_size;
709
710 let size_in = cfg.hidden_size;
711 let head_dim = cfg.head_dim();
712 let op_size =
713 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
714 let qkv_proj = size_in * op_size / weight_pack_factor;
715 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
716
717 let h_size = cfg.hidden_size;
718 let i_size = cfg.intermediate_size;
719 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
720 let down_proj = h_size * i_size / weight_pack_factor;
721
722 input_layernorm
723 + post_attention_layernorm
724 + qkv_proj
725 + o_proj
726 + gate_up_proj
727 + down_proj
728 };
729 Ok(vec![
730 per_layer_elems * dtype.size_in_bytes();
731 cfg.num_hidden_layers
732 ])
733 }
734
735 fn num_layers(&self, config: &str) -> Result<usize> {
736 let cfg: Phi3Config = serde_json::from_str(config)?;
737 Ok(cfg.num_hidden_layers)
738 }
739
740 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
741 let cfg: Phi3Config = serde_json::from_str(config)?;
742
743 let cfg = ModelConfigMetadata {
744 max_seq_len: cfg.max_position_embeddings,
745 num_layers: cfg.num_hidden_layers,
746 hidden_size: cfg.hidden_size,
747 num_kv_heads: cfg.num_key_value_heads,
748 num_attn_heads: cfg.num_attention_heads,
749 sliding_window: cfg.sliding_window,
750 k_head_dim: cfg.head_dim(),
751 v_head_dim: cfg.head_dim(),
752 };
753
754 Ok(Box::new(cfg))
755 }
756
757 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
758 Some(vec![NonMappedSubModel::Vision])
759 }
760}
761
762pub struct Idefics2Loader;
768
769pub struct Idefics2Prefixer;
770
771impl MultimodalPromptPrefixer for Idefics2Prefixer {
772 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
773 prompt.to_string()
775 }
776}
777
778impl VisionModelLoader for Idefics2Loader {
779 fn load(
780 &self,
781 config: &str,
782 vb: ShardedVarBuilder,
783 normal_loading_metadata: NormalLoadingMetadata,
784 attention_mechanism: AttentionImplementation,
785 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
786 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
787 Ok(Box::new(Idefics2::new(
788 &cfg,
789 vb,
790 self.is_gptx(config),
791 normal_loading_metadata,
792 attention_mechanism,
793 )?))
794 }
795 fn is_gptx(&self, _config: &str) -> bool {
796 true
797 }
798 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
799 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
800 Ok(Box::new(cfg))
801 }
802 fn get_processor(
803 &self,
804 _model_config: &str,
805 processor_config: Option<ProcessorConfig>,
806 preprocessor_config: PreProcessorConfig,
807 max_edge: Option<u32>,
808 ) -> Arc<dyn Processor + Send + Sync> {
809 Arc::new(Idefics2Processor::new(
810 processor_config.unwrap(),
811 preprocessor_config,
812 max_edge,
813 ))
814 }
815 fn supports_paged_attention(&self, _config: &str) -> bool {
816 true
817 }
818 fn supports_prefix_cacher(&self, _config: &str) -> bool {
819 true
820 }
821 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
822 Arc::new(Idefics2Prefixer)
823 }
824 fn modalities(&self, _config: &str) -> Result<Modalities> {
825 Ok(Modalities {
826 input: vec![SupportedModality::Text, SupportedModality::Vision],
827 output: vec![SupportedModality::Text],
828 })
829 }
830}
831
832impl IsqModelLoader for Idefics2Loader {
833 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
834 Ok(vec![
835 Regex::new(r"lm_head\.(weight|bias)$")?,
836 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
838 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
839 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
840 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
841 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
843 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
844 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
845 ])
846 }
847 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
848 Ok(vec![
849 Regex::new(r"lm_head\.(weight|bias)$")?,
850 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
852 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
853 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
854 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
855 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
857 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
858 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
859 ])
860 }
861}
862
863impl DeviceMappedModelLoader for Idefics2Loader {
864 fn mapped_max_act_size_elems(
865 &self,
866 config: &str,
867 params: &AutoDeviceMapParams,
868 ) -> Result<usize> {
869 let AutoDeviceMapParams::Vision {
870 max_seq_len,
871 max_batch_size,
872 max_image_shape: _,
873 max_num_images,
874 } = params
875 else {
876 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
877 };
878
879 let cfg: Idefics2Config = serde_json::from_str(config)?;
880
881 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
882 let img_seq_len = (num_patches + 1) * max_num_images;
883
884 let max_text_attn = {
885 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
887 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
888 };
889
890 Ok(max_text_attn)
891 }
892
893 fn non_mapped_max_act_size_elems(
894 &self,
895 config: &str,
896 params: &AutoDeviceMapParams,
897 ) -> Result<usize> {
898 let AutoDeviceMapParams::Vision {
899 max_seq_len: _,
900 max_batch_size,
901 max_image_shape: _,
902 max_num_images,
903 } = params
904 else {
905 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
906 };
907
908 let cfg: Idefics2Config = serde_json::from_str(config)?;
909
910 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
911 let img_seq_len = num_patches + 1;
912
913 let max_vision_attn = {
914 let images_factor = 5;
916
917 (max_batch_size * images_factor * max_num_images)
918 * cfg.vision_config.num_attention_heads
919 * img_seq_len
920 * img_seq_len
921 };
922
923 Ok(max_vision_attn)
924 }
925
926 fn non_mapped_size_in_bytes(
927 &self,
928 config: &str,
929 dtype: DType,
930 weight_pack_factor: usize,
931 _matformer_config: Option<&MatformerSliceConfig>,
932 ) -> Result<usize> {
933 let cfg: Idefics2Config = serde_json::from_str(config)?;
934 let text_elems = {
935 let tie_word_embeddings = cfg.tie_word_embeddings;
936 let cfg = &cfg.text_config;
937
938 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
939 let lm_head = if !tie_word_embeddings {
940 cfg.hidden_size * cfg.vocab_size
941 } else {
942 0
943 };
944 let norm = cfg.hidden_size;
945 embed_tokens + lm_head + norm
946 };
947
948 let connector_elems = {
949 let tcfg = &cfg.text_config;
950 let vcfg = &cfg.vision_config;
951 let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
952 let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
953 let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
954
955 let perceiver_elems = {
956 let tcfg = &cfg.text_config;
957 let pcfg = &cfg.perceiver_config;
958
959 let n_latents = pcfg.resampler_n_latents;
960 let hidden_size = tcfg.hidden_size;
961 let depth = pcfg.resampler_depth;
962
963 let norm = tcfg.hidden_size;
964 let latents = n_latents * hidden_size;
965
966 let layer_elems = {
967 let input_latents_norm = hidden_size;
968 let input_context_norm = hidden_size;
969 let post_attn_norm = hidden_size;
970
971 let num_heads = pcfg.resampler_n_heads;
972 let head_dim = pcfg.resampler_head_dim;
973 let num_key_value_heads = pcfg.num_key_value_heads;
974
975 let q_proj = hidden_size * num_heads * head_dim;
976 let k_proj = hidden_size * num_key_value_heads * head_dim;
977 let v_proj = hidden_size * num_key_value_heads * head_dim;
978 let o_proj = num_heads * head_dim * hidden_size;
979
980 let gate_proj = hidden_size * hidden_size * 4;
981 let up_proj = hidden_size * hidden_size * 4;
982 let down_proj = hidden_size * 4 * hidden_size;
983
984 input_latents_norm
985 + input_context_norm
986 + post_attn_norm
987 + q_proj
988 + k_proj
989 + v_proj
990 + o_proj
991 + gate_proj
992 + up_proj
993 + down_proj
994 };
995
996 norm + latents + layer_elems * depth
997 };
998
999 gate_proj + up_proj + down_proj + perceiver_elems
1000 };
1001
1002 let vision_transformer = {
1003 let cfg = &cfg.vision_config;
1004
1005 let post_layernorm = cfg.hidden_size;
1006
1007 let conv_config = Conv2dConfig {
1008 stride: cfg.patch_size,
1009 ..Default::default()
1010 };
1011 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
1012 * cfg.patch_size
1013 * cfg.patch_size;
1014
1015 let num_patches_per_side = cfg.image_size / cfg.patch_size;
1016 let num_patches = num_patches_per_side.pow(2);
1017 let position_embedding = num_patches * cfg.hidden_size;
1018
1019 let layer_elems = {
1020 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1021 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1022
1023 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
1024 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
1025
1026 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1027 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1028 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1029 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1030
1031 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
1032 };
1033
1034 post_layernorm + patch_embedding + position_embedding + layer_elems
1035 };
1036
1037 let elems = text_elems + connector_elems + vision_transformer;
1038
1039 Ok(elems * dtype.size_in_bytes())
1040 }
1041
1042 fn layer_sizes_in_bytes(
1043 &self,
1044 config: &str,
1045 dtype: DType,
1046 weight_pack_factor: usize,
1047 _matformer_config: Option<&MatformerSliceConfig>,
1048 ) -> Result<Vec<usize>> {
1049 let cfg: Idefics2Config = serde_json::from_str(config)?;
1050 let cfg = cfg.text_config;
1051 let per_layer_elems = {
1052 let input_layernorm = cfg.hidden_size;
1053 let post_attention_layernorm = cfg.hidden_size;
1054
1055 let size_in = cfg.hidden_size;
1056 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1057 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1058 let q_proj = size_in * size_q / weight_pack_factor;
1059 let k_proj = size_in * size_kv / weight_pack_factor;
1060 let v_proj = size_in * size_kv / weight_pack_factor;
1061 let o_proj = size_q * size_in / weight_pack_factor;
1062
1063 let h_size = cfg.hidden_size;
1064 let i_size = cfg.intermediate_size;
1065 let gate_proj = h_size * i_size / weight_pack_factor;
1066 let up_proj = h_size * i_size / weight_pack_factor;
1067 let down_proj = i_size * h_size / weight_pack_factor;
1068
1069 input_layernorm
1070 + post_attention_layernorm
1071 + q_proj
1072 + k_proj
1073 + v_proj
1074 + o_proj
1075 + gate_proj
1076 + up_proj
1077 + down_proj
1078 };
1079 Ok(vec![
1080 per_layer_elems * dtype.size_in_bytes();
1081 cfg.num_hidden_layers
1082 ])
1083 }
1084
1085 fn num_layers(&self, config: &str) -> Result<usize> {
1086 let cfg: Idefics2Config = serde_json::from_str(config)?;
1087 Ok(cfg.text_config.num_hidden_layers)
1088 }
1089 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1090 let cfg: Idefics2Config = serde_json::from_str(config)?;
1091 let cfg = &cfg.text_config;
1092
1093 let cfg = ModelConfigMetadata {
1094 max_seq_len: cfg.max_position_embeddings,
1095 num_layers: cfg.num_hidden_layers,
1096 hidden_size: cfg.hidden_size,
1097 num_kv_heads: cfg.num_key_value_heads,
1098 num_attn_heads: cfg.num_attention_heads,
1099 sliding_window: cfg.sliding_window,
1100 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1101 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1102 };
1103
1104 Ok(Box::new(cfg))
1105 }
1106
1107 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1108 Some(vec![NonMappedSubModel::Vision])
1109 }
1110}
1111
1112pub struct LLaVANextLoader;
1118
1119pub struct LLaVANextPrefixer;
1120
1121impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1122 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1123 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1124 }
1125}
1126
1127impl VisionModelLoader for LLaVANextLoader {
1128 fn load(
1129 &self,
1130 config: &str,
1131 vb: ShardedVarBuilder,
1132 normal_loading_metadata: NormalLoadingMetadata,
1133 attention_mechanism: AttentionImplementation,
1134 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1135 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1136 Ok(Box::new(LLaVANext::new(
1137 &cfg,
1138 vb,
1139 self.is_gptx(config),
1140 normal_loading_metadata,
1141 attention_mechanism,
1142 )?))
1143 }
1144 fn is_gptx(&self, _config: &str) -> bool {
1145 false
1146 }
1147 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1148 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1149 Ok(Box::new(cfg))
1150 }
1151 fn get_processor(
1152 &self,
1153 model_config: &str,
1154 _processor_config: Option<ProcessorConfig>,
1155 _preprocessor_config: PreProcessorConfig,
1156 _max_edge: Option<u32>,
1157 ) -> Arc<dyn Processor + Send + Sync> {
1158 Arc::new(LLaVANextProcessor::new(model_config))
1159 }
1160 fn supports_paged_attention(&self, _config: &str) -> bool {
1161 true
1162 }
1163 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1164 true
1165 }
1166 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1167 Arc::new(LLaVANextPrefixer)
1168 }
1169 fn modalities(&self, _config: &str) -> Result<Modalities> {
1170 Ok(Modalities {
1171 input: vec![SupportedModality::Text, SupportedModality::Vision],
1172 output: vec![SupportedModality::Text],
1173 })
1174 }
1175}
1176
1177impl IsqModelLoader for LLaVANextLoader {
1178 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1179 Ok(vec![
1180 Regex::new(r"lm_head\.(weight|bias)$")?,
1181 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1183 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1184 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1185 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1186 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1188 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1189 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1190 ])
1191 }
1192 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1193 Ok(vec![
1194 Regex::new(r"lm_head\.(weight|bias)$")?,
1195 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1197 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1198 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1199 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1200 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1202 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1203 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1204 ])
1205 }
1206}
1207
1208impl DeviceMappedModelLoader for LLaVANextLoader {
1209 fn mapped_max_act_size_elems(
1210 &self,
1211 config: &str,
1212 params: &AutoDeviceMapParams,
1213 ) -> Result<usize> {
1214 let AutoDeviceMapParams::Vision {
1215 max_seq_len,
1216 max_batch_size,
1217 max_image_shape,
1218 max_num_images,
1219 } = params
1220 else {
1221 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1222 };
1223
1224 let config: LLaVAConfig = serde_json::from_str(config)?;
1225
1226 #[allow(clippy::cast_possible_truncation)]
1227 let img_seq_len =
1228 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1229 &config,
1230 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1231 );
1232 let img_seq_len = img_seq_len * max_num_images;
1233
1234 let max_text_attn = {
1235 let cfg = &config.text_config;
1236 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1238
1239 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1240 };
1241
1242 Ok(max_text_attn)
1243 }
1244
1245 fn non_mapped_max_act_size_elems(
1246 &self,
1247 config: &str,
1248 params: &AutoDeviceMapParams,
1249 ) -> Result<usize> {
1250 let AutoDeviceMapParams::Vision {
1251 max_seq_len: _,
1252 max_batch_size,
1253 max_image_shape,
1254 max_num_images,
1255 } = params
1256 else {
1257 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1258 };
1259
1260 let config: LLaVAConfig = serde_json::from_str(config)?;
1261
1262 #[allow(clippy::cast_possible_truncation)]
1263 let img_seq_len =
1264 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1265 &config,
1266 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1267 );
1268
1269 let max_vision_attn = {
1270 (max_batch_size * max_num_images)
1271 * config.vision_config.num_attention_heads
1272 * img_seq_len
1273 * img_seq_len
1274 };
1275
1276 Ok(max_vision_attn)
1277 }
1278
1279 fn non_mapped_size_in_bytes(
1280 &self,
1281 config: &str,
1282 dtype: DType,
1283 weight_pack_factor: usize,
1284 _matformer_config: Option<&MatformerSliceConfig>,
1285 ) -> Result<usize> {
1286 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1287 let text_elems = {
1288 let cfg = &cfg.text_config;
1289 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1290 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1291 let norm = cfg.hidden_size;
1292 embed_tokens + lm_head + norm
1293 };
1294
1295 let image_newline = cfg.text_config.hidden_size;
1296 let mmproj = {
1297 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1298 + cfg.text_config.hidden_size;
1299 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1300 + cfg.text_config.hidden_size;
1301
1302 linear_1 + linear_2
1303 };
1304 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1305
1306 let elems = text_elems + image_newline + mmproj + vision_tower;
1307 Ok(elems * dtype.size_in_bytes())
1308 }
1309
1310 fn layer_sizes_in_bytes(
1311 &self,
1312 config: &str,
1313 dtype: DType,
1314 weight_pack_factor: usize,
1315 _matformer_config: Option<&MatformerSliceConfig>,
1316 ) -> Result<Vec<usize>> {
1317 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1318 let per_layer_elems = {
1319 let cfg = &cfg.text_config;
1320 let input_layernorm = cfg.hidden_size;
1321 let post_attention_layernorm = cfg.hidden_size;
1322
1323 let size_in = cfg.hidden_size;
1324 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1325 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1326 let q_proj = size_in * size_q / weight_pack_factor;
1327 let k_proj = size_in * size_kv / weight_pack_factor;
1328 let v_proj = size_in * size_kv / weight_pack_factor;
1329 let o_proj = size_q * size_in / weight_pack_factor;
1330
1331 let h_size = cfg.hidden_size;
1332 let i_size = cfg.intermediate_size;
1333 let gate_proj = h_size * i_size / weight_pack_factor;
1334 let up_proj = h_size * i_size / weight_pack_factor;
1335 let down_proj = i_size * h_size / weight_pack_factor;
1336
1337 input_layernorm
1338 + post_attention_layernorm
1339 + q_proj
1340 + k_proj
1341 + v_proj
1342 + o_proj
1343 + gate_proj
1344 + up_proj
1345 + down_proj
1346 };
1347 Ok(vec![
1348 per_layer_elems * dtype.size_in_bytes();
1349 cfg.text_config.num_hidden_layers
1350 ])
1351 }
1352
1353 fn num_layers(&self, config: &str) -> Result<usize> {
1354 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1355 Ok(cfg.text_config.num_hidden_layers)
1356 }
1357
1358 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1359 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1360 let cfg = &cfg.text_config;
1361
1362 let cfg = ModelConfigMetadata {
1363 max_seq_len: cfg.max_position_embeddings,
1364 num_layers: cfg.num_hidden_layers,
1365 hidden_size: cfg.hidden_size,
1366 num_kv_heads: cfg.num_key_value_heads,
1367 num_attn_heads: cfg.num_attention_heads,
1368 sliding_window: cfg.sliding_window,
1369 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1370 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1371 };
1372
1373 Ok(Box::new(cfg))
1374 }
1375
1376 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1377 Some(vec![NonMappedSubModel::Vision])
1378 }
1379}
1380
1381pub struct LLaVALoader;
1387
1388pub struct LLaVAPrefixer;
1389
1390impl MultimodalPromptPrefixer for LLaVAPrefixer {
1391 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1392 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1393 }
1394}
1395
1396impl VisionModelLoader for LLaVALoader {
1397 fn load(
1398 &self,
1399 config: &str,
1400 vb: ShardedVarBuilder,
1401 normal_loading_metadata: NormalLoadingMetadata,
1402 attention_mechanism: AttentionImplementation,
1403 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1404 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1405 Ok(Box::new(LLaVA::new(
1406 &cfg,
1407 vb,
1408 self.is_gptx(config),
1409 normal_loading_metadata,
1410 attention_mechanism,
1411 )?))
1412 }
1413 fn is_gptx(&self, _config: &str) -> bool {
1414 false
1415 }
1416 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1417 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1418 Ok(Box::new(cfg))
1419 }
1420 fn get_processor(
1421 &self,
1422 model_config: &str,
1423 _processor_config: Option<ProcessorConfig>,
1424 _preprocessor_config: PreProcessorConfig,
1425 _max_edge: Option<u32>,
1426 ) -> Arc<dyn Processor + Send + Sync> {
1427 Arc::new(LLaVAProcessor::new(model_config))
1428 }
1429 fn supports_paged_attention(&self, _config: &str) -> bool {
1430 true
1431 }
1432 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1433 true
1434 }
1435 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1436 Arc::new(LLaVAPrefixer)
1437 }
1438 fn modalities(&self, _config: &str) -> Result<Modalities> {
1439 Ok(Modalities {
1440 input: vec![SupportedModality::Text, SupportedModality::Vision],
1441 output: vec![SupportedModality::Text],
1442 })
1443 }
1444}
1445
1446impl IsqModelLoader for LLaVALoader {
1447 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1448 Ok(vec![
1449 Regex::new(r"lm_head\.(weight|bias)$")?,
1450 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1452 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1453 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1454 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1455 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1457 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1458 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1459 ])
1460 }
1461 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1462 Ok(vec![
1463 Regex::new(r"lm_head\.(weight|bias)$")?,
1464 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1466 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1467 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1468 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1469 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1471 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1472 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1473 ])
1474 }
1475}
1476
1477impl DeviceMappedModelLoader for LLaVALoader {
1478 fn mapped_max_act_size_elems(
1479 &self,
1480 config: &str,
1481 params: &AutoDeviceMapParams,
1482 ) -> Result<usize> {
1483 let AutoDeviceMapParams::Vision {
1484 max_seq_len,
1485 max_batch_size,
1486 max_image_shape: _,
1487 max_num_images,
1488 } = params
1489 else {
1490 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1491 };
1492
1493 let config: LLaVAConfig = serde_json::from_str(config)?;
1494
1495 let img_seq_len =
1496 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1497 let img_seq_len = img_seq_len * max_num_images;
1498
1499 let max_text_attn = {
1500 let cfg = &config.text_config;
1501 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1503
1504 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1505 };
1506
1507 Ok(max_text_attn)
1508 }
1509
1510 fn non_mapped_max_act_size_elems(
1511 &self,
1512 config: &str,
1513 params: &AutoDeviceMapParams,
1514 ) -> Result<usize> {
1515 let AutoDeviceMapParams::Vision {
1516 max_seq_len: _,
1517 max_batch_size,
1518 max_image_shape: _,
1519 max_num_images,
1520 } = params
1521 else {
1522 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1523 };
1524
1525 let config: LLaVAConfig = serde_json::from_str(config)?;
1526
1527 let img_seq_len =
1528 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1529
1530 let max_vision_attn = {
1531 (max_batch_size * max_num_images)
1532 * config.vision_config.num_attention_heads
1533 * img_seq_len
1534 * img_seq_len
1535 };
1536
1537 Ok(max_vision_attn)
1538 }
1539
1540 fn non_mapped_size_in_bytes(
1541 &self,
1542 config: &str,
1543 dtype: DType,
1544 weight_pack_factor: usize,
1545 _matformer_config: Option<&MatformerSliceConfig>,
1546 ) -> Result<usize> {
1547 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1548 let text_elems = {
1549 let cfg = &cfg.text_config;
1550 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1551 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1552 let norm = cfg.hidden_size;
1553 embed_tokens + lm_head + norm
1554 };
1555
1556 let image_newline = cfg.text_config.hidden_size;
1557 let mmproj = {
1558 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1559 + cfg.text_config.hidden_size;
1560 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1561 + cfg.text_config.hidden_size;
1562
1563 linear_1 + linear_2
1564 };
1565 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1566
1567 let elems = text_elems + image_newline + mmproj + vision_tower;
1568 Ok(elems * dtype.size_in_bytes())
1569 }
1570
1571 fn layer_sizes_in_bytes(
1572 &self,
1573 config: &str,
1574 dtype: DType,
1575 weight_pack_factor: usize,
1576 _matformer_config: Option<&MatformerSliceConfig>,
1577 ) -> Result<Vec<usize>> {
1578 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1579 let per_layer_elems = {
1580 let cfg = &cfg.text_config;
1581 let input_layernorm = cfg.hidden_size;
1582 let post_attention_layernorm = cfg.hidden_size;
1583
1584 let size_in = cfg.hidden_size;
1585 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1586 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1587 let q_proj = size_in * size_q / weight_pack_factor;
1588 let k_proj = size_in * size_kv / weight_pack_factor;
1589 let v_proj = size_in * size_kv / weight_pack_factor;
1590 let o_proj = size_q * size_in / weight_pack_factor;
1591
1592 let h_size = cfg.hidden_size;
1593 let i_size = cfg.intermediate_size;
1594 let gate_proj = h_size * i_size / weight_pack_factor;
1595 let up_proj = h_size * i_size / weight_pack_factor;
1596 let down_proj = i_size * h_size / weight_pack_factor;
1597
1598 input_layernorm
1599 + post_attention_layernorm
1600 + q_proj
1601 + k_proj
1602 + v_proj
1603 + o_proj
1604 + gate_proj
1605 + up_proj
1606 + down_proj
1607 };
1608 Ok(vec![
1609 per_layer_elems * dtype.size_in_bytes();
1610 cfg.text_config.num_hidden_layers
1611 ])
1612 }
1613
1614 fn num_layers(&self, config: &str) -> Result<usize> {
1615 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1616 Ok(cfg.text_config.num_hidden_layers)
1617 }
1618
1619 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1620 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1621 let cfg = &cfg.text_config;
1622
1623 let cfg = ModelConfigMetadata {
1624 max_seq_len: cfg.max_position_embeddings,
1625 num_layers: cfg.num_hidden_layers,
1626 hidden_size: cfg.hidden_size,
1627 num_kv_heads: cfg.num_key_value_heads,
1628 num_attn_heads: cfg.num_attention_heads,
1629 sliding_window: cfg.sliding_window,
1630 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1631 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1632 };
1633
1634 Ok(Box::new(cfg))
1635 }
1636
1637 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1638 Some(vec![NonMappedSubModel::Vision])
1639 }
1640}
1641
1642pub struct VLlamaLoader;
1648
1649pub struct VLlamaPrefixer;
1650
1651impl MultimodalPromptPrefixer for VLlamaPrefixer {
1652 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1653 format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1654 }
1655}
1656
1657impl VisionModelLoader for VLlamaLoader {
1658 fn load(
1659 &self,
1660 config: &str,
1661 vb: ShardedVarBuilder,
1662 normal_loading_metadata: NormalLoadingMetadata,
1663 attention_mechanism: AttentionImplementation,
1664 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1665 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1666 Ok(Box::new(MLlamaModel::new(
1667 &cfg,
1668 vb,
1669 self.is_gptx(config),
1670 normal_loading_metadata,
1671 attention_mechanism,
1672 )?))
1673 }
1674 fn is_gptx(&self, _config: &str) -> bool {
1675 true
1676 }
1677 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1678 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1679 Ok(Box::new(cfg))
1680 }
1681 fn get_processor(
1682 &self,
1683 _model_config: &str,
1684 _processor_config: Option<ProcessorConfig>,
1685 _preprocessor_config: PreProcessorConfig,
1686 _max_edge: Option<u32>,
1687 ) -> Arc<dyn Processor + Send + Sync> {
1688 Arc::new(MLlamaProcessor::new())
1689 }
1690 fn supports_paged_attention(&self, _config: &str) -> bool {
1691 false
1692 }
1693 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1694 true
1695 }
1696 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1697 Arc::new(VLlamaPrefixer)
1698 }
1699 fn modalities(&self, _config: &str) -> Result<Modalities> {
1700 Ok(Modalities {
1701 input: vec![SupportedModality::Text, SupportedModality::Vision],
1702 output: vec![SupportedModality::Text],
1703 })
1704 }
1705}
1706
1707impl IsqModelLoader for VLlamaLoader {
1708 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1709 let config: MLlamaConfig = serde_json::from_str(config)?;
1710 let cross_attn_layers = &config.text_config.cross_attention_layers;
1711 let transformer_layers =
1712 (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1713 let mut text_regexes = Vec::new();
1714 for layer in transformer_layers {
1715 text_regexes.extend(vec![
1716 Regex::new(&format!(
1718 r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1719 ))?,
1720 Regex::new(&format!(
1721 r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1722 ))?,
1723 Regex::new(&format!(
1724 r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1725 ))?,
1726 Regex::new(&format!(
1727 r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1728 ))?,
1729 Regex::new(&format!(
1731 r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1732 ))?,
1733 Regex::new(&format!(
1734 r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1735 ))?,
1736 Regex::new(&format!(
1737 r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1738 ))?,
1739 ]);
1740 }
1741 let vision_regexes = vec![
1742 Regex::new(
1744 r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1745 )?,
1746 Regex::new(
1747 r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1748 )?,
1749 Regex::new(
1750 r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1751 )?,
1752 Regex::new(
1753 r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1754 )?,
1755 Regex::new(
1757 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1758 )?,
1759 Regex::new(
1760 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1761 )?,
1762 Regex::new(
1763 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1764 )?,
1765 Regex::new(
1766 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1767 )?,
1768 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1770 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1771 ];
1772
1773 Ok([text_regexes, vision_regexes].concat())
1774 }
1775 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1776 self.isq_layer_regexes(config)
1777 }
1778}
1779
1780impl DeviceMappedModelLoader for VLlamaLoader {
1781 fn mapped_max_act_size_elems(
1782 &self,
1783 config: &str,
1784 params: &AutoDeviceMapParams,
1785 ) -> Result<usize> {
1786 let AutoDeviceMapParams::Vision {
1787 max_seq_len,
1788 max_batch_size,
1789 max_image_shape: _,
1790 max_num_images,
1791 } = params
1792 else {
1793 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1794 };
1795
1796 let config: MLlamaConfig = serde_json::from_str(config)?;
1797
1798 let img_seq_len = {
1799 let cfg = &config.vision_config;
1800 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1801 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1802 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1803 };
1804 let img_seq_len = img_seq_len * max_num_images;
1805
1806 let max_cross_text_attn = {
1807 let cfg = &config.text_config;
1808 max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1809 };
1810
1811 let max_self_text_attn = {
1812 let cfg = &config.text_config;
1813 max_batch_size * cfg.num_attention_heads * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)
1814 };
1815
1816 Ok(max_self_text_attn.max(max_cross_text_attn))
1817 }
1818
1819 fn non_mapped_max_act_size_elems(
1820 &self,
1821 config: &str,
1822 params: &AutoDeviceMapParams,
1823 ) -> Result<usize> {
1824 let AutoDeviceMapParams::Vision {
1825 max_seq_len: _,
1826 max_batch_size,
1827 max_image_shape: _,
1828 max_num_images,
1829 } = params
1830 else {
1831 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1832 };
1833
1834 let config: MLlamaConfig = serde_json::from_str(config)?;
1835
1836 let img_seq_len = {
1837 let cfg = &config.vision_config;
1838 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1839 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1840 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1841 };
1842 let max_vision_attn = {
1843 let cfg = &config.vision_config;
1844 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1845 };
1846
1847 Ok(max_vision_attn)
1848 }
1849
1850 fn non_mapped_size_in_bytes(
1851 &self,
1852 config: &str,
1853 dtype: DType,
1854 weight_pack_factor: usize,
1855 _matformer_config: Option<&MatformerSliceConfig>,
1856 ) -> Result<usize> {
1857 let config: MLlamaConfig = serde_json::from_str(config)?;
1858 let text_elems = {
1859 let cfg = &config.text_config;
1860 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1861 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1863 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1864 } else {
1865 0
1866 };
1867 let norm = cfg.hidden_size;
1868 embed_tokens + lm_head + norm
1869 };
1870
1871 let vision_elems = {
1872 let cfg = &config.vision_config;
1873
1874 let conv_cfg = Conv2dConfig {
1875 stride: cfg.patch_size,
1876 ..Default::default()
1877 };
1878 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1879 * cfg.patch_size
1880 * cfg.patch_size;
1881
1882 let class_embedding = cfg.hidden_size;
1883
1884 let gated_positional_embedding = {
1885 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1886 let embedding = num_patches * cfg.hidden_size;
1887 let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1888 * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1889
1890 embedding + tile_embedding
1891 };
1892
1893 let pre_tile_positional_embedding =
1894 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1895 let post_tile_positional_embedding =
1896 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1897
1898 let layernorm_pre = cfg.hidden_size;
1899 let layernorm_post = cfg.hidden_size;
1900
1901 let encoder_layer = {
1902 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1903 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1904
1905 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1906 let q_proj =
1907 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1908 let k_proj =
1909 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1910 let v_proj =
1911 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1912 let o_proj =
1913 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1914
1915 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1916 + cfg.intermediate_size;
1917 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1918 + cfg.hidden_size;
1919
1920 input_layernorm
1921 + post_attention_layernorm
1922 + q_proj
1923 + k_proj
1924 + v_proj
1925 + o_proj
1926 + fc1
1927 + fc2
1928 };
1929
1930 patch_embedding
1931 + class_embedding
1932 + gated_positional_embedding
1933 + pre_tile_positional_embedding
1934 + post_tile_positional_embedding
1935 + layernorm_pre
1936 + layernorm_post
1937 + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1938 };
1939
1940 let elems = text_elems + vision_elems;
1941 Ok(elems * dtype.size_in_bytes())
1942 }
1943
1944 fn layer_sizes_in_bytes(
1945 &self,
1946 config: &str,
1947 dtype: DType,
1948 weight_pack_factor: usize,
1949 _matformer_config: Option<&MatformerSliceConfig>,
1950 ) -> Result<Vec<usize>> {
1951 let config: MLlamaConfig = serde_json::from_str(config)?;
1952 let cfg = &config.text_config;
1953
1954 let mut layer_sizes = Vec::new();
1955
1956 for i in 0..cfg.num_hidden_layers {
1957 let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1958 1
1960 } else {
1961 weight_pack_factor
1962 };
1963
1964 let per_layer_elems = {
1965 let input_layernorm = cfg.hidden_size;
1966 let post_attention_layernorm = cfg.hidden_size;
1967
1968 let size_in = cfg.hidden_size;
1969 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1970 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1971 let q_proj = size_in * size_q / weight_pack_factor;
1972 let k_proj = size_in * size_kv / weight_pack_factor;
1973 let v_proj = size_in * size_kv / weight_pack_factor;
1974 let o_proj = size_q * size_in / weight_pack_factor;
1975
1976 let h_size = cfg.hidden_size;
1977 let i_size = cfg.intermediate_size;
1978 let gate_proj = h_size * i_size / weight_pack_factor;
1979 let up_proj = h_size * i_size / weight_pack_factor;
1980 let down_proj = i_size * h_size / weight_pack_factor;
1981
1982 input_layernorm
1983 + post_attention_layernorm
1984 + q_proj
1985 + k_proj
1986 + v_proj
1987 + o_proj
1988 + gate_proj
1989 + up_proj
1990 + down_proj
1991 };
1992
1993 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1994 }
1995
1996 Ok(layer_sizes)
1997 }
1998
1999 fn num_layers(&self, config: &str) -> Result<usize> {
2000 let config: MLlamaConfig = serde_json::from_str(config)?;
2001 Ok(config.text_config.num_hidden_layers)
2002 }
2003
2004 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2005 let cfg: MLlamaConfig = serde_json::from_str(config)?;
2006 let cfg = &cfg.text_config;
2007
2008 let cfg = ModelConfigMetadata {
2009 max_seq_len: cfg.max_position_embeddings,
2010 num_layers: cfg.num_hidden_layers,
2011 hidden_size: cfg.hidden_size,
2012 num_kv_heads: cfg.num_key_value_heads,
2013 num_attn_heads: cfg.num_attention_heads,
2014 sliding_window: None,
2015 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2016 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2017 };
2018
2019 Ok(Box::new(cfg))
2020 }
2021
2022 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2023 Some(vec![NonMappedSubModel::Vision])
2024 }
2025}
2026
2027pub struct Qwen2VLLoader;
2033
2034pub struct Qwen2VLPrefixer;
2035
2036impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
2037 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2038 format!(
2039 "{}{prompt}",
2040 format!(
2041 "{}{}{}",
2042 Qwen2VLProcessor::VISION_START,
2043 Qwen2VLProcessor::IMAGE_PAD,
2044 Qwen2VLProcessor::VISION_END
2045 )
2046 .repeat(image_indexes.len())
2047 )
2048 }
2049}
2050
2051impl VisionModelLoader for Qwen2VLLoader {
2052 fn load(
2053 &self,
2054 config: &str,
2055 vb: ShardedVarBuilder,
2056 normal_loading_metadata: NormalLoadingMetadata,
2057 attention_mechanism: AttentionImplementation,
2058 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2059 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2060 Ok(Box::new(Qwen2VLModel::new(
2061 &cfg,
2062 vb,
2063 self.is_gptx(config),
2064 normal_loading_metadata,
2065 attention_mechanism,
2066 )?))
2067 }
2068 fn is_gptx(&self, _config: &str) -> bool {
2069 true
2070 }
2071 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2072 let config: Qwen2VLConfig = serde_json::from_str(config)?;
2073 Ok(Box::new(config))
2074 }
2075 fn get_processor(
2076 &self,
2077 _model_config: &str,
2078 _processor_config: Option<ProcessorConfig>,
2079 _preprocessor_config: PreProcessorConfig,
2080 max_edge: Option<u32>,
2081 ) -> Arc<dyn Processor + Send + Sync> {
2082 Arc::new(Qwen2VLProcessor::new(max_edge))
2083 }
2084 fn supports_paged_attention(&self, _config: &str) -> bool {
2085 false
2086 }
2087 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2088 Arc::new(Qwen2VLPrefixer)
2089 }
2090 fn modalities(&self, _config: &str) -> Result<Modalities> {
2091 Ok(Modalities {
2092 input: vec![SupportedModality::Text, SupportedModality::Vision],
2093 output: vec![SupportedModality::Text],
2094 })
2095 }
2096}
2097
2098impl IsqModelLoader for Qwen2VLLoader {
2099 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2100 Ok(vec![
2101 Regex::new(r"lm_head\.(weight|bias)$")?,
2102 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2104 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2105 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2106 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2107 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2109 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2110 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2111 ])
2112 }
2113 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2114 self.isq_layer_regexes(config)
2115 }
2116}
2117
2118impl DeviceMappedModelLoader for Qwen2VLLoader {
2119 fn mapped_max_act_size_elems(
2120 &self,
2121 config: &str,
2122 params: &AutoDeviceMapParams,
2123 ) -> Result<usize> {
2124 let AutoDeviceMapParams::Vision {
2125 max_seq_len,
2126 max_batch_size,
2127 max_image_shape,
2128 max_num_images,
2129 } = params
2130 else {
2131 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2132 };
2133
2134 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2135
2136 let img_seq_len = {
2138 let cfg = &cfg.vision_config;
2139 let grid_t = 1;
2141 let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
2143 let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
2144 grid_t * grid_h * grid_w * max_num_images
2145 };
2146
2147 let max_text_attn = {
2148 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2150 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2151 };
2152
2153 Ok(max_text_attn)
2154 }
2155
2156 fn non_mapped_max_act_size_elems(
2157 &self,
2158 config: &str,
2159 params: &AutoDeviceMapParams,
2160 ) -> Result<usize> {
2161 let AutoDeviceMapParams::Vision {
2162 max_seq_len: _,
2163 max_batch_size,
2164 max_image_shape,
2165 max_num_images,
2166 } = params
2167 else {
2168 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2169 };
2170
2171 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2172
2173 let img_seq_len = {
2175 let cfg = &cfg.vision_config;
2176 let grid_t = 1;
2178 let grid_h = max_image_shape.0 / cfg.patch_size;
2179 let grid_w = max_image_shape.1 / cfg.patch_size;
2180 grid_t * grid_h * grid_w
2181 };
2182
2183 let max_vision_attn = {
2184 let cfg = &cfg.vision_config;
2185 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2186 };
2187
2188 Ok(max_vision_attn)
2189 }
2190
2191 fn non_mapped_size_in_bytes(
2192 &self,
2193 config: &str,
2194 dtype: DType,
2195 weight_pack_factor: usize,
2196 _matformer_config: Option<&MatformerSliceConfig>,
2197 ) -> Result<usize> {
2198 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2199 let text_elems = {
2200 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2201 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2203 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2204 } else {
2205 0
2206 };
2207 let norm = cfg.hidden_size;
2208 embed_tokens + lm_head + norm
2209 };
2210
2211 let patch_merger = {
2212 let cfg = &cfg.vision_config;
2213 let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2214
2215 let mlp0 = hidden_size * hidden_size + hidden_size;
2216 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2217
2218 let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2219
2220 mlp0 + mlp2 + ln_q
2221 };
2222
2223 let patch_embed = {
2224 let cfg = &cfg.vision_config;
2225 let conv_cfg = Conv3dConfig {
2226 stride: cfg.patch_size,
2227 ..Default::default()
2228 };
2229 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2230 cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2231 * kernel_sizes[0]
2232 * kernel_sizes[1]
2233 * kernel_sizes[2]
2234 };
2235
2236 let encoder_layer = {
2237 let cfg = &cfg.vision_config;
2238 let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2239 let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2240
2241 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2242 let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2243 let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2244 let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2245
2246 let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2247 let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2248
2249 norm1 + norm2 + fc1 + fc2 + qkv + out
2250 };
2251
2252 let elems =
2253 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2254
2255 Ok(elems * dtype.size_in_bytes())
2256 }
2257
2258 fn layer_sizes_in_bytes(
2259 &self,
2260 config: &str,
2261 dtype: DType,
2262 weight_pack_factor: usize,
2263 _matformer_config: Option<&MatformerSliceConfig>,
2264 ) -> Result<Vec<usize>> {
2265 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2266 let per_layer_elems = {
2267 let input_layernorm = cfg.hidden_size;
2268 let post_attention_layernorm = cfg.hidden_size;
2269
2270 let size_in = cfg.hidden_size;
2271 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2272 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2273 let q_proj = size_in * size_q / weight_pack_factor + size_q;
2274 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2275 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2276 let o_proj = size_q * size_in / weight_pack_factor;
2277
2278 let h_size = cfg.hidden_size;
2279 let i_size = cfg.intermediate_size;
2280 let gate_proj = h_size * i_size / weight_pack_factor;
2281 let up_proj = h_size * i_size / weight_pack_factor;
2282 let down_proj = i_size * h_size / weight_pack_factor;
2283
2284 input_layernorm
2285 + post_attention_layernorm
2286 + q_proj
2287 + k_proj
2288 + v_proj
2289 + o_proj
2290 + gate_proj
2291 + up_proj
2292 + down_proj
2293 };
2294 Ok(vec![
2295 per_layer_elems * dtype.size_in_bytes();
2296 cfg.num_hidden_layers
2297 ])
2298 }
2299
2300 fn num_layers(&self, config: &str) -> Result<usize> {
2301 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2302 Ok(cfg.num_hidden_layers)
2303 }
2304
2305 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2306 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2307
2308 let cfg = ModelConfigMetadata {
2309 max_seq_len: cfg.max_position_embeddings,
2310 num_layers: cfg.num_hidden_layers,
2311 hidden_size: cfg.hidden_size,
2312 num_kv_heads: cfg.num_key_value_heads,
2313 num_attn_heads: cfg.num_attention_heads,
2314 sliding_window: cfg.sliding_window,
2315 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2316 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2317 };
2318
2319 Ok(Box::new(cfg))
2320 }
2321
2322 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2323 Some(vec![NonMappedSubModel::Vision])
2324 }
2325}
2326
2327pub struct Idefics3Loader;
2333
2334pub struct Idefics3Prefixer;
2335
2336impl MultimodalPromptPrefixer for Idefics3Prefixer {
2337 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2338 prompt.to_string()
2340 }
2341}
2342
2343impl VisionModelLoader for Idefics3Loader {
2344 fn load(
2345 &self,
2346 config: &str,
2347 vb: ShardedVarBuilder,
2348 normal_loading_metadata: NormalLoadingMetadata,
2349 attention_mechanism: AttentionImplementation,
2350 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2351 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2352 Ok(Box::new(Idefics3Model::new(
2353 &cfg,
2354 vb,
2355 self.is_gptx(config),
2356 normal_loading_metadata,
2357 attention_mechanism,
2358 )?))
2359 }
2360 fn is_gptx(&self, _config: &str) -> bool {
2361 true
2362 }
2363 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2364 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2365 Ok(Box::new(cfg))
2366 }
2367 fn get_processor(
2368 &self,
2369 _model_config: &str,
2370 processor_config: Option<ProcessorConfig>,
2371 preprocessor_config: PreProcessorConfig,
2372 max_edge: Option<u32>,
2373 ) -> Arc<dyn Processor + Send + Sync> {
2374 Arc::new(Idefics3Processor::new(
2375 processor_config.unwrap_or_default(),
2376 preprocessor_config,
2377 max_edge,
2378 ))
2379 }
2380 fn supports_paged_attention(&self, _config: &str) -> bool {
2381 true
2382 }
2383 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2384 true
2385 }
2386 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2387 Arc::new(Idefics3Prefixer)
2388 }
2389 fn modalities(&self, _config: &str) -> Result<Modalities> {
2390 Ok(Modalities {
2391 input: vec![SupportedModality::Text, SupportedModality::Vision],
2392 output: vec![SupportedModality::Text],
2393 })
2394 }
2395}
2396
2397impl IsqModelLoader for Idefics3Loader {
2398 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2399 Ok(vec![
2400 Regex::new(r"lm_head\.(weight|bias)$")?,
2401 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2403 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2404 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2405 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2406 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2408 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2409 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2410 ])
2411 }
2412 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2413 Ok(vec![
2414 Regex::new(r"lm_head\.(weight|bias)$")?,
2415 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2417 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2418 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2419 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2420 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2422 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2423 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2424 ])
2441 }
2442}
2443
2444impl DeviceMappedModelLoader for Idefics3Loader {
2445 fn mapped_max_act_size_elems(
2446 &self,
2447 config: &str,
2448 params: &AutoDeviceMapParams,
2449 ) -> Result<usize> {
2450 let AutoDeviceMapParams::Vision {
2451 max_seq_len,
2452 max_batch_size,
2453 max_image_shape: _,
2454 max_num_images,
2455 } = params
2456 else {
2457 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2458 };
2459
2460 let cfg: Idefics3Config = serde_json::from_str(config)?;
2461
2462 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2463 let img_seq_len = (num_patches + 1) * max_num_images;
2464
2465 let max_text_attn = {
2466 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2468 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2469 };
2470
2471 Ok(max_text_attn)
2472 }
2473
2474 fn non_mapped_max_act_size_elems(
2475 &self,
2476 config: &str,
2477 params: &AutoDeviceMapParams,
2478 ) -> Result<usize> {
2479 let AutoDeviceMapParams::Vision {
2480 max_seq_len: _,
2481 max_batch_size,
2482 max_image_shape: _,
2483 max_num_images,
2484 } = params
2485 else {
2486 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2487 };
2488
2489 let cfg: Idefics3Config = serde_json::from_str(config)?;
2490
2491 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2492 let img_seq_len = num_patches + 1;
2493
2494 let max_vision_attn = {
2495 let images_factor = 5;
2497
2498 (max_batch_size * images_factor * max_num_images)
2499 * cfg.vision_config.num_attention_heads
2500 * img_seq_len
2501 * img_seq_len
2502 };
2503
2504 Ok(max_vision_attn)
2505 }
2506
2507 fn non_mapped_size_in_bytes(
2508 &self,
2509 config: &str,
2510 dtype: DType,
2511 weight_pack_factor: usize,
2512 _matformer_config: Option<&MatformerSliceConfig>,
2513 ) -> Result<usize> {
2514 let cfg: Idefics3Config = serde_json::from_str(config)?;
2515 let text_elems = {
2516 let cfg = &cfg.text_config;
2517
2518 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2519 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2520 let norm = cfg.hidden_size;
2521 embed_tokens + lm_head + norm
2522 };
2523
2524 let connector_elems = {
2525 let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2526 let out_dim = cfg.text_config.hidden_size;
2527
2528 in_dim * out_dim
2529 };
2530
2531 let vision_transformer = {
2532 let cfg = &cfg.vision_config;
2533
2534 let post_layernorm = cfg.hidden_size;
2535
2536 let conv_config = Conv2dConfig {
2537 stride: cfg.patch_size,
2538 ..Default::default()
2539 };
2540 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2541 * cfg.patch_size
2542 * cfg.patch_size;
2543
2544 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2545 let num_patches = num_patches_per_side.pow(2);
2546 let position_embedding = num_patches * cfg.hidden_size;
2547
2548 let layer_elems = {
2549 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2550 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2551
2552 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2553 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2554
2555 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2556 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2557 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2558 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2559
2560 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2561 };
2562
2563 post_layernorm
2564 + patch_embedding
2565 + position_embedding
2566 + layer_elems * cfg.num_hidden_layers
2567 };
2568
2569 let elems = text_elems + connector_elems + vision_transformer;
2570
2571 Ok(elems * dtype.size_in_bytes())
2572 }
2573
2574 fn layer_sizes_in_bytes(
2575 &self,
2576 config: &str,
2577 dtype: DType,
2578 weight_pack_factor: usize,
2579 _matformer_config: Option<&MatformerSliceConfig>,
2580 ) -> Result<Vec<usize>> {
2581 let cfg: Idefics3Config = serde_json::from_str(config)?;
2582 let cfg = cfg.text_config;
2583 let per_layer_elems = {
2584 let input_layernorm = cfg.hidden_size;
2585 let post_attention_layernorm = cfg.hidden_size;
2586
2587 let size_in = cfg.hidden_size;
2588 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2589 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2590 let q_proj = size_in * size_q / weight_pack_factor;
2591 let k_proj = size_in * size_kv / weight_pack_factor;
2592 let v_proj = size_in * size_kv / weight_pack_factor;
2593 let o_proj = size_q * size_in / weight_pack_factor;
2594
2595 let h_size = cfg.hidden_size;
2596 let i_size = cfg.intermediate_size;
2597 let gate_proj = h_size * i_size / weight_pack_factor;
2598 let up_proj = h_size * i_size / weight_pack_factor;
2599 let down_proj = i_size * h_size / weight_pack_factor;
2600
2601 input_layernorm
2602 + post_attention_layernorm
2603 + q_proj
2604 + k_proj
2605 + v_proj
2606 + o_proj
2607 + gate_proj
2608 + up_proj
2609 + down_proj
2610 };
2611 Ok(vec![
2612 per_layer_elems * dtype.size_in_bytes();
2613 cfg.num_hidden_layers
2614 ])
2615 }
2616
2617 fn num_layers(&self, config: &str) -> Result<usize> {
2618 let cfg: Idefics3Config = serde_json::from_str(config)?;
2619 Ok(cfg.text_config.num_hidden_layers)
2620 }
2621 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2622 let cfg: Idefics3Config = serde_json::from_str(config)?;
2623 let cfg = &cfg.text_config;
2624
2625 let cfg = ModelConfigMetadata {
2626 max_seq_len: cfg.max_position_embeddings,
2627 num_layers: cfg.num_hidden_layers,
2628 hidden_size: cfg.hidden_size,
2629 num_kv_heads: cfg.num_key_value_heads,
2630 num_attn_heads: cfg.num_attention_heads,
2631 sliding_window: None,
2632 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2633 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2634 };
2635
2636 Ok(Box::new(cfg))
2637 }
2638
2639 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2640 Some(vec![NonMappedSubModel::Vision])
2641 }
2642}
2643
2644pub struct MiniCpmOLoader;
2650
2651pub struct MiniCpmOPrefixer;
2652
2653impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2654 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2655 format!(
2656 "{}{prompt}",
2657 "(<image>./</image>)".repeat(image_indexes.len())
2658 )
2659 }
2660}
2661
2662impl VisionModelLoader for MiniCpmOLoader {
2663 fn load(
2664 &self,
2665 config: &str,
2666 vb: ShardedVarBuilder,
2667 normal_loading_metadata: NormalLoadingMetadata,
2668 attention_mechanism: AttentionImplementation,
2669 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2670 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2671 Ok(Box::new(MiniCpmOModel::new(
2672 &cfg,
2673 vb,
2674 self.is_gptx(config),
2675 normal_loading_metadata,
2676 attention_mechanism,
2677 )?))
2678 }
2679 fn is_gptx(&self, _config: &str) -> bool {
2680 true
2681 }
2682 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2683 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2684 Ok(Box::new(cfg))
2685 }
2686 fn get_processor(
2687 &self,
2688 _model_config: &str,
2689 processor_config: Option<ProcessorConfig>,
2690 preprocessor_config: PreProcessorConfig,
2691 max_edge: Option<u32>,
2692 ) -> Arc<dyn Processor + Send + Sync> {
2693 Arc::new(MiniCpmOProcessor::new(
2694 processor_config.unwrap_or_default(),
2695 preprocessor_config,
2696 max_edge,
2697 ))
2698 }
2699 fn supports_paged_attention(&self, _config: &str) -> bool {
2700 true
2701 }
2702 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2703 Arc::new(MiniCpmOPrefixer)
2704 }
2705 fn modalities(&self, _config: &str) -> Result<Modalities> {
2706 Ok(Modalities {
2707 input: vec![SupportedModality::Text, SupportedModality::Vision],
2708 output: vec![SupportedModality::Text],
2709 })
2710 }
2711}
2712
2713impl IsqModelLoader for MiniCpmOLoader {
2714 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2715 Ok(vec![
2716 Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2717 Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2719 Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2720 Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2721 Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2722 Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2724 Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2725 Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2726 ])
2727 }
2728 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2729 self.isq_layer_regexes(config)
2730 }
2731}
2732
2733impl DeviceMappedModelLoader for MiniCpmOLoader {
2734 fn mapped_max_act_size_elems(
2735 &self,
2736 config: &str,
2737 params: &AutoDeviceMapParams,
2738 ) -> Result<usize> {
2739 let AutoDeviceMapParams::Vision {
2740 max_seq_len,
2741 max_batch_size,
2742 max_image_shape: _,
2743 max_num_images,
2744 } = params
2745 else {
2746 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2747 };
2748
2749 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2750
2751 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2752 let img_seq_len = (num_patches + 1) * max_num_images;
2753
2754 let max_text_attn = {
2755 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2757 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2758 };
2759
2760 Ok(max_text_attn)
2761 }
2762
2763 fn non_mapped_max_act_size_elems(
2764 &self,
2765 config: &str,
2766 params: &AutoDeviceMapParams,
2767 ) -> Result<usize> {
2768 let AutoDeviceMapParams::Vision {
2769 max_seq_len: _,
2770 max_batch_size,
2771 max_image_shape: _,
2772 max_num_images,
2773 } = params
2774 else {
2775 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2776 };
2777
2778 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2779
2780 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2781 let img_seq_len = num_patches + 1;
2782
2783 let max_vision_attn = {
2784 let images_factor = 5;
2786
2787 (max_batch_size * images_factor * max_num_images)
2788 * cfg.vision_config.num_attention_heads
2789 * img_seq_len
2790 * img_seq_len
2791 };
2792
2793 Ok(max_vision_attn)
2794 }
2795
2796 fn non_mapped_size_in_bytes(
2797 &self,
2798 config: &str,
2799 dtype: DType,
2800 weight_pack_factor: usize,
2801 _matformer_config: Option<&MatformerSliceConfig>,
2802 ) -> Result<usize> {
2803 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2804 let text_elems = {
2805 let cfg = &cfg.text_config;
2806
2807 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2808 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2809 let norm = cfg.hidden_size;
2810 embed_tokens + lm_head + norm
2811 };
2812
2813 let vision_transformer = {
2814 let cfg = &cfg.vision_config;
2815
2816 let post_layernorm = cfg.hidden_size;
2817
2818 let conv_config = Conv2dConfig {
2819 stride: cfg.patch_size,
2820 ..Default::default()
2821 };
2822 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2823 * cfg.patch_size
2824 * cfg.patch_size;
2825
2826 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2827 let num_patches = num_patches_per_side.pow(2);
2828 let position_embedding = num_patches * cfg.hidden_size;
2829
2830 let layer_elems = {
2831 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2832 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2833
2834 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2835 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2836
2837 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2838 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2839 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2840 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2841
2842 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2843 };
2844
2845 post_layernorm
2846 + patch_embedding
2847 + position_embedding
2848 + layer_elems * cfg.num_hidden_layers
2849 };
2850
2851 let elems = text_elems + vision_transformer;
2852
2853 Ok(elems * dtype.size_in_bytes())
2854 }
2855
2856 fn layer_sizes_in_bytes(
2857 &self,
2858 config: &str,
2859 dtype: DType,
2860 weight_pack_factor: usize,
2861 _matformer_config: Option<&MatformerSliceConfig>,
2862 ) -> Result<Vec<usize>> {
2863 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2864 let cfg = cfg.text_config;
2865 let per_layer_elems = {
2866 let input_layernorm = cfg.hidden_size;
2867 let post_attention_layernorm = cfg.hidden_size;
2868
2869 let size_in = cfg.hidden_size;
2870 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2871 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2872 let q_proj = size_in * size_q / weight_pack_factor;
2873 let k_proj = size_in * size_kv / weight_pack_factor;
2874 let v_proj = size_in * size_kv / weight_pack_factor;
2875 let o_proj = size_q * size_in / weight_pack_factor;
2876
2877 let h_size = cfg.hidden_size;
2878 let i_size = cfg.intermediate_size;
2879 let gate_proj = h_size * i_size / weight_pack_factor;
2880 let up_proj = h_size * i_size / weight_pack_factor;
2881 let down_proj = i_size * h_size / weight_pack_factor;
2882
2883 input_layernorm
2884 + post_attention_layernorm
2885 + q_proj
2886 + k_proj
2887 + v_proj
2888 + o_proj
2889 + gate_proj
2890 + up_proj
2891 + down_proj
2892 };
2893 Ok(vec![
2894 per_layer_elems * dtype.size_in_bytes();
2895 cfg.num_hidden_layers
2896 ])
2897 }
2898
2899 fn num_layers(&self, config: &str) -> Result<usize> {
2900 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2901 Ok(cfg.text_config.num_hidden_layers)
2902 }
2903 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2904 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2905 let cfg = &cfg.text_config;
2906
2907 let cfg = ModelConfigMetadata {
2908 max_seq_len: cfg.max_position_embeddings,
2909 num_layers: cfg.num_hidden_layers,
2910 hidden_size: cfg.hidden_size,
2911 num_kv_heads: cfg.num_key_value_heads,
2912 num_attn_heads: cfg.num_attention_heads,
2913 sliding_window: None,
2914 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2915 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2916 };
2917
2918 Ok(Box::new(cfg))
2919 }
2920}
2921
2922pub struct Phi4MMLoader;
2928
2929pub struct Phi4MMPrefixer;
2930
2931impl MultimodalPromptPrefixer for Phi4MMPrefixer {
2932 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2933 format!(
2936 "{}{prompt}",
2937 image_indexes
2938 .into_iter()
2939 .map(|image_index| format!("<|image_{}|>", image_index + 1))
2940 .join("")
2941 )
2942 }
2943 fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
2944 format!(
2947 "{}{prompt}",
2948 audio_indexes
2949 .into_iter()
2950 .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
2951 .join("")
2952 )
2953 }
2954}
2955
2956impl VisionModelLoader for Phi4MMLoader {
2957 fn load(
2958 &self,
2959 config: &str,
2960 vb: ShardedVarBuilder,
2961 normal_loading_metadata: NormalLoadingMetadata,
2962 attention_mechanism: AttentionImplementation,
2963 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2964 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2965 Ok(Box::new(Phi4MMModel::new(
2966 &cfg,
2967 vb,
2968 self.is_gptx(config),
2969 normal_loading_metadata,
2970 attention_mechanism,
2971 )?))
2972 }
2973 fn is_gptx(&self, _config: &str) -> bool {
2974 true
2975 }
2976 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2977 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2978 Ok(Box::new(cfg))
2979 }
2980 fn get_processor(
2981 &self,
2982 _model_config: &str,
2983 processor_config: Option<ProcessorConfig>,
2984 preprocessor_config: PreProcessorConfig,
2985 _max_edge: Option<u32>,
2986 ) -> Arc<dyn Processor + Send + Sync> {
2987 Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2988 }
2989 fn supports_paged_attention(&self, _config: &str) -> bool {
2990 true
2991 }
2992 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2993 true
2994 }
2995 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2996 Arc::new(Phi4MMPrefixer)
2997 }
2998 fn modalities(&self, _config: &str) -> Result<Modalities> {
2999 Ok(Modalities {
3000 input: vec![
3001 SupportedModality::Text,
3002 SupportedModality::Vision,
3003 SupportedModality::Audio,
3004 ],
3005 output: vec![SupportedModality::Text],
3006 })
3007 }
3008}
3009
3010impl IsqModelLoader for Phi4MMLoader {
3011 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3012 Ok(vec![
3013 Regex::new(r"lm_head\.(weight|bias)$")?,
3014 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
3016 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3017 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
3019 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3020 ])
3021 }
3022 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3023 self.isq_layer_regexes(config)
3024 }
3025}
3026
3027impl DeviceMappedModelLoader for Phi4MMLoader {
3028 fn mapped_max_act_size_elems(
3029 &self,
3030 config: &str,
3031 params: &AutoDeviceMapParams,
3032 ) -> Result<usize> {
3033 let AutoDeviceMapParams::Vision {
3035 max_seq_len,
3036 max_batch_size,
3037 max_image_shape: _,
3038 max_num_images,
3039 } = params
3040 else {
3041 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3042 };
3043
3044 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3045
3046 let vcfg = &PHI4_MM_VISION_CFG;
3047
3048 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3049 let img_seq_len = (num_patches + 1) * max_num_images;
3050
3051 let max_text_attn = {
3052 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3054 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3055 };
3056
3057 Ok(max_text_attn)
3058 }
3059
3060 fn non_mapped_max_act_size_elems(
3061 &self,
3062 _config: &str,
3063 params: &AutoDeviceMapParams,
3064 ) -> Result<usize> {
3065 let AutoDeviceMapParams::Vision {
3066 max_seq_len: _,
3067 max_batch_size,
3068 max_image_shape,
3069 max_num_images,
3070 } = params
3071 else {
3072 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3073 };
3074
3075 let vcfg = &PHI4_MM_VISION_CFG;
3076
3077 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3078 let img_seq_len = num_patches + 1;
3079
3080 let max_batch_size = max_batch_size
3081 * (max_image_shape
3082 .0
3083 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3084 * max_image_shape
3085 .1
3086 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3087 + 1);
3088
3089 let max_vision_attn = (max_batch_size * max_num_images)
3090 * vcfg.num_attention_heads
3091 * img_seq_len
3092 * img_seq_len;
3093 let max_qkv = 3
3094 * (max_batch_size
3095 * vcfg.num_attention_heads
3096 * img_seq_len
3097 * (vcfg.hidden_size / vcfg.num_attention_heads));
3098
3099 Ok(max_vision_attn + max_qkv)
3100 }
3101
3102 fn non_mapped_size_in_bytes(
3103 &self,
3104 config: &str,
3105 dtype: DType,
3106 weight_pack_factor: usize,
3107 _matformer_config: Option<&MatformerSliceConfig>,
3108 ) -> Result<usize> {
3109 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3110 let elems = {
3111 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3112 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3114 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3115 } else {
3116 0
3117 };
3118 let norm = cfg.hidden_size;
3119
3120 let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3121 let projection_cls = img_embed
3122 .projection_cls
3123 .clone()
3124 .unwrap_or("linear".to_string());
3125 let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3126 let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3127 let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3128
3129 let proj = match (projection_cls.as_str(), use_hd_transform) {
3130 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3131 ("mlp", true) => {
3132 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3133 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3134 a + b
3135 }
3136 ("mlp", false) => {
3137 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3138 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3139 a + b
3140 }
3141 _ => {
3142 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3143 }
3144 };
3145
3146 let (glb_gn, sub_gn) = if with_learnable_separator {
3147 let glb_gn = image_dim_out * 4;
3148 let sub_gn = image_dim_out * 4;
3149 (glb_gn, sub_gn)
3150 } else {
3151 (0, 0)
3152 };
3153
3154 let vision_transformer = {
3155 let cfg = &PHI4_MM_VISION_CFG;
3156
3157 let post_layernorm = cfg.hidden_size;
3158
3159 let conv_config = Conv2dConfig {
3160 stride: cfg.patch_size,
3161 ..Default::default()
3162 };
3163 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3164 * cfg.patch_size
3165 * cfg.patch_size;
3166
3167 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3168 let num_patches = num_patches_per_side.pow(2);
3169 let position_embedding = num_patches * cfg.hidden_size;
3170
3171 let layer_elems = {
3172 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3173 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3174
3175 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3176 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3177
3178 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3179 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3180 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3181 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3182
3183 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3184 };
3185
3186 post_layernorm
3187 + patch_embedding
3188 + position_embedding
3189 + layer_elems * cfg.num_hidden_layers
3190 };
3191
3192 proj + glb_gn + sub_gn + vision_transformer
3193 } else {
3194 0
3195 };
3196
3197 embed_tokens + lm_head + norm + image_embed
3198 };
3199
3200 Ok(elems * dtype.size_in_bytes())
3201 }
3202
3203 fn layer_sizes_in_bytes(
3204 &self,
3205 config: &str,
3206 dtype: DType,
3207 weight_pack_factor: usize,
3208 _matformer_config: Option<&MatformerSliceConfig>,
3209 ) -> Result<Vec<usize>> {
3210 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3211 let per_layer_elems = {
3212 let input_layernorm = cfg.hidden_size;
3213 let post_attention_layernorm = cfg.hidden_size;
3214
3215 let size_in = cfg.hidden_size;
3216 let head_dim = cfg.head_dim();
3217 let op_size =
3218 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3219 let qkv_proj = size_in * op_size / weight_pack_factor;
3220 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3221
3222 let h_size = cfg.hidden_size;
3223 let i_size = cfg.intermediate_size;
3224 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3225 let down_proj = h_size * i_size / weight_pack_factor;
3226
3227 input_layernorm
3228 + post_attention_layernorm
3229 + qkv_proj
3230 + o_proj
3231 + gate_up_proj
3232 + down_proj
3233 };
3234 Ok(vec![
3235 per_layer_elems * dtype.size_in_bytes();
3236 cfg.num_hidden_layers
3237 ])
3238 }
3239
3240 fn num_layers(&self, config: &str) -> Result<usize> {
3241 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3242 Ok(cfg.num_hidden_layers)
3243 }
3244
3245 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3246 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3247
3248 let cfg = ModelConfigMetadata {
3249 max_seq_len: cfg.max_position_embeddings,
3250 num_layers: cfg.num_hidden_layers,
3251 hidden_size: cfg.hidden_size,
3252 num_kv_heads: cfg.num_key_value_heads(),
3253 num_attn_heads: cfg.num_attention_heads,
3254 sliding_window: cfg.sliding_window,
3255 k_head_dim: cfg.head_dim(),
3256 v_head_dim: cfg.head_dim(),
3257 };
3258
3259 Ok(Box::new(cfg))
3260 }
3261
3262 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3263 Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
3264 }
3265}
3266
3267pub struct Qwen2_5VLLoader;
3273
3274pub struct Qwen2_5VLPrefixer;
3275
3276impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3277 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3278 format!(
3279 "{}{prompt}",
3280 format!(
3281 "{}{}{}",
3282 Qwen2_5VLProcessor::VISION_START,
3283 Qwen2_5VLProcessor::IMAGE_PAD,
3284 Qwen2_5VLProcessor::VISION_END
3285 )
3286 .repeat(image_indexes.len())
3287 )
3288 }
3289}
3290
3291impl VisionModelLoader for Qwen2_5VLLoader {
3292 fn load(
3293 &self,
3294 config: &str,
3295 vb: ShardedVarBuilder,
3296 normal_loading_metadata: NormalLoadingMetadata,
3297 attention_mechanism: AttentionImplementation,
3298 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3299 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3300 Ok(Box::new(Qwen2_5VLModel::new(
3301 &cfg,
3302 vb,
3303 self.is_gptx(config),
3304 normal_loading_metadata,
3305 attention_mechanism,
3306 )?))
3307 }
3308 fn is_gptx(&self, _config: &str) -> bool {
3309 true
3310 }
3311 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3312 let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3313 Ok(Box::new(config))
3314 }
3315 fn get_processor(
3316 &self,
3317 _model_config: &str,
3318 _processor_config: Option<ProcessorConfig>,
3319 _preprocessor_config: PreProcessorConfig,
3320 max_edge: Option<u32>,
3321 ) -> Arc<dyn Processor + Send + Sync> {
3322 Arc::new(Qwen2_5VLProcessor::new(max_edge))
3323 }
3324 fn supports_paged_attention(&self, _config: &str) -> bool {
3325 false
3326 }
3327 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3328 Arc::new(Qwen2_5VLPrefixer)
3329 }
3330 fn modalities(&self, _config: &str) -> Result<Modalities> {
3331 Ok(Modalities {
3332 input: vec![SupportedModality::Text, SupportedModality::Vision],
3333 output: vec![SupportedModality::Text],
3334 })
3335 }
3336}
3337
3338impl IsqModelLoader for Qwen2_5VLLoader {
3339 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3340 Ok(vec![
3341 Regex::new(r"lm_head\.(weight|bias)$")?,
3342 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3344 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3345 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3346 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3347 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3349 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3350 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3351 ])
3352 }
3353 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3354 self.isq_layer_regexes(config)
3355 }
3356}
3357
3358impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3359 fn mapped_max_act_size_elems(
3360 &self,
3361 config: &str,
3362 params: &AutoDeviceMapParams,
3363 ) -> Result<usize> {
3364 let AutoDeviceMapParams::Vision {
3365 max_seq_len,
3366 max_batch_size,
3367 max_image_shape,
3368 max_num_images,
3369 } = params
3370 else {
3371 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3372 };
3373
3374 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3375
3376 let img_seq_len = {
3377 let cfg = &cfg.vision_config;
3378 let grid_t = max_num_images / cfg.temporal_patch_size;
3379 let grid_h = max_image_shape.0 / cfg.patch_size;
3380 let grid_w = max_image_shape.1 / cfg.patch_size;
3381 grid_t * grid_h * grid_w
3382 };
3383 let img_seq_len = img_seq_len * max_num_images;
3384
3385 let max_text_attn = {
3386 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3388 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3389 };
3390
3391 Ok(max_text_attn)
3392 }
3393
3394 fn non_mapped_max_act_size_elems(
3395 &self,
3396 config: &str,
3397 params: &AutoDeviceMapParams,
3398 ) -> Result<usize> {
3399 let AutoDeviceMapParams::Vision {
3400 max_seq_len: _,
3401 max_batch_size,
3402 max_image_shape,
3403 max_num_images,
3404 } = params
3405 else {
3406 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3407 };
3408
3409 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3410
3411 let img_seq_len = {
3412 let cfg = &cfg.vision_config;
3413 let grid_t = max_num_images / cfg.temporal_patch_size;
3414 let grid_h = max_image_shape.0 / cfg.patch_size;
3415 let grid_w = max_image_shape.1 / cfg.patch_size;
3416 grid_t * grid_h * grid_w
3417 };
3418
3419 let max_vision_attn = {
3420 let cfg = &cfg.vision_config;
3421 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3422 };
3423
3424 Ok(max_vision_attn)
3425 }
3426
3427 fn non_mapped_size_in_bytes(
3428 &self,
3429 config: &str,
3430 dtype: DType,
3431 weight_pack_factor: usize,
3432 _matformer_config: Option<&MatformerSliceConfig>,
3433 ) -> Result<usize> {
3434 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3435 let text_elems = {
3436 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3437 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3439 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3440 } else {
3441 0
3442 };
3443 let norm = cfg.hidden_size;
3444 embed_tokens + lm_head + norm
3445 };
3446
3447 let patch_merger = {
3448 let cfg = &cfg.vision_config;
3449 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3450
3451 let mlp0 = hidden_size * hidden_size + hidden_size;
3452 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3453
3454 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3455
3456 mlp0 + mlp2 + ln_q
3457 };
3458
3459 let patch_embed = {
3460 let cfg = &cfg.vision_config;
3461 let conv_cfg = Conv3dConfig {
3462 stride: cfg.patch_size,
3463 ..Default::default()
3464 };
3465 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3466 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3467 * kernel_sizes[0]
3468 * kernel_sizes[1]
3469 * kernel_sizes[2]
3470 };
3471
3472 let encoder_layer = {
3473 let cfg = &cfg.vision_config;
3474 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3475 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3476
3477 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3478 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3479 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3480
3481 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3482 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3483
3484 norm1 + norm2 + fc1 + fc2 + qkv + out
3485 };
3486
3487 let elems =
3488 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3489
3490 Ok(elems * dtype.size_in_bytes())
3491 }
3492
3493 fn layer_sizes_in_bytes(
3494 &self,
3495 config: &str,
3496 dtype: DType,
3497 weight_pack_factor: usize,
3498 _matformer_config: Option<&MatformerSliceConfig>,
3499 ) -> Result<Vec<usize>> {
3500 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3501 let per_layer_elems = {
3502 let input_layernorm = cfg.hidden_size;
3503 let post_attention_layernorm = cfg.hidden_size;
3504
3505 let size_in = cfg.hidden_size;
3506 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3507 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3508 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3509 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3510 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3511 let o_proj = size_q * size_in / weight_pack_factor;
3512
3513 let h_size = cfg.hidden_size;
3514 let i_size = cfg.intermediate_size;
3515 let gate_proj = h_size * i_size / weight_pack_factor;
3516 let up_proj = h_size * i_size / weight_pack_factor;
3517 let down_proj = i_size * h_size / weight_pack_factor;
3518
3519 input_layernorm
3520 + post_attention_layernorm
3521 + q_proj
3522 + k_proj
3523 + v_proj
3524 + o_proj
3525 + gate_proj
3526 + up_proj
3527 + down_proj
3528 };
3529 Ok(vec![
3530 per_layer_elems * dtype.size_in_bytes();
3531 cfg.num_hidden_layers
3532 ])
3533 }
3534
3535 fn num_layers(&self, config: &str) -> Result<usize> {
3536 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3537 Ok(cfg.num_hidden_layers)
3538 }
3539
3540 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3541 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3542
3543 let cfg = ModelConfigMetadata {
3544 max_seq_len: cfg.max_position_embeddings,
3545 num_layers: cfg.num_hidden_layers,
3546 hidden_size: cfg.hidden_size,
3547 num_kv_heads: cfg.num_key_value_heads,
3548 num_attn_heads: cfg.num_attention_heads,
3549 sliding_window: cfg.sliding_window,
3550 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3551 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3552 };
3553
3554 Ok(Box::new(cfg))
3555 }
3556
3557 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3558 Some(vec![NonMappedSubModel::Vision])
3559 }
3560}
3561
3562pub struct Gemma3Loader;
3568
3569pub struct Gemma3Prefixer;
3570
3571impl MultimodalPromptPrefixer for Gemma3Prefixer {
3572 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3573 prompt.to_string()
3574 }
3575}
3576
3577impl VisionModelLoader for Gemma3Loader {
3578 fn load(
3579 &self,
3580 config: &str,
3581 vb: ShardedVarBuilder,
3582 normal_loading_metadata: NormalLoadingMetadata,
3583 attention_mechanism: AttentionImplementation,
3584 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3585 let cfg: Gemma3Config = serde_json::from_str(config)?;
3586 Ok(Box::new(Gemma3Model::new(
3587 &cfg,
3588 vb,
3589 self.is_gptx(config),
3590 normal_loading_metadata,
3591 attention_mechanism,
3592 )?))
3593 }
3594 fn is_gptx(&self, _config: &str) -> bool {
3595 true
3596 }
3597 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3598 let config: Gemma3Config = serde_json::from_str(config)?;
3599 Ok(Box::new(config))
3600 }
3601 fn get_processor(
3602 &self,
3603 config: &str,
3604 processor_config: Option<ProcessorConfig>,
3605 _preprocessor_config: PreProcessorConfig,
3606 _max_edge: Option<u32>,
3607 ) -> Arc<dyn Processor + Send + Sync> {
3608 let config: Gemma3Config = serde_json::from_str(config).unwrap();
3609 Arc::new(Gemma3Processor::new(
3611 processor_config.unwrap_or_default(),
3612 matches!(config, Gemma3Config::WithVision { .. }),
3613 ))
3614 }
3615 fn supports_paged_attention(&self, _config: &str) -> bool {
3616 true
3617 }
3618 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3619 true
3620 }
3621 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3622 Arc::new(Gemma3Prefixer)
3623 }
3624 fn modalities(&self, _config: &str) -> Result<Modalities> {
3625 Ok(Modalities {
3626 input: vec![SupportedModality::Text, SupportedModality::Vision],
3627 output: vec![SupportedModality::Text],
3628 })
3629 }
3630}
3631
3632impl IsqModelLoader for Gemma3Loader {
3633 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3634 Ok(vec![
3635 Regex::new(r"lm_head\.(weight|bias)$")?,
3636 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3638 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3639 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3640 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3641 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3643 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3644 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3645 ])
3646 }
3647 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3648 Ok(vec![
3649 Regex::new(r"lm_head\.(weight|bias)$")?,
3650 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3652 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3653 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3654 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3655 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3657 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3658 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3659 ])
3660 }
3661}
3662
3663impl DeviceMappedModelLoader for Gemma3Loader {
3664 fn mapped_max_act_size_elems(
3665 &self,
3666 config: &str,
3667 params: &AutoDeviceMapParams,
3668 ) -> Result<usize> {
3669 let AutoDeviceMapParams::Vision {
3670 max_seq_len,
3671 max_batch_size,
3672 max_image_shape: _,
3673 max_num_images,
3674 } = params
3675 else {
3676 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3677 };
3678
3679 let cfg: Gemma3Config = serde_json::from_str(config)?;
3680
3681 match cfg {
3682 Gemma3Config::Text(text_config) => Ok(max_batch_size
3683 * text_config.num_attention_heads
3684 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)),
3685 Gemma3Config::WithVision {
3686 text_config,
3687 vision_config,
3688 ..
3689 } => {
3690 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3691 let img_seq_len = (num_patches + 1) * max_num_images;
3692
3693 let max_text_attn = {
3694 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3696 max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3697 };
3698 Ok(max_text_attn)
3699 }
3700 }
3701 }
3702
3703 fn non_mapped_max_act_size_elems(
3704 &self,
3705 config: &str,
3706 params: &AutoDeviceMapParams,
3707 ) -> Result<usize> {
3708 let AutoDeviceMapParams::Vision {
3709 max_seq_len: _,
3710 max_batch_size,
3711 max_image_shape: _,
3712 max_num_images,
3713 } = params
3714 else {
3715 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3716 };
3717
3718 let cfg: Gemma3Config = serde_json::from_str(config)?;
3719
3720 match cfg {
3721 Gemma3Config::WithVision { vision_config, .. } => {
3722 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3723 let img_seq_len = num_patches + 1;
3724
3725 let max_vision_attn = {
3726 (max_batch_size * max_num_images)
3727 * vision_config.num_attention_heads
3728 * img_seq_len
3729 * img_seq_len
3730 };
3731
3732 Ok(max_vision_attn)
3733 }
3734 Gemma3Config::Text(_) => Ok(0),
3735 }
3736 }
3737
3738 fn non_mapped_size_in_bytes(
3739 &self,
3740 config: &str,
3741 dtype: DType,
3742 weight_pack_factor: usize,
3743 _matformer_config: Option<&MatformerSliceConfig>,
3744 ) -> Result<usize> {
3745 let cfg: Gemma3Config = serde_json::from_str(config)?;
3746
3747 let text_elems = {
3748 let cfg = match &cfg {
3749 Gemma3Config::Text(cfg) => cfg,
3750 Gemma3Config::WithVision { text_config, .. } => text_config,
3751 };
3752 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3753 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3755 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3756 } else {
3757 0
3758 };
3759 let norm = cfg.hidden_size;
3760 embed_tokens + lm_head + norm
3761 };
3762
3763 let vision_transformer = if let Gemma3Config::WithVision {
3764 vision_config: cfg, ..
3765 } = &cfg
3766 {
3767 let post_layernorm = cfg.hidden_size;
3768
3769 let conv_config = Conv2dConfig {
3770 stride: cfg.patch_size,
3771 ..Default::default()
3772 };
3773 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3774 * cfg.patch_size
3775 * cfg.patch_size;
3776
3777 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3778 let num_patches = num_patches_per_side.pow(2);
3779 let position_embedding = num_patches * cfg.hidden_size;
3780
3781 let layer_elems = {
3782 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3783 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3784
3785 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3786 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3787
3788 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3789 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3790 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3791 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3792
3793 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3794 };
3795
3796 post_layernorm
3797 + patch_embedding
3798 + position_embedding
3799 + layer_elems * cfg.num_hidden_layers
3800 } else {
3801 0
3802 };
3803
3804 let elems = text_elems + vision_transformer;
3805
3806 Ok(elems * dtype.size_in_bytes())
3807 }
3808
3809 fn layer_sizes_in_bytes(
3810 &self,
3811 config: &str,
3812 dtype: DType,
3813 weight_pack_factor: usize,
3814 _matformer_config: Option<&MatformerSliceConfig>,
3815 ) -> Result<Vec<usize>> {
3816 let cfg: Gemma3Config = serde_json::from_str(config)?;
3817
3818 let txt_cfg = match &cfg {
3819 Gemma3Config::Text(cfg) => cfg,
3820 Gemma3Config::WithVision { text_config, .. } => text_config,
3821 };
3822 let per_layer_elems = {
3823 let cfg = txt_cfg;
3824
3825 let input_layernorm = cfg.hidden_size;
3826 let post_attention_layernorm = cfg.hidden_size;
3827
3828 let size_in = cfg.hidden_size;
3829 let size_q = cfg.head_dim * cfg.num_attention_heads;
3830 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3831 let q_proj =
3832 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3833 let k_proj =
3834 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3835 let v_proj =
3836 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3837 let o_proj =
3838 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3839
3840 let h_size = cfg.hidden_size;
3841 let i_size = cfg.intermediate_size;
3842 let gate_proj = h_size * i_size / weight_pack_factor;
3843 let up_proj = h_size * i_size / weight_pack_factor;
3844 let down_proj = i_size * h_size / weight_pack_factor;
3845
3846 input_layernorm
3847 + post_attention_layernorm
3848 + q_proj
3849 + k_proj
3850 + v_proj
3851 + o_proj
3852 + gate_proj
3853 + up_proj
3854 + down_proj
3855 };
3856 Ok(vec![
3857 per_layer_elems * dtype.size_in_bytes();
3858 txt_cfg.num_hidden_layers
3859 ])
3860 }
3861
3862 fn num_layers(&self, config: &str) -> Result<usize> {
3863 let cfg: Gemma3Config = serde_json::from_str(config)?;
3864
3865 let txt_cfg = match &cfg {
3866 Gemma3Config::Text(cfg) => cfg,
3867 Gemma3Config::WithVision { text_config, .. } => text_config,
3868 };
3869
3870 Ok(txt_cfg.num_hidden_layers)
3871 }
3872
3873 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3874 let cfg: Gemma3Config = serde_json::from_str(config)?;
3875
3876 let cfg = match &cfg {
3877 Gemma3Config::Text(cfg) => cfg,
3878 Gemma3Config::WithVision { text_config, .. } => text_config,
3879 };
3880
3881 let cfg = ModelConfigMetadata {
3882 max_seq_len: cfg.max_position_embeddings,
3883 num_layers: cfg.num_hidden_layers,
3884 hidden_size: cfg.hidden_size,
3885 num_kv_heads: cfg.num_key_value_heads,
3886 num_attn_heads: cfg.num_attention_heads,
3887 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3889 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3890 };
3891
3892 Ok(Box::new(cfg))
3893 }
3894
3895 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3896 Some(vec![NonMappedSubModel::Vision])
3897 }
3898}
3899
3900pub struct Mistral3Loader;
3906
3907pub struct Mistral3Prefixer;
3908
3909impl MultimodalPromptPrefixer for Mistral3Prefixer {
3910 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3911 prompt.to_string()
3912 }
3913}
3914
3915impl VisionModelLoader for Mistral3Loader {
3916 fn load(
3917 &self,
3918 config: &str,
3919 vb: ShardedVarBuilder,
3920 normal_loading_metadata: NormalLoadingMetadata,
3921 attention_mechanism: AttentionImplementation,
3922 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3923 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3924 Ok(Box::new(Mistral3Model::new(
3925 &cfg,
3926 vb,
3927 self.is_gptx(config),
3928 normal_loading_metadata,
3929 attention_mechanism,
3930 )?))
3931 }
3932 fn is_gptx(&self, _config: &str) -> bool {
3933 true
3934 }
3935 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3936 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3937 Ok(Box::new(cfg))
3938 }
3939 fn get_processor(
3940 &self,
3941 _model_config: &str,
3942 processor_config: Option<ProcessorConfig>,
3943 _preprocessor_config: PreProcessorConfig,
3944 _max_edge: Option<u32>,
3945 ) -> Arc<dyn Processor + Send + Sync> {
3946 Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3947 }
3948 fn supports_paged_attention(&self, _config: &str) -> bool {
3949 true
3950 }
3951 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3952 true
3953 }
3954 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3955 Arc::new(Mistral3Prefixer)
3956 }
3957 fn modalities(&self, _config: &str) -> Result<Modalities> {
3958 Ok(Modalities {
3959 input: vec![SupportedModality::Text, SupportedModality::Vision],
3960 output: vec![SupportedModality::Text],
3961 })
3962 }
3963}
3964
3965impl IsqModelLoader for Mistral3Loader {
3966 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3967 Ok(vec![
3968 Regex::new(r"lm_head\.(weight|bias)$")?,
3969 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3971 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3972 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3973 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3974 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3976 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3977 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3978 ])
3979 }
3980 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3981 Ok(vec![
3982 Regex::new(r"lm_head\.(weight|bias)$")?,
3983 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3985 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3986 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3987 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3988 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3990 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3991 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3992 ])
3993 }
3994}
3995
3996#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3997impl DeviceMappedModelLoader for Mistral3Loader {
3998 fn mapped_max_act_size_elems(
3999 &self,
4000 config: &str,
4001 params: &AutoDeviceMapParams,
4002 ) -> Result<usize> {
4003 let cfg: Mistral3Config = serde_json::from_str(config)?;
4004 let vcfg = &cfg.vision_config;
4005 let tcfg = &cfg.text_config;
4006
4007 let AutoDeviceMapParams::Vision {
4008 max_seq_len,
4009 max_batch_size,
4010 max_image_shape: (mut height, mut width),
4011 max_num_images,
4012 } = params
4013 else {
4014 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4015 };
4016
4017 let img_seq_len = {
4018 let (max_height, max_width) = (1540, 1540);
4022 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4023 if ratio > 1. {
4024 height = (height as f64 / ratio).floor() as usize;
4025 width = (width as f64 / ratio).floor() as usize;
4026 }
4027
4028 let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
4029 let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
4030
4031 height = num_height_tokens * vcfg.patch_size;
4032 width = num_width_tokens * vcfg.patch_size;
4033
4034 let num_height_tokens = height / vcfg.patch_size;
4035 let num_width_tokens = width / vcfg.patch_size;
4036
4037 (num_width_tokens + 1) * num_height_tokens
4038 };
4039
4040 let max_seq_len = img_seq_len * max_num_images + *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4042 Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
4043 }
4044
4045 fn non_mapped_max_act_size_elems(
4046 &self,
4047 config: &str,
4048 params: &AutoDeviceMapParams,
4049 ) -> Result<usize> {
4050 let cfg: Mistral3Config = serde_json::from_str(config)?;
4051 let cfg = &cfg.vision_config;
4052
4053 let AutoDeviceMapParams::Vision {
4054 max_seq_len: _,
4055 max_batch_size,
4056 max_image_shape: (mut height, mut width),
4057 max_num_images,
4058 } = params
4059 else {
4060 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4061 };
4062
4063 let img_seq_len = {
4064 let (max_height, max_width) = (1540, 1540);
4068 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4069 if ratio > 1. {
4070 height = (height as f64 / ratio).floor() as usize;
4071 width = (width as f64 / ratio).floor() as usize;
4072 }
4073
4074 let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4075 let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4076
4077 height = num_height_tokens * cfg.patch_size;
4078 width = num_width_tokens * cfg.patch_size;
4079
4080 let num_height_tokens = height / cfg.patch_size;
4081 let num_width_tokens = width / cfg.patch_size;
4082
4083 (num_width_tokens + 1) * num_height_tokens
4084 };
4085
4086 Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4087 }
4088
4089 fn non_mapped_size_in_bytes(
4090 &self,
4091 config: &str,
4092 dtype: DType,
4093 weight_pack_factor: usize,
4094 _matformer_config: Option<&MatformerSliceConfig>,
4095 ) -> Result<usize> {
4096 let cfg: Mistral3Config = serde_json::from_str(config)?;
4097
4098 let text_elems = {
4099 let cfg = &cfg.text_config;
4100
4101 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4102 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4104 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4105 } else {
4106 0
4107 };
4108 let norm = cfg.hidden_size;
4109 embed_tokens + lm_head + norm
4110 };
4111
4112 let vision_elems = {
4113 let cfg = &cfg.vision_config;
4114
4115 let patch_embed = {
4116 let conv_cfg = Conv2dConfig {
4117 stride: cfg.patch_size,
4118 ..Default::default()
4119 };
4120 cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4121 * cfg.patch_size
4122 * cfg.patch_size
4123 * cfg.patch_size
4124 };
4125 let ln_pre = cfg.hidden_size;
4126 let vision_layer = {
4127 let attn_norm = cfg.hidden_size;
4128 let ffn_norm = cfg.hidden_size;
4129
4130 let gate = cfg.hidden_size * cfg.intermediate_size;
4131 let up = cfg.hidden_size * cfg.intermediate_size;
4132 let down = cfg.hidden_size * cfg.intermediate_size;
4133
4134 let q = cfg.hidden_size * cfg.hidden_size;
4135 let k = cfg.hidden_size * cfg.hidden_size;
4136 let v = cfg.hidden_size * cfg.hidden_size;
4137 let o = cfg.hidden_size * cfg.hidden_size;
4138
4139 attn_norm + ffn_norm + gate + up + down + q + k + v + o
4140 };
4141
4142 patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4143 };
4144
4145 let elems = text_elems + vision_elems;
4146
4147 Ok(elems * dtype.size_in_bytes())
4148 }
4149
4150 fn layer_sizes_in_bytes(
4151 &self,
4152 config: &str,
4153 dtype: DType,
4154 weight_pack_factor: usize,
4155 _matformer_config: Option<&MatformerSliceConfig>,
4156 ) -> Result<Vec<usize>> {
4157 let cfg: Mistral3Config = serde_json::from_str(config)?;
4158 let cfg = &cfg.text_config;
4159
4160 let per_layer_elems = {
4161 let input_layernorm = cfg.hidden_size;
4162 let post_attention_layernorm = cfg.hidden_size;
4163
4164 let size_in = cfg.hidden_size;
4165 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4166 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4167 let q_proj = size_in * size_q / weight_pack_factor;
4168 let k_proj = size_in * size_kv / weight_pack_factor;
4169 let v_proj = size_in * size_kv / weight_pack_factor;
4170 let o_proj = size_q * size_in / weight_pack_factor;
4171
4172 let h_size = cfg.hidden_size;
4173 let i_size = cfg.intermediate_size;
4174 let gate_proj = h_size * i_size / weight_pack_factor;
4175 let up_proj = h_size * i_size / weight_pack_factor;
4176 let down_proj = i_size * h_size / weight_pack_factor;
4177
4178 input_layernorm
4179 + post_attention_layernorm
4180 + q_proj
4181 + k_proj
4182 + v_proj
4183 + o_proj
4184 + gate_proj
4185 + up_proj
4186 + down_proj
4187 };
4188 Ok(vec![
4189 per_layer_elems * dtype.size_in_bytes();
4190 cfg.num_hidden_layers
4191 ])
4192 }
4193
4194 fn num_layers(&self, config: &str) -> Result<usize> {
4195 let cfg: Mistral3Config = serde_json::from_str(config)?;
4196 let cfg = &cfg.text_config;
4197 Ok(cfg.num_hidden_layers)
4198 }
4199
4200 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4201 let cfg: Mistral3Config = serde_json::from_str(config)?;
4202 let cfg = &cfg.text_config;
4203
4204 let cfg = ModelConfigMetadata {
4205 max_seq_len: cfg.max_position_embeddings,
4206 num_layers: cfg.num_hidden_layers,
4207 hidden_size: cfg.hidden_size,
4208 num_kv_heads: cfg.num_key_value_heads,
4209 num_attn_heads: cfg.num_attention_heads,
4210 sliding_window: cfg.sliding_window,
4211 k_head_dim: cfg.head_dim(),
4212 v_head_dim: cfg.head_dim(),
4213 };
4214
4215 Ok(Box::new(cfg))
4216 }
4217
4218 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4219 Some(vec![NonMappedSubModel::Vision])
4220 }
4221}
4222
4223pub struct VLlama4Loader;
4229
4230pub struct VLlama4Prefixer;
4231
4232impl MultimodalPromptPrefixer for VLlama4Prefixer {
4233 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4234 format!(
4235 "{}{prompt}",
4236 llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4237 )
4238 }
4239}
4240
4241impl VisionModelLoader for VLlama4Loader {
4242 fn load(
4243 &self,
4244 config: &str,
4245 vb: ShardedVarBuilder,
4246 normal_loading_metadata: NormalLoadingMetadata,
4247 attention_mechanism: AttentionImplementation,
4248 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4249 let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4250 Ok(Box::new(Llama4Model::new(
4251 &cfg,
4252 vb,
4253 self.is_gptx(config),
4254 normal_loading_metadata,
4255 attention_mechanism,
4256 )?))
4257 }
4258 fn is_gptx(&self, _config: &str) -> bool {
4259 false
4260 }
4261 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4262 let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4263 Ok(Box::new(cfg))
4264 }
4265 fn get_processor(
4266 &self,
4267 _model_config: &str,
4268 processor_config: Option<ProcessorConfig>,
4269 _preprocessor_config: PreProcessorConfig,
4270 _max_edge: Option<u32>,
4271 ) -> Arc<dyn Processor + Send + Sync> {
4272 Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4273 }
4274 fn supports_paged_attention(&self, _config: &str) -> bool {
4275 true
4276 }
4277 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4278 Arc::new(VLlama4Prefixer)
4279 }
4280 fn modalities(&self, _config: &str) -> Result<Modalities> {
4281 Ok(Modalities {
4282 input: vec![SupportedModality::Text, SupportedModality::Vision],
4283 output: vec![SupportedModality::Text],
4284 })
4285 }
4286}
4287
4288impl IsqModelLoader for VLlama4Loader {
4289 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4290 Ok(vec![
4291 Regex::new(r"lm_head\.(weight|bias)$")?,
4292 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4294 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4295 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4296 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4297 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4299 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4300 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4301 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4302 Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4303 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4304 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4305 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4306 Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4308 Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4309 Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4310 ])
4311 }
4312 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4313 Ok(vec![
4314 Regex::new(r"lm_head\.(weight|bias)$")?,
4315 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4317 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4318 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4319 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4320 Regex::new(
4322 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4323 )?,
4324 Regex::new(
4325 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4326 )?,
4327 Regex::new(
4328 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4329 )?,
4330 Regex::new(
4331 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4332 )?,
4333 Regex::new(
4334 r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4335 )?,
4336 Regex::new(
4337 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4338 )?,
4339 Regex::new(
4340 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4341 )?,
4342 Regex::new(
4343 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4344 )?,
4345 Regex::new(
4347 r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4348 )?,
4349 Regex::new(
4350 r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4351 )?,
4352 Regex::new(
4353 r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4354 )?,
4355 ])
4356 }
4357}
4358
4359impl VLlama4Loader {
4360 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4363 fn run_dummy_processing(
4364 &self,
4365 cfg: &Llama4Config,
4366 height: usize,
4367 width: usize,
4368 max_num_images: usize,
4369 max_batch_size: usize,
4370 ) -> Result<(usize, usize)> {
4371 let cfg = &cfg.vision_config;
4372
4373 let img_processor =
4374 Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4375 let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4376 let res = img_processor.preprocess(
4377 vec![image; max_num_images],
4378 vec![],
4379 &PreProcessorConfig::default(),
4380 &Device::Cpu,
4381 (max_batch_size, max_num_images),
4382 )?;
4383
4384 let pixels_batch_size = res.pixel_values.dim(0)?;
4385 let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4386
4387 let (image_h, image_w) = (
4388 res.pixel_values.dim(D::Minus2).unwrap(),
4389 res.pixel_values.dim(D::Minus1).unwrap(),
4390 );
4391 let num_patches_per_chunk = (image_h / img_processor.patch_size)
4392 * (image_w / img_processor.patch_size)
4393 / img_processor.downsample_ratio;
4394
4395 Ok((
4396 pixels_max_batch_size,
4397 num_patches_per_chunk * pixels_max_batch_size,
4398 ))
4399 }
4400}
4401
4402impl DeviceMappedModelLoader for VLlama4Loader {
4403 fn mapped_max_act_size_elems(
4404 &self,
4405 config: &str,
4406 params: &AutoDeviceMapParams,
4407 ) -> Result<usize> {
4408 let AutoDeviceMapParams::Vision {
4409 max_seq_len,
4410 max_batch_size,
4411 max_image_shape: (height, width),
4412 max_num_images,
4413 } = params
4414 else {
4415 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4416 };
4417
4418 let cfg: Llama4Config = serde_json::from_str(config)?;
4419
4420 let (_pixels_batch_size, num_text_image_toks) =
4421 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4422
4423 let max_seq_len = max_seq_len.min(&ATTENTION_CHUNK_SIZE) + num_text_image_toks;
4424
4425 Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4426 }
4427 fn non_mapped_max_act_size_elems(
4428 &self,
4429 config: &str,
4430 params: &AutoDeviceMapParams,
4431 ) -> Result<usize> {
4432 let AutoDeviceMapParams::Vision {
4433 max_seq_len: _,
4434 max_batch_size,
4435 max_image_shape: (height, width),
4436 max_num_images,
4437 } = params
4438 else {
4439 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4440 };
4441
4442 let cfg: Llama4Config = serde_json::from_str(config)?;
4443
4444 let (pixels_batch_size, _num_text_image_toks) =
4445 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4446 let max_seq_len = cfg.vision_config.num_patches();
4447
4448 Ok((max_batch_size * pixels_batch_size)
4449 * cfg.vision_config.num_attention_heads
4450 * max_seq_len
4451 * max_seq_len)
4452 }
4453
4454 fn non_mapped_size_in_bytes(
4455 &self,
4456 config: &str,
4457 dtype: DType,
4458 weight_pack_factor: usize,
4459 _matformer_config: Option<&MatformerSliceConfig>,
4460 ) -> Result<usize> {
4461 let cfg: Llama4Config = serde_json::from_str(config)?;
4462 let tcfg = &cfg.text_config;
4463
4464 let text_elems = {
4465 let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4466 let lm_head = if !tcfg.tie_word_embeddings {
4467 tcfg.hidden_size * tcfg.vocab_size
4468 } else {
4469 0
4470 };
4471 let norm = tcfg.hidden_size;
4472 embed_tokens + lm_head + norm
4473 };
4474
4475 let vision_elems = {
4476 let cfg = &cfg.vision_config;
4477
4478 let num_patches = cfg.num_patches();
4479
4480 let unfold_elems =
4481 (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4482 let class_embeddng_elems = cfg.hidden_size;
4483 let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4484 let layernorm_pre_elems = cfg.hidden_size;
4485 let layernorm_post_elems = cfg.hidden_size;
4486
4487 let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4488 / weight_pack_factor
4489 + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4490
4491 let encoder_layer = {
4492 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4493 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4494
4495 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4496 let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4497 / weight_pack_factor
4498 + cfg.num_attention_heads * head_dim;
4499 let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4500 / weight_pack_factor
4501 + cfg.num_attention_heads * head_dim;
4502 let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4503 / weight_pack_factor
4504 + cfg.num_attention_heads * head_dim;
4505 let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4506 / weight_pack_factor
4507 + cfg.num_attention_heads * head_dim;
4508
4509 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4510 + cfg.intermediate_size;
4511 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4512 + cfg.hidden_size;
4513
4514 input_layernorm
4515 + post_attention_layernorm
4516 + q_proj
4517 + k_proj
4518 + v_proj
4519 + o_proj
4520 + fc1
4521 + fc2
4522 };
4523
4524 unfold_elems
4525 + class_embeddng_elems
4526 + positional_embedding_vlm_elems
4527 + layernorm_post_elems
4528 + layernorm_pre_elems
4529 + pixel_shuffle_elems
4530 + encoder_layer * cfg.num_hidden_layers
4531 };
4532
4533 let elems = text_elems + vision_elems;
4534
4535 Ok(elems * dtype.size_in_bytes())
4536 }
4537
4538 fn layer_sizes_in_bytes(
4539 &self,
4540 config: &str,
4541 dtype: DType,
4542 weight_pack_factor: usize,
4543 _matformer_config: Option<&MatformerSliceConfig>,
4544 ) -> Result<Vec<usize>> {
4545 let cfg: Llama4Config = serde_json::from_str(config)?;
4546 let tcfg = &cfg.text_config;
4547
4548 let mut per_layer_elems = Vec::new();
4549
4550 for layer_idx in 0..tcfg.num_hidden_layers {
4551 let input_layernorm = tcfg.hidden_size;
4552 let post_attention_layernorm = tcfg.hidden_size;
4553
4554 let size_in = tcfg.hidden_size;
4555 let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4556 let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4557 let q_proj = size_in * size_q / weight_pack_factor;
4558 let k_proj = size_in * size_kv / weight_pack_factor;
4559 let v_proj = size_in * size_kv / weight_pack_factor;
4560 let o_proj = size_q * size_in / weight_pack_factor;
4561
4562 let use_moe = tcfg.moe_layers().contains(&layer_idx);
4563 let moe_block = if use_moe {
4564 let h_size = tcfg.hidden_size;
4565 let i_size = tcfg.intermediate_size;
4566 let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4567 let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4568 let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4569
4570 gate_proj + up_proj + down_proj
4571 } else {
4572 let h_size = tcfg.hidden_size;
4573 let i_size = tcfg.intermediate_size_mlp;
4574 let gate_proj = h_size * i_size / weight_pack_factor;
4575 let up_proj = h_size * i_size / weight_pack_factor;
4576 let down_proj = i_size * h_size / weight_pack_factor;
4577
4578 gate_proj + up_proj + down_proj
4579 };
4580
4581 per_layer_elems.push(
4582 input_layernorm
4583 + post_attention_layernorm
4584 + q_proj
4585 + k_proj
4586 + v_proj
4587 + o_proj
4588 + moe_block,
4589 );
4590 }
4591
4592 Ok(per_layer_elems
4593 .into_iter()
4594 .map(|x| x * dtype.size_in_bytes())
4595 .collect())
4596 }
4597
4598 fn num_layers(&self, config: &str) -> Result<usize> {
4599 let cfg: Llama4Config = serde_json::from_str(config)?;
4600 Ok(cfg.text_config.num_hidden_layers)
4601 }
4602
4603 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4604 let cfg: Llama4Config = serde_json::from_str(config)?;
4605 let cfg = &cfg.text_config;
4606
4607 let cfg = ModelConfigMetadata {
4608 max_seq_len: cfg.max_position_embeddings,
4609 num_layers: cfg.num_hidden_layers,
4610 hidden_size: cfg.hidden_size,
4611 num_kv_heads: cfg.num_attention_heads,
4612 num_attn_heads: cfg.num_attention_heads,
4613 sliding_window: None,
4614 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4615 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4616 };
4617
4618 Ok(Box::new(cfg))
4619 }
4620
4621 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4622 Some(vec![NonMappedSubModel::Vision])
4623 }
4624}
4625
4626pub struct Gemma3nLoader;
4632
4633#[allow(dead_code)]
4634pub struct Gemma3nPrefixer;
4635
4636impl MultimodalPromptPrefixer for Gemma3nPrefixer {
4637 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4638 prompt.to_string()
4639 }
4640}
4641
4642impl VisionModelLoader for Gemma3nLoader {
4643 fn load(
4644 &self,
4645 config: &str,
4646 vb: ShardedVarBuilder,
4647 normal_loading_metadata: NormalLoadingMetadata,
4648 attention_mechanism: AttentionImplementation,
4649 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4650 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4651 Ok(Box::new(Gemma3nModel::new(
4652 &cfg,
4653 vb,
4654 self.is_gptx(config),
4655 normal_loading_metadata,
4656 attention_mechanism,
4657 )?))
4658 }
4659 fn is_gptx(&self, _config: &str) -> bool {
4660 true
4661 }
4662 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4663 let config: Gemma3nConfig = serde_json::from_str(config)?;
4664 Ok(Box::new(config))
4665 }
4666 fn get_processor(
4667 &self,
4668 _config: &str,
4669 processor_config: Option<ProcessorConfig>,
4670 _preprocessor_config: PreProcessorConfig,
4671 _max_edge: Option<u32>,
4672 ) -> Arc<dyn Processor + Send + Sync> {
4673 Arc::new(Gemma3nProcessor::new(
4675 processor_config.unwrap_or_default(),
4676 true,
4677 ))
4678 }
4679 fn supports_paged_attention(&self, _config: &str) -> bool {
4680 false
4681 }
4682 fn supports_prefix_cacher(&self, _config: &str) -> bool {
4683 true
4684 }
4685 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4686 Arc::new(Gemma3Prefixer)
4687 }
4688 fn modalities(&self, _config: &str) -> Result<Modalities> {
4689 Ok(Modalities {
4690 input: vec![
4691 SupportedModality::Text,
4692 SupportedModality::Vision,
4693 SupportedModality::Audio,
4694 ],
4695 output: vec![SupportedModality::Text],
4696 })
4697 }
4698}
4699
4700impl IsqModelLoader for Gemma3nLoader {
4701 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4702 Ok(vec![
4703 Regex::new(r"lm_head\.(weight|bias)$")?,
4704 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4706 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4707 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4708 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4709 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4711 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4712 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4713 Regex::new(r"conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$")?,
4715 Regex::new(r"conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$")?,
4716 Regex::new(r"conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$")?,
4717 Regex::new(
4718 r"conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4719 )?,
4720 Regex::new(r"conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4721 Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$")?,
4723 Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$")?,
4724 Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$")?,
4725 Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$")?,
4726 Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$")?,
4728 Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$")?,
4729 Regex::new(r"subsample_conv_projection\.input_proj_linear\.(weight|bias)$")?,
4731 Regex::new(r"embed_vision\.embedding_projection\.(weight|bias)$")?,
4733 Regex::new(r"embed_audio\.embedding_projection\.(weight|bias)$")?,
4734 ])
4735 }
4736 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4737 Ok(vec![
4738 Regex::new(r"lm_head\.(weight|bias)$")?,
4739 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4741 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4742 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4743 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4744 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4746 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4747 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4748 Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
4750 Regex::new(r"model\.language_model\.altup_projections\.(\d+)\.(weight|bias)$")?,
4751 Regex::new(r"model\.language_model\.altup_unembed_projections\.(\d+)\.(weight|bias)$")?,
4752 Regex::new(
4754 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$",
4755 )?,
4756 Regex::new(
4757 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$",
4758 )?,
4759 Regex::new(
4760 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$",
4761 )?,
4762 Regex::new(
4763 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4764 )?,
4765 Regex::new(r"model\.audio_tower\.conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4766 Regex::new(
4768 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$",
4769 )?,
4770 Regex::new(
4771 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$",
4772 )?,
4773 Regex::new(
4774 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$",
4775 )?,
4776 Regex::new(
4777 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$",
4778 )?,
4779 Regex::new(
4781 r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$",
4782 )?,
4783 Regex::new(
4784 r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$",
4785 )?,
4786 Regex::new(
4788 r"model\.audio_tower\.subsample_conv_projection\.input_proj_linear\.(weight|bias)$",
4789 )?,
4790 Regex::new(r"model\.embed_vision\.embedding_projection\.(weight|bias)$")?,
4792 Regex::new(r"model\.embed_audio\.embedding_projection\.(weight|bias)$")?,
4793 ])
4794 }
4795}
4796
4797impl DeviceMappedModelLoader for Gemma3nLoader {
4798 fn mapped_max_act_size_elems(
4799 &self,
4800 config: &str,
4801 params: &AutoDeviceMapParams,
4802 ) -> Result<usize> {
4803 let AutoDeviceMapParams::Vision {
4804 max_seq_len,
4805 max_batch_size,
4806 max_image_shape: _,
4807 max_num_images,
4808 } = params
4809 else {
4810 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4811 };
4812
4813 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4814 let text_cfg = &cfg.text_config;
4815
4816 let mut total_seq_len = *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4820
4821 {
4823 let msfa_spatial_size = 16; let vision_tokens_per_image = msfa_spatial_size * msfa_spatial_size; total_seq_len += vision_tokens_per_image * max_num_images;
4828 }
4829
4830 {
4832 let audio_tokens = cfg.audio_soft_tokens_per_image;
4835 total_seq_len += audio_tokens;
4836 }
4837
4838 let max_text_attn =
4840 max_batch_size * text_cfg.num_attention_heads * total_seq_len * total_seq_len;
4841
4842 Ok(max_text_attn)
4843 }
4844
4845 fn non_mapped_max_act_size_elems(
4846 &self,
4847 config: &str,
4848 params: &AutoDeviceMapParams,
4849 ) -> Result<usize> {
4850 let AutoDeviceMapParams::Vision {
4851 max_seq_len: _,
4852 max_batch_size,
4853 max_image_shape: _,
4854 max_num_images,
4855 } = params
4856 else {
4857 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4858 };
4859
4860 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4861
4862 let mut max_activation = 0;
4864
4865 {
4867 let vision_tower_act = {
4877 let num_heads = 16; let spatial_size = 24; let seq_len = spatial_size * spatial_size;
4883
4884 max_batch_size * max_num_images * num_heads * seq_len * seq_len
4886 };
4887
4888 let vision_embed_act = {
4890 let msfa_channels = 2048; let spatial_size = 16; let vision_features =
4894 max_batch_size * max_num_images * msfa_channels * spatial_size * spatial_size;
4895
4896 let projected = max_batch_size
4898 * max_num_images
4899 * spatial_size
4900 * spatial_size
4901 * cfg.text_config.hidden_size;
4902
4903 vision_features.max(projected)
4904 };
4905
4906 max_activation = max_activation.max(vision_tower_act).max(vision_embed_act);
4907 }
4908
4909 {
4911 let audio_cfg = &cfg.audio_config;
4912
4913 let max_audio_frames = 1280;
4918
4919 let subsample_factor: usize = audio_cfg
4920 .sscp_conv_stride_size
4921 .iter()
4922 .map(|stride| stride[0]) .product();
4924 let audio_seq_after_subsample = max_audio_frames / subsample_factor;
4925
4926 let audio_encoder_act = {
4928 let intermediate_size = audio_cfg.hidden_size * 4; max_batch_size * audio_seq_after_subsample * intermediate_size
4933 };
4934
4935 let audio_attn_act = {
4937 let chunk_size = audio_cfg.conf_attention_chunk_size;
4939 let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
4940 + audio_cfg.conf_attention_context_right;
4941
4942 let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
4944
4945 max_batch_size
4946 * audio_cfg.conf_num_attention_heads
4947 * num_chunks
4948 * chunk_size
4949 * context_size
4950 };
4951
4952 max_activation = max_activation.max(audio_encoder_act).max(audio_attn_act);
4953 }
4954
4955 Ok(max_activation)
4956 }
4957
4958 fn non_mapped_size_in_bytes(
4959 &self,
4960 config: &str,
4961 dtype: DType,
4962 weight_pack_factor: usize,
4963 matformer_config: Option<&MatformerSliceConfig>,
4964 ) -> Result<usize> {
4965 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4966
4967 let text_cfg = if let Some(matformer_cfg) = matformer_config {
4969 use crate::device_map::DummyDeviceMapper;
4970 use crate::vision_models::gemma3n::text::handle_matformer_slicing;
4971
4972 let dummy_mapper = DummyDeviceMapper {
4973 nm_device: Device::Cpu,
4974 };
4975 let (adjusted_cfg, _, _, _, _) = handle_matformer_slicing(
4976 &cfg.text_config,
4977 &Some(matformer_cfg.clone()),
4978 &dummy_mapper,
4979 )?;
4980 adjusted_cfg
4981 } else {
4982 cfg.text_config.clone()
4983 };
4984
4985 let text_cfg = &text_cfg;
4986
4987 let text_elems = {
4989 let embed_tokens = text_cfg.hidden_size * text_cfg.vocab_size;
4991 let embed_tokens_per_layer = text_cfg.num_hidden_layers
4992 * text_cfg.hidden_size_per_layer_input
4993 * text_cfg.vocab_size_per_layer_input;
4994
4995 let lm_head = if !text_cfg.tie_word_embeddings || weight_pack_factor != 1 {
4997 text_cfg.hidden_size * text_cfg.vocab_size / weight_pack_factor
4998 } else {
4999 0
5000 };
5001
5002 let norm = text_cfg.hidden_size;
5004
5005 let altup_projections =
5007 (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5008 / weight_pack_factor;
5009 let altup_unembed_projections =
5010 (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5011 / weight_pack_factor;
5012
5013 let per_layer_model_projection = text_cfg.num_hidden_layers
5015 * text_cfg.hidden_size
5016 * text_cfg.hidden_size_per_layer_input
5017 / weight_pack_factor;
5018 let per_layer_projection_norm = text_cfg.hidden_size;
5019
5020 embed_tokens
5021 + embed_tokens_per_layer
5022 + lm_head
5023 + norm
5024 + altup_projections
5025 + altup_unembed_projections
5026 + per_layer_model_projection
5027 + per_layer_projection_norm
5028 };
5029
5030 let vision_elems = {
5032 let vision_cfg = &cfg.vision_config;
5033 let vision_tower_elems = {
5037 use crate::vision_models::gemma3n::vision::{
5038 gemma3n_mobilenet_def, make_divisible, BlockType, INPUT_CHANNELS,
5039 MSFA_EXPANSION_RATIO, MSFA_IN_CHANNELS, MSFA_OUT_CHANNELS, STEM_KERNEL_SIZE,
5040 STEM_OUT_CHANNELS,
5041 };
5042
5043 let stem_conv =
5045 INPUT_CHANNELS * STEM_OUT_CHANNELS * STEM_KERNEL_SIZE * STEM_KERNEL_SIZE;
5046 let stem_norm = STEM_OUT_CHANNELS; let mut in_chs = STEM_OUT_CHANNELS;
5050 let mut total_elems = stem_conv + stem_norm;
5051
5052 let block_defs = gemma3n_mobilenet_def();
5054
5055 for stage_blocks in block_defs.iter() {
5056 for block_type in stage_blocks.iter() {
5057 match block_type {
5058 BlockType::EdgeResidual {
5059 out_channels,
5060 kernel_size,
5061 stride: _,
5062 expand_ratio,
5063 ..
5064 } => {
5065 #[allow(clippy::cast_precision_loss)]
5066 let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5067 total_elems += in_chs * mid_chs * kernel_size * kernel_size; total_elems += mid_chs; total_elems += mid_chs * out_channels; total_elems += out_channels; in_chs = *out_channels;
5073 }
5074 BlockType::UniversalInvertedResidual {
5075 out_channels,
5076 start_kernel_size,
5077 mid_kernel_size,
5078 stride: _,
5079 expand_ratio,
5080 ..
5081 } => {
5082 #[allow(clippy::cast_precision_loss)]
5083 let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5084 if *expand_ratio != 1.0 {
5086 total_elems += in_chs * mid_chs; total_elems += mid_chs; }
5089 if *start_kernel_size > 0 {
5090 total_elems += mid_chs * start_kernel_size * start_kernel_size; total_elems += mid_chs; }
5093 if *mid_kernel_size > 0 {
5094 total_elems += mid_chs * mid_kernel_size * mid_kernel_size; total_elems += mid_chs; }
5097 total_elems += mid_chs * out_channels; total_elems += out_channels; total_elems += out_channels; in_chs = *out_channels;
5101 }
5102 BlockType::MultiQueryAttention {
5103 num_heads,
5104 kv_dim,
5105 kv_stride: _,
5106 ..
5107 } => {
5108 let dw_kernel_size = 3; total_elems += in_chs; total_elems += in_chs * num_heads * kv_dim; total_elems += in_chs * kv_dim; total_elems += in_chs * dw_kernel_size * dw_kernel_size; total_elems += *kv_dim; total_elems += 1; total_elems += *kv_dim; total_elems += num_heads * kv_dim * in_chs; total_elems += in_chs; }
5120 }
5121 }
5122 }
5123
5124 let msfa_in = MSFA_IN_CHANNELS.iter().sum::<usize>();
5126 let msfa_out = MSFA_OUT_CHANNELS;
5127 #[allow(clippy::cast_precision_loss)]
5128 let msfa_mid = make_divisible(msfa_in as f64 * MSFA_EXPANSION_RATIO, 8);
5129
5130 total_elems += msfa_in * msfa_mid; total_elems += msfa_mid; total_elems += msfa_mid * msfa_out; total_elems += msfa_out; total_elems += msfa_out; total_elems
5138 };
5139
5140 let embed_vision_elems = {
5142 let embedding = vision_cfg.vocab_size * vision_cfg.hidden_size;
5144
5145 let hard_norm = vision_cfg.hidden_size;
5147 let soft_norm = vision_cfg.hidden_size;
5148
5149 let projection = vision_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5151
5152 let post_norm = text_cfg.hidden_size;
5154
5155 embedding + hard_norm + soft_norm + projection + post_norm
5156 };
5157
5158 vision_tower_elems + embed_vision_elems
5159 };
5160
5161 let audio_elems = {
5163 let audio_cfg = &cfg.audio_config;
5164
5165 let subsample_conv_projection_elems = {
5167 let mut conv_elems = 0;
5169
5170 let in_ch_0 = 1;
5172 let out_ch_0 = audio_cfg.sscp_conv_channel_size[0];
5173 let kernel_0 = &audio_cfg.sscp_conv_kernel_size[0];
5174 conv_elems += in_ch_0 * out_ch_0 * kernel_0[0] * kernel_0[1];
5175
5176 let in_ch_1 = out_ch_0;
5178 let out_ch_1 = audio_cfg.sscp_conv_channel_size[1];
5179 let kernel_1 = &audio_cfg.sscp_conv_kernel_size[1];
5180 conv_elems += in_ch_1 * out_ch_1 * kernel_1[0] * kernel_1[1];
5181
5182 let norm_0 = out_ch_0; let norm_1 = out_ch_1; let mut f_out = audio_cfg.input_feat_size;
5188 for i in 0..2 {
5189 let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
5190 let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
5191 let pad_left = 1;
5192 let pad_right = 1;
5193 f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
5194 }
5195 let input_proj_in_features = out_ch_1 * f_out;
5196 let input_proj_linear =
5197 input_proj_in_features * audio_cfg.hidden_size / weight_pack_factor;
5198
5199 conv_elems + norm_0 + norm_1 + input_proj_linear
5200 };
5201
5202 let conformer_elems = {
5204 let mut total = 0;
5205
5206 for _ in 0..audio_cfg.conf_num_hidden_layers {
5207 let attention_elems = {
5209 let pre_attn_norm = audio_cfg.hidden_size;
5211 let post_norm = audio_cfg.hidden_size;
5212
5213 let q_proj =
5215 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5216 let k_proj =
5217 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5218 let v_proj =
5219 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5220 let post =
5221 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5222
5223 let pos_proj =
5225 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5226 let per_dim_scale =
5227 audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads; let inv_timescales = audio_cfg.hidden_size / 2; let pos_indices = audio_cfg.conf_attention_context_left
5230 + audio_cfg.conf_attention_context_right
5231 + 1;
5232
5233 let chunk_size = audio_cfg.conf_attention_chunk_size;
5235 let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5236 + audio_cfg.conf_attention_context_right;
5237 let local_causal_valid_mask = chunk_size * context_size; let invalid_logits_tensor = 1; pre_attn_norm
5241 + post_norm
5242 + q_proj
5243 + k_proj
5244 + v_proj
5245 + post
5246 + pos_proj
5247 + per_dim_scale
5248 + inv_timescales
5249 + pos_indices
5250 + local_causal_valid_mask
5251 + invalid_logits_tensor
5252 };
5253
5254 let ffw_elems = {
5256 let intermediate_size = audio_cfg.hidden_size * 4;
5262
5263 let ffw_start = {
5264 let pre_norm = audio_cfg.hidden_size;
5265 let layer_1 =
5266 audio_cfg.hidden_size * intermediate_size / weight_pack_factor;
5267 let layer_2 =
5268 intermediate_size * audio_cfg.hidden_size / weight_pack_factor;
5269 let post_norm = audio_cfg.hidden_size;
5270 pre_norm + layer_1 + layer_2 + post_norm
5271 };
5272
5273 let ffw_end = ffw_start; ffw_start + ffw_end
5276 };
5277
5278 let lconv1d_elems = {
5280 let pre_layer_norm = audio_cfg.hidden_size;
5282 let conv_norm = audio_cfg.hidden_size;
5283
5284 let linear_start = audio_cfg.hidden_size * (audio_cfg.hidden_size * 2)
5286 / weight_pack_factor;
5287 let linear_end =
5288 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5289
5290 let depthwise = audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
5292
5293 pre_layer_norm + conv_norm + linear_start + linear_end + depthwise
5294 };
5295
5296 let block_norm = audio_cfg.hidden_size;
5298
5299 total += attention_elems + ffw_elems + lconv1d_elems + block_norm;
5300 }
5301
5302 total
5303 };
5304
5305 let embed_audio_elems = {
5307 let embedding = audio_cfg.vocab_size * audio_cfg.hidden_size;
5309
5310 let hard_embedding_norm = audio_cfg.hidden_size; let soft_embedding_norm = audio_cfg.hidden_size; let embedding_post_projection_norm = text_cfg.hidden_size; let embedding_projection =
5317 audio_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5318
5319 embedding
5320 + hard_embedding_norm
5321 + soft_embedding_norm
5322 + embedding_post_projection_norm
5323 + embedding_projection
5324 };
5325
5326 subsample_conv_projection_elems + conformer_elems + embed_audio_elems
5327 };
5328
5329 let vision_dtype = if dtype == DType::F16 {
5330 DType::F32
5332 } else {
5333 dtype
5334 };
5335
5336 let total_elems = text_elems * dtype.size_in_bytes()
5337 + vision_elems * vision_dtype.size_in_bytes()
5338 + audio_elems * dtype.size_in_bytes();
5339
5340 Ok(total_elems)
5341 }
5342
5343 fn layer_sizes_in_bytes(
5344 &self,
5345 config: &str,
5346 dtype: DType,
5347 weight_pack_factor: usize,
5348 matformer_config: Option<&MatformerSliceConfig>,
5349 ) -> Result<Vec<usize>> {
5350 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5351
5352 let (text_cfg, _layer_rename_map, _layers_skipped) = if let Some(matformer_cfg) =
5354 matformer_config
5355 {
5356 use crate::device_map::DummyDeviceMapper;
5357 use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5358
5359 let dummy_mapper = DummyDeviceMapper {
5360 nm_device: Device::Cpu,
5361 };
5362 let (adjusted_cfg, _, _, layer_rename_map, layers_skipped) = handle_matformer_slicing(
5363 &cfg.text_config,
5364 &Some(matformer_cfg.clone()),
5365 &dummy_mapper,
5366 )?;
5367 (adjusted_cfg, layer_rename_map, layers_skipped)
5368 } else {
5369 (cfg.text_config.clone(), None, None)
5370 };
5371
5372 let text_cfg = &text_cfg;
5373
5374 let mut layer_sizes = Vec::new();
5376
5377 for layer_idx in 0..text_cfg.num_hidden_layers {
5381 let per_layer_elems = {
5382 let input_layernorm = text_cfg.hidden_size;
5384 let post_attention_layernorm = text_cfg.hidden_size;
5385 let pre_feedforward_layernorm = text_cfg.hidden_size;
5386 let post_feedforward_layernorm = text_cfg.hidden_size;
5387 let post_per_layer_input_norm = text_cfg.hidden_size;
5388
5389 let size_in = text_cfg.hidden_size;
5391 let size_q = text_cfg.num_attention_heads * text_cfg.head_dim;
5392 let size_kv = text_cfg.num_key_value_heads * text_cfg.head_dim;
5393
5394 let q_proj = size_in * size_q / weight_pack_factor;
5395 let k_proj = size_in * size_kv / weight_pack_factor;
5396 let v_proj = size_in * size_kv / weight_pack_factor;
5397 let o_proj = size_q * size_in / weight_pack_factor;
5398
5399 let q_norm = text_cfg.head_dim;
5401 let k_norm = text_cfg.head_dim;
5402 let v_norm = text_cfg.head_dim; let intermediate_size = match &text_cfg.intermediate_size {
5406 IntermediateSize::Single(size) => *size,
5407 IntermediateSize::PerLayer(sizes) => sizes[layer_idx],
5408 IntermediateSize::Matformer(sizes, _) => sizes[layer_idx],
5409 };
5410 let gate_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5411 let up_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5412 let down_proj = intermediate_size * text_cfg.hidden_size / weight_pack_factor;
5413
5414 let altup_elems = {
5416 let correct_output_scale = text_cfg.hidden_size;
5417 let correction_coefs = text_cfg.altup_num_inputs * text_cfg.altup_num_inputs;
5418 let prediction_coefs =
5419 text_cfg.altup_num_inputs * text_cfg.altup_num_inputs.pow(2);
5420 let modality_router = text_cfg.hidden_size * text_cfg.altup_num_inputs;
5421 let router_norm = text_cfg.hidden_size;
5422
5423 correct_output_scale
5424 + correction_coefs
5425 + prediction_coefs
5426 + modality_router
5427 + router_norm
5428 };
5429
5430 let laurel_elems = {
5432 let left = text_cfg.hidden_size * text_cfg.laurel_rank;
5433 let right = text_cfg.laurel_rank * text_cfg.hidden_size;
5434 let post_norm = text_cfg.hidden_size;
5435
5436 left + right + post_norm
5437 };
5438
5439 let per_layer_input_gate =
5441 text_cfg.hidden_size * text_cfg.hidden_size_per_layer_input;
5442 let per_layer_projection =
5443 text_cfg.hidden_size_per_layer_input * text_cfg.hidden_size;
5444
5445 input_layernorm
5446 + post_attention_layernorm
5447 + pre_feedforward_layernorm
5448 + post_feedforward_layernorm
5449 + post_per_layer_input_norm
5450 + q_proj
5451 + k_proj
5452 + v_proj
5453 + o_proj
5454 + q_norm
5455 + k_norm
5456 + v_norm
5457 + gate_proj
5458 + up_proj
5459 + down_proj
5460 + altup_elems
5461 + laurel_elems
5462 + per_layer_input_gate
5463 + per_layer_projection
5464 };
5465
5466 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5467 }
5468
5469 Ok(layer_sizes)
5470 }
5471
5472 fn num_layers(&self, config: &str) -> Result<usize> {
5473 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5474 Ok(cfg.text_config.num_hidden_layers)
5475 }
5476
5477 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5478 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5479 let cfg = cfg.text_config;
5480
5481 let cfg = ModelConfigMetadata {
5482 max_seq_len: cfg.max_position_embeddings,
5483 num_layers: cfg.num_hidden_layers,
5484 hidden_size: cfg.hidden_size,
5485 num_kv_heads: cfg.num_key_value_heads,
5486 num_attn_heads: cfg.num_attention_heads,
5487 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5489 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5490 };
5491
5492 Ok(Box::new(cfg))
5493 }
5494
5495 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5496 Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
5497 }
5498}
5499
5500pub struct Qwen3VLLoader;
5506
5507pub struct Qwen3VLPrefixer;
5508
5509impl MultimodalPromptPrefixer for Qwen3VLPrefixer {
5510 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
5511 format!(
5512 "{}{prompt}",
5513 format!(
5514 "{}{}{}",
5515 Qwen3VLProcessor::VISION_START,
5516 Qwen3VLProcessor::IMAGE_PAD,
5517 Qwen3VLProcessor::VISION_END
5518 )
5519 .repeat(image_indexes.len())
5520 )
5521 }
5522}
5523
5524impl VisionModelLoader for Qwen3VLLoader {
5525 fn load(
5526 &self,
5527 config: &str,
5528 vb: ShardedVarBuilder,
5529 normal_loading_metadata: NormalLoadingMetadata,
5530 attention_mechanism: AttentionImplementation,
5531 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
5532 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5533 Ok(Box::new(Qwen3VLModel::new(
5534 &cfg,
5535 vb,
5536 self.is_gptx(config),
5537 normal_loading_metadata,
5538 attention_mechanism,
5539 )?))
5540 }
5541 fn is_gptx(&self, _config: &str) -> bool {
5542 true
5543 }
5544 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5545 let config: Qwen3VLConfig = serde_json::from_str(config)?;
5546 Ok(Box::new(config))
5547 }
5548 fn get_processor(
5549 &self,
5550 _model_config: &str,
5551 _processor_config: Option<ProcessorConfig>,
5552 _preprocessor_config: PreProcessorConfig,
5553 max_edge: Option<u32>,
5554 ) -> Arc<dyn Processor + Send + Sync> {
5555 Arc::new(Qwen3VLProcessor::new(max_edge))
5556 }
5557 fn supports_paged_attention(&self, _config: &str) -> bool {
5558 true
5559 }
5560 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5561 Arc::new(Qwen3VLPrefixer)
5562 }
5563 fn modalities(&self, _config: &str) -> Result<Modalities> {
5564 Ok(Modalities {
5565 input: vec![SupportedModality::Text, SupportedModality::Vision],
5566 output: vec![SupportedModality::Text],
5567 })
5568 }
5569}
5570
5571impl IsqModelLoader for Qwen3VLLoader {
5572 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5573 Ok(vec![
5574 Regex::new(r"lm_head\.(weight|bias)$")?,
5575 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5577 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5578 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5579 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5580 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5582 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5583 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5584 ])
5585 }
5586 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5587 self.isq_layer_regexes(config)
5588 }
5589}
5590
5591impl DeviceMappedModelLoader for Qwen3VLLoader {
5592 fn mapped_max_act_size_elems(
5593 &self,
5594 config: &str,
5595 params: &AutoDeviceMapParams,
5596 ) -> Result<usize> {
5597 let AutoDeviceMapParams::Vision {
5598 max_seq_len,
5599 max_batch_size,
5600 max_image_shape,
5601 max_num_images,
5602 } = params
5603 else {
5604 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5605 };
5606
5607 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5608
5609 let img_seq_len = {
5611 let cfg = &cfg.vision_config;
5612 let grid_t = 1;
5614 let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
5616 let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
5617 grid_t * grid_h * grid_w * max_num_images
5618 };
5619
5620 let max_text_attn = {
5621 let cfg = &cfg.text_config;
5622 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5624 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5625 };
5626
5627 Ok(max_text_attn)
5628 }
5629
5630 fn non_mapped_max_act_size_elems(
5631 &self,
5632 config: &str,
5633 params: &AutoDeviceMapParams,
5634 ) -> Result<usize> {
5635 let AutoDeviceMapParams::Vision {
5636 max_seq_len: _,
5637 max_batch_size,
5638 max_image_shape,
5639 max_num_images,
5640 } = params
5641 else {
5642 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5643 };
5644
5645 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5646
5647 let img_seq_len = {
5649 let cfg = &cfg.vision_config;
5650 let grid_t = 1;
5652 let grid_h = max_image_shape.0 / cfg.patch_size;
5653 let grid_w = max_image_shape.1 / cfg.patch_size;
5654 grid_t * grid_h * grid_w
5655 };
5656
5657 let max_vision_attn = {
5658 let cfg = &cfg.vision_config;
5659 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5660 };
5661
5662 Ok(max_vision_attn)
5663 }
5664
5665 fn non_mapped_size_in_bytes(
5666 &self,
5667 config: &str,
5668 dtype: DType,
5669 weight_pack_factor: usize,
5670 _matformer_config: Option<&MatformerSliceConfig>,
5671 ) -> Result<usize> {
5672 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5673 let tie = cfg.tie_word_embeddings;
5674 let text_elems = {
5675 let cfg = &cfg.text_config;
5676 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5677 let lm_head = if !tie || weight_pack_factor != 1 {
5679 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5680 } else {
5681 0
5682 };
5683 let norm = cfg.hidden_size;
5684 embed_tokens + lm_head + norm
5685 };
5686
5687 let patch_merger = {
5688 let cfg = &cfg.vision_config;
5689 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
5690
5691 let mlp0 = hidden_size * hidden_size + hidden_size;
5692 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
5693
5694 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5695
5696 mlp0 + mlp2 + ln_q
5697 };
5698
5699 let patch_embed = {
5700 let cfg = &cfg.vision_config;
5701 let conv_cfg = Conv3dConfig {
5702 stride: cfg.patch_size,
5703 ..Default::default()
5704 };
5705 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
5706 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
5707 * kernel_sizes[0]
5708 * kernel_sizes[1]
5709 * kernel_sizes[2]
5710 };
5711
5712 let encoder_layer = {
5713 let cfg = &cfg.vision_config;
5714 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5715 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5716
5717 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
5718 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
5719 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
5720
5721 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
5722 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
5723
5724 norm1 + norm2 + fc1 + fc2 + qkv + out
5725 };
5726
5727 let elems =
5728 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
5729
5730 Ok(elems * dtype.size_in_bytes())
5731 }
5732
5733 fn layer_sizes_in_bytes(
5734 &self,
5735 config: &str,
5736 dtype: DType,
5737 weight_pack_factor: usize,
5738 _matformer_config: Option<&MatformerSliceConfig>,
5739 ) -> Result<Vec<usize>> {
5740 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5741 let per_layer_elems = {
5742 let cfg = &cfg.text_config;
5743 let input_layernorm = cfg.hidden_size;
5744 let post_attention_layernorm = cfg.hidden_size;
5745
5746 let size_in = cfg.hidden_size;
5747 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
5748 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
5749 let q_proj = size_in * size_q / weight_pack_factor + size_q;
5750 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
5751 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
5752 let o_proj = size_q * size_in / weight_pack_factor;
5753
5754 let h_size = cfg.hidden_size;
5755 let i_size = cfg.intermediate_size;
5756 let gate_proj = h_size * i_size / weight_pack_factor;
5757 let up_proj = h_size * i_size / weight_pack_factor;
5758 let down_proj = i_size * h_size / weight_pack_factor;
5759
5760 input_layernorm
5761 + post_attention_layernorm
5762 + q_proj
5763 + k_proj
5764 + v_proj
5765 + o_proj
5766 + gate_proj
5767 + up_proj
5768 + down_proj
5769 };
5770 Ok(vec![
5771 per_layer_elems * dtype.size_in_bytes();
5772 cfg.text_config.num_hidden_layers
5773 ])
5774 }
5775
5776 fn num_layers(&self, config: &str) -> Result<usize> {
5777 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5778 let cfg = &cfg.text_config;
5779 Ok(cfg.num_hidden_layers)
5780 }
5781
5782 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5783 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5784 let cfg = &cfg.text_config;
5785
5786 let cfg = ModelConfigMetadata {
5787 max_seq_len: cfg.max_position_embeddings,
5788 num_layers: cfg.num_hidden_layers,
5789 hidden_size: cfg.hidden_size,
5790 num_kv_heads: cfg.num_key_value_heads,
5791 num_attn_heads: cfg.num_attention_heads,
5792 sliding_window: cfg.sliding_window,
5793 k_head_dim: cfg.head_dim,
5794 v_head_dim: cfg.head_dim,
5795 };
5796
5797 Ok(Box::new(cfg))
5798 }
5799
5800 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5801 Some(vec![NonMappedSubModel::Vision])
5802 }
5803}
5804
5805pub struct Qwen3VLMoELoader;
5811
5812pub struct Qwen3VLMoEPrefixer;
5813
5814impl MultimodalPromptPrefixer for Qwen3VLMoEPrefixer {
5815 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
5816 format!(
5817 "{}{prompt}",
5818 format!(
5819 "{}{}{}",
5820 Qwen3VLMoEProcessor::VISION_START,
5821 Qwen3VLMoEProcessor::IMAGE_PAD,
5822 Qwen3VLMoEProcessor::VISION_END
5823 )
5824 .repeat(image_indexes.len())
5825 )
5826 }
5827}
5828
5829impl VisionModelLoader for Qwen3VLMoELoader {
5830 fn load(
5831 &self,
5832 config: &str,
5833 vb: ShardedVarBuilder,
5834 normal_loading_metadata: NormalLoadingMetadata,
5835 attention_mechanism: AttentionImplementation,
5836 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
5837 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5838 Ok(Box::new(Qwen3VLMoEModel::new(
5839 &cfg,
5840 vb,
5841 self.is_gptx(config),
5842 normal_loading_metadata,
5843 attention_mechanism,
5844 )?))
5845 }
5846 fn is_gptx(&self, _config: &str) -> bool {
5847 true
5848 }
5849 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5850 let config: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5851 Ok(Box::new(config))
5852 }
5853 fn get_processor(
5854 &self,
5855 _model_config: &str,
5856 _processor_config: Option<ProcessorConfig>,
5857 _preprocessor_config: PreProcessorConfig,
5858 max_edge: Option<u32>,
5859 ) -> Arc<dyn Processor + Send + Sync> {
5860 Arc::new(Qwen3VLMoEProcessor::new(max_edge))
5861 }
5862 fn supports_paged_attention(&self, _config: &str) -> bool {
5863 true
5864 }
5865 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5866 Arc::new(Qwen3VLMoEPrefixer)
5867 }
5868 fn modalities(&self, _config: &str) -> Result<Modalities> {
5869 Ok(Modalities {
5870 input: vec![SupportedModality::Text, SupportedModality::Vision],
5871 output: vec![SupportedModality::Text],
5872 })
5873 }
5874}
5875
5876impl IsqModelLoader for Qwen3VLMoELoader {
5877 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5878 Ok(vec![
5879 Regex::new(r"lm_head\.(weight|bias)$")?,
5880 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5882 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5883 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5884 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5885 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5887 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5888 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5889 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate\.(weight|bias)$")?,
5891 Regex::new(
5893 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
5894 )?,
5895 Regex::new(
5896 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$",
5897 )?,
5898 Regex::new(
5899 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$",
5900 )?,
5901 ])
5902 }
5903 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5904 self.isq_layer_regexes(config)
5905 }
5906}
5907
5908impl DeviceMappedModelLoader for Qwen3VLMoELoader {
5909 fn mapped_max_act_size_elems(
5910 &self,
5911 config: &str,
5912 params: &AutoDeviceMapParams,
5913 ) -> Result<usize> {
5914 let AutoDeviceMapParams::Vision {
5915 max_seq_len,
5916 max_batch_size,
5917 max_image_shape,
5918 max_num_images,
5919 } = params
5920 else {
5921 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5922 };
5923
5924 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5925
5926 let img_seq_len = {
5928 let cfg = &cfg.vision_config;
5929 let grid_t = 1;
5931 let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
5933 let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
5934 grid_t * grid_h * grid_w * max_num_images
5935 };
5936
5937 let max_text_attn = {
5938 let cfg = &cfg.text_config;
5939 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5941 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5942 };
5943
5944 Ok(max_text_attn)
5945 }
5946
5947 fn non_mapped_max_act_size_elems(
5948 &self,
5949 config: &str,
5950 params: &AutoDeviceMapParams,
5951 ) -> Result<usize> {
5952 let AutoDeviceMapParams::Vision {
5953 max_seq_len: _,
5954 max_batch_size,
5955 max_image_shape,
5956 max_num_images,
5957 } = params
5958 else {
5959 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5960 };
5961
5962 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5963
5964 let img_seq_len = {
5966 let cfg = &cfg.vision_config;
5967 let grid_t = 1;
5969 let grid_h = max_image_shape.0 / cfg.patch_size;
5970 let grid_w = max_image_shape.1 / cfg.patch_size;
5971 grid_t * grid_h * grid_w
5972 };
5973
5974 let max_vision_attn = {
5975 let cfg = &cfg.vision_config;
5976 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5977 };
5978
5979 Ok(max_vision_attn)
5980 }
5981
5982 fn non_mapped_size_in_bytes(
5983 &self,
5984 config: &str,
5985 dtype: DType,
5986 weight_pack_factor: usize,
5987 _matformer_config: Option<&MatformerSliceConfig>,
5988 ) -> Result<usize> {
5989 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5990 let tie = cfg.tie_word_embeddings;
5991 let text_elems = {
5992 let cfg = &cfg.text_config;
5993 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5994 let lm_head = if !tie || weight_pack_factor != 1 {
5996 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5997 } else {
5998 0
5999 };
6000 let norm = cfg.hidden_size;
6001 embed_tokens + lm_head + norm
6002 };
6003
6004 let patch_merger = {
6005 let cfg = &cfg.vision_config;
6006 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
6007
6008 let mlp0 = hidden_size * hidden_size + hidden_size;
6009 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
6010
6011 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6012
6013 mlp0 + mlp2 + ln_q
6014 };
6015
6016 let patch_embed = {
6017 let cfg = &cfg.vision_config;
6018 let conv_cfg = Conv3dConfig {
6019 stride: cfg.patch_size,
6020 ..Default::default()
6021 };
6022 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
6023 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
6024 * kernel_sizes[0]
6025 * kernel_sizes[1]
6026 * kernel_sizes[2]
6027 };
6028
6029 let encoder_layer = {
6030 let cfg = &cfg.vision_config;
6031 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6032 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6033
6034 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
6035 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
6036 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
6037
6038 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
6039 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
6040
6041 norm1 + norm2 + fc1 + fc2 + qkv + out
6042 };
6043
6044 let elems =
6045 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
6046
6047 Ok(elems * dtype.size_in_bytes())
6048 }
6049
6050 fn layer_sizes_in_bytes(
6051 &self,
6052 config: &str,
6053 dtype: DType,
6054 weight_pack_factor: usize,
6055 _matformer_config: Option<&MatformerSliceConfig>,
6056 ) -> Result<Vec<usize>> {
6057 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6058 let text_cfg = &cfg.text_config;
6059
6060 let mut layer_sizes = Vec::with_capacity(text_cfg.num_hidden_layers);
6061
6062 for layer_idx in 0..text_cfg.num_hidden_layers {
6063 let input_layernorm = text_cfg.hidden_size;
6064 let post_attention_layernorm = text_cfg.hidden_size;
6065
6066 let size_in = text_cfg.hidden_size;
6067 let size_q = (text_cfg.hidden_size / text_cfg.num_attention_heads)
6068 * text_cfg.num_attention_heads;
6069 let size_kv = (text_cfg.hidden_size / text_cfg.num_attention_heads)
6070 * text_cfg.num_key_value_heads;
6071 let q_proj = size_in * size_q / weight_pack_factor + size_q;
6072 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
6073 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
6074 let o_proj = size_q * size_in / weight_pack_factor;
6075
6076 let is_moe = !text_cfg.mlp_only_layers.contains(&layer_idx)
6078 && (text_cfg.num_experts > 0
6079 && (layer_idx + 1) % text_cfg.decoder_sparse_step == 0);
6080
6081 let mlp_elems = if is_moe {
6082 let gate = text_cfg.hidden_size * text_cfg.num_experts;
6084 let per_expert = {
6085 let h_size = text_cfg.hidden_size;
6086 let i_size = text_cfg.moe_intermediate_size;
6087 let gate_proj = h_size * i_size / weight_pack_factor;
6088 let up_proj = h_size * i_size / weight_pack_factor;
6089 let down_proj = i_size * h_size / weight_pack_factor;
6090 gate_proj + up_proj + down_proj
6091 };
6092 gate + per_expert * text_cfg.num_experts
6093 } else {
6094 let h_size = text_cfg.hidden_size;
6096 let i_size = text_cfg.intermediate_size;
6097 let gate_proj = h_size * i_size / weight_pack_factor;
6098 let up_proj = h_size * i_size / weight_pack_factor;
6099 let down_proj = i_size * h_size / weight_pack_factor;
6100 gate_proj + up_proj + down_proj
6101 };
6102
6103 let per_layer_elems = input_layernorm
6104 + post_attention_layernorm
6105 + q_proj
6106 + k_proj
6107 + v_proj
6108 + o_proj
6109 + mlp_elems;
6110
6111 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
6112 }
6113
6114 Ok(layer_sizes)
6115 }
6116
6117 fn num_layers(&self, config: &str) -> Result<usize> {
6118 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6119 let cfg = &cfg.text_config;
6120 Ok(cfg.num_hidden_layers)
6121 }
6122
6123 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
6124 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6125 let cfg = &cfg.text_config;
6126
6127 let cfg = ModelConfigMetadata {
6128 max_seq_len: cfg.max_position_embeddings,
6129 num_layers: cfg.num_hidden_layers,
6130 hidden_size: cfg.hidden_size,
6131 num_kv_heads: cfg.num_key_value_heads,
6132 num_attn_heads: cfg.num_attention_heads,
6133 sliding_window: cfg.sliding_window,
6134 k_head_dim: cfg.head_dim,
6135 v_head_dim: cfg.head_dim,
6136 };
6137
6138 Ok(Box::new(cfg))
6139 }
6140
6141 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
6142 Some(vec![NonMappedSubModel::Vision])
6143 }
6144}