1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::device_map::DeviceMapper;
24use crate::layers::Conv3dConfig;
25use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
26use crate::pipeline::isq::IsqModelLoader;
27use crate::pipeline::loaders::AutoDeviceMapParams;
28use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
29use crate::pipeline::{
30 EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
31 SupportedModality,
32};
33use crate::utils::varbuilder_utils::DeviceForLoadTensor;
34use crate::vision_models::clip::ClipConfig;
35use crate::vision_models::gemma3::config::Gemma3Config;
36use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
37use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
38use crate::vision_models::idefics2_input_processor::Idefics2Processor;
39use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
40use crate::vision_models::image_processor::ImagePreProcessor;
41use crate::vision_models::inputs_processor::Phi4MMProcessor;
42use crate::vision_models::llama4::{
43 self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
44};
45use crate::vision_models::llava::config::Config as LLaVAConfig;
46use crate::vision_models::llava15::Model as LLaVA;
47use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
48use crate::vision_models::llava_next::Model as LLaVANext;
49use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
50use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
51use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
52use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
53use crate::vision_models::phi3_inputs_processor::Phi3Processor;
54use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
55use crate::vision_models::preprocessor_config::PreProcessorConfig;
56use crate::vision_models::processor_config::ProcessorConfig;
57use crate::vision_models::qwen2_5_vl::{
58 Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
59};
60use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
61use crate::vision_models::{minicpmo, phi4};
62
63pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
64 #[allow(clippy::too_many_arguments)]
66 fn forward(
67 &self,
68 input_ids: &Tensor,
69 pixel_values: Option<Tensor>,
70 seqlen_offsets: &[usize],
71 context_lens: Vec<(usize, usize)>,
72 position_ids: Vec<usize>,
73 model_specific_args: Box<dyn Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
75 flash_params: &FlashParams,
76 ) -> candle_core::Result<Tensor>;
77 fn device(&self) -> &Device;
78 fn cache(&self) -> &EitherCache;
79 fn cache_mut(&mut self) -> &mut EitherCache;
80 fn max_seq_len(&self) -> usize;
81 fn config(&self) -> &ModelConfigMetadata;
82 fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
84}
85
86pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
87 fn load(
88 &self,
89 config: &str,
90 vb: ShardedVarBuilder,
91 normal_loading_metadata: NormalLoadingMetadata,
92 attention_mechanism: AttentionImplementation,
93 ) -> Result<Box<dyn VisionModel + Send + Sync>>;
94 fn is_gptx(&self, config: &str) -> bool;
95 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
96 fn get_processor(
97 &self,
98 model_config: &str,
99 processor_config: Option<ProcessorConfig>,
100 preprocessor_config: PreProcessorConfig,
101 max_edge: Option<u32>,
102 ) -> Arc<dyn Processor + Send + Sync>;
103 fn supports_paged_attention(&self, config: &str) -> bool;
104 fn supports_prefix_cacher(&self, _config: &str) -> bool {
105 false
107 }
108 fn modalities(&self, config: &str) -> Result<Modalities>;
109 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
110 fn get_device_for_tensor(
111 &self,
112 config: &str,
113 _mapper: &dyn DeviceMapper,
114 loading_isq: bool,
115 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
116 if loading_isq {
117 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
118 } else {
119 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
120 let num_layers = self.model_config(config)?.num_layers();
121 let closure = move |name: String| {
122 if let Some(captures) = re.captures(&name) {
123 captures
124 .get(1)
125 .and_then(|m| m.as_str().parse::<usize>().ok())
126 .map(|l| l.min(num_layers))
127 .map(DeviceForLoadTensor::Idx)
128 .unwrap_or(DeviceForLoadTensor::Base)
129 } else {
130 DeviceForLoadTensor::Base
131 }
132 };
133
134 Ok(Arc::new(closure))
135 }
136 }
137}
138
139#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
140#[derive(Clone, Debug, Deserialize, PartialEq)]
141pub enum VisionLoaderType {
143 #[serde(rename = "phi3v")]
144 Phi3V,
145 #[serde(rename = "idefics2")]
146 Idefics2,
147 #[serde(rename = "llava_next")]
148 LLaVANext,
149 #[serde(rename = "llava")]
150 LLaVA,
151 #[serde(rename = "vllama")]
152 VLlama,
153 #[serde(rename = "qwen2vl")]
154 Qwen2VL,
155 #[serde(rename = "idefics3")]
156 Idefics3,
157 #[serde(rename = "minicpmo")]
158 MiniCpmO,
159 #[serde(rename = "phi4mm")]
160 Phi4MM,
161 #[serde(rename = "qwen2_5vl")]
162 Qwen2_5VL,
163 #[serde(rename = "gemma3")]
164 Gemma3,
165 #[serde(rename = "mistral3")]
166 Mistral3,
167 #[serde(rename = "llama4")]
168 Llama4,
169}
170
171impl VisionLoaderType {
173 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
174 match name {
175 "Phi3VForCausalLM" => Ok(Self::Phi3V),
176 "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
177 "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
178 "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
179 "MllamaForConditionalGeneration" => Ok(Self::VLlama),
180 "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
181 "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
182 "MiniCPMO" => Ok(Self::MiniCpmO),
183 "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
184 "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
185 "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
186 "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
187 "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
188 other => anyhow::bail!(
189 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
190 ),
191 }
192 }
193}
194
195impl FromStr for VisionLoaderType {
196 type Err = String;
197 fn from_str(s: &str) -> Result<Self, Self::Err> {
198 match s {
199 "phi3v" => Ok(Self::Phi3V),
200 "idefics2" => Ok(Self::Idefics2),
201 "llava_next" => Ok(Self::LLaVANext),
202 "llava" => Ok(Self::LLaVA),
203 "vllama" => Ok(Self::VLlama),
204 "qwen2vl" => Ok(Self::Qwen2VL),
205 "idefics3" => Ok(Self::Idefics3),
206 "minicpmo" => Ok(Self::MiniCpmO),
207 "phi4mm" => Ok(Self::Phi4MM),
208 "qwen2_5vl" => Ok(Self::Qwen2_5VL),
209 "gemma3" => Ok(Self::Gemma3),
210 "mistral3" => Ok(Self::Mistral3),
211 "llama4" => Ok(Self::Llama4),
212 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`.")),
213 }
214 }
215}
216
217impl std::fmt::Display for VisionLoaderType {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 let name = match self {
220 VisionLoaderType::Phi3V => "phi3v",
221 VisionLoaderType::Idefics2 => "idefics2",
222 VisionLoaderType::LLaVANext => "llava_next",
223 VisionLoaderType::LLaVA => "llava",
224 VisionLoaderType::VLlama => "vllama",
225 VisionLoaderType::Qwen2VL => "qwen2vl",
226 VisionLoaderType::Idefics3 => "idefics3",
227 VisionLoaderType::MiniCpmO => "minicpmo",
228 VisionLoaderType::Phi4MM => "phi4mm",
229 VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
230 VisionLoaderType::Gemma3 => "gemma3",
231 VisionLoaderType::Mistral3 => "mistral3",
232 VisionLoaderType::Llama4 => "llama4",
233 };
234 write!(f, "{name}")
235 }
236}
237
238#[derive(Deserialize)]
239struct AutoVisionLoaderConfig {
240 architectures: Vec<String>,
241}
242
243pub struct AutoVisionLoader;
245
246impl AutoVisionLoader {
247 fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
248 let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
249 if auto_cfg.architectures.len() != 1 {
250 anyhow::bail!("Expected exactly one architecture in config");
251 }
252
253 let name = &auto_cfg.architectures[0];
254 let tp = VisionLoaderType::from_causal_lm_name(name)?;
255
256 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
257
258 Ok(match tp {
260 VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
261 VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
262 VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
263 VisionLoaderType::LLaVA => Box::new(LLaVALoader),
264 VisionLoaderType::VLlama => Box::new(VLlamaLoader),
265 VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
266 VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
267 VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
268 VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
269 VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
270 VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
271 VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
272 VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
273 })
274 }
275}
276
277impl VisionModelLoader for AutoVisionLoader {
278 fn load(
279 &self,
280 config: &str,
281 vb: ShardedVarBuilder,
282 normal_loading_metadata: NormalLoadingMetadata,
283 attention_mechanism: AttentionImplementation,
284 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
285 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
286 }
287
288 fn is_gptx(&self, config: &str) -> bool {
289 Self::get_loader(config)
290 .expect("AutoVisionLoader get_loader")
291 .is_gptx(config)
292 }
293
294 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
295 Self::get_loader(config)?.get_config_repr(config)
296 }
297
298 fn get_processor(
299 &self,
300 model_config: &str,
301 proc_cfg: Option<ProcessorConfig>,
302 preproc_cfg: PreProcessorConfig,
303 max_edge: Option<u32>,
304 ) -> Arc<dyn Processor + Send + Sync> {
305 Self::get_loader(model_config)
306 .expect("AutoVisionLoader get_loader")
307 .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
308 }
309
310 fn supports_paged_attention(&self, config: &str) -> bool {
311 Self::get_loader(config)
312 .expect("AutoVisionLoader")
313 .supports_paged_attention(config)
314 }
315
316 fn modalities(&self, config: &str) -> Result<Modalities> {
317 Self::get_loader(config)?.modalities(config)
318 }
319
320 fn supports_prefix_cacher(&self, config: &str) -> bool {
321 Self::get_loader(config)
322 .expect("AutoVisionLoader")
323 .supports_prefix_cacher(config)
324 }
325
326 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
327 Self::get_loader(config)
328 .expect("AutoVisionLoader")
329 .prefixer(config)
330 }
331
332 fn get_device_for_tensor(
333 &self,
334 config: &str,
335 mapper: &dyn DeviceMapper,
336 loading_isq: bool,
337 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
338 Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
339 }
340}
341
342impl IsqModelLoader for AutoVisionLoader {
343 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
344 Self::get_loader(config)?.isq_layer_regexes(config)
345 }
346 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
347 Self::get_loader(config)?.immediate_isq_predicates(config)
348 }
349}
350
351impl DeviceMappedModelLoader for AutoVisionLoader {
352 fn mapped_max_act_size_elems(
353 &self,
354 config: &str,
355 params: &AutoDeviceMapParams,
356 prompt_chunksize: usize,
357 ) -> Result<usize> {
358 Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
359 }
360 fn non_mapped_max_act_size_elems(
361 &self,
362 config: &str,
363 params: &AutoDeviceMapParams,
364 ) -> Result<usize> {
365 Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
366 }
367 fn non_mapped_size_in_bytes(
368 &self,
369 config: &str,
370 dtype: DType,
371 weight_pack_factor: usize,
372 ) -> Result<usize> {
373 Self::get_loader(config)?.non_mapped_size_in_bytes(config, dtype, weight_pack_factor)
374 }
375 fn layer_sizes_in_bytes(
376 &self,
377 config: &str,
378 dtype: DType,
379 weight_pack_factor: usize,
380 ) -> Result<Vec<usize>> {
381 Self::get_loader(config)?.layer_sizes_in_bytes(config, dtype, weight_pack_factor)
382 }
383 fn num_layers(&self, config: &str) -> Result<usize> {
384 Self::get_loader(config)?.num_layers(config)
385 }
386 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
387 Self::get_loader(config)?.model_config(config)
388 }
389}
390
391macro_rules! bias_if {
392 ($cond:expr, $size:expr) => {
393 if $cond {
394 $size
395 } else {
396 0
397 }
398 };
399}
400
401fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
402 let pre_layer_norm = cfg.hidden_size;
403 let final_layer_norm = cfg.hidden_size;
404
405 let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
406 let num_positions = num_patches + 1;
407
408 let class_embedding = cfg.hidden_size;
409
410 let position_ids = num_positions;
411 let position_embedding = num_positions * cfg.hidden_size;
412
413 let conv2dconfig = Conv2dConfig {
414 stride: cfg.patch_size,
415 ..Default::default()
416 };
417 let patch_embedding =
418 cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
419
420 let encoder_layer_elems = {
421 let layer_norm1 = cfg.hidden_size;
422 let layer_norm2 = cfg.hidden_size;
423
424 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
425 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
426 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
427 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
428
429 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
430 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
431
432 layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
433 };
434
435 pre_layer_norm
436 + final_layer_norm
437 + class_embedding
438 + position_ids
439 + position_embedding
440 + patch_embedding
441 + cfg.num_hidden_layers * encoder_layer_elems
442}
443
444pub struct Phi3VLoader;
450
451pub struct Phi3VPrefixer;
452
453impl MultimodalPromptPrefixer for Phi3VPrefixer {
454 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
455 format!(
457 "{}{prompt}",
458 image_indexes
459 .into_iter()
460 .map(|image_index| format!("<|image_{}|>", image_index + 1))
461 .join("")
462 )
463 }
464}
465
466impl VisionModelLoader for Phi3VLoader {
467 fn load(
468 &self,
469 config: &str,
470 vb: ShardedVarBuilder,
471 normal_loading_metadata: NormalLoadingMetadata,
472 attention_mechanism: AttentionImplementation,
473 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
474 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
475 Ok(Box::new(Phi3::new(
476 &cfg,
477 vb,
478 self.is_gptx(config),
479 normal_loading_metadata,
480 attention_mechanism,
481 )?))
482 }
483 fn is_gptx(&self, _config: &str) -> bool {
484 true
485 }
486 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
487 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
488 Ok(Box::new(cfg))
489 }
490 fn get_processor(
491 &self,
492 _model_config: &str,
493 processor_config: Option<ProcessorConfig>,
494 preprocessor_config: PreProcessorConfig,
495 _max_edge: Option<u32>,
496 ) -> Arc<dyn Processor + Send + Sync> {
497 Phi3Processor::new_processor(processor_config, preprocessor_config)
498 }
499 fn supports_paged_attention(&self, _config: &str) -> bool {
500 true
501 }
502 fn supports_prefix_cacher(&self, _config: &str) -> bool {
503 true
504 }
505 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
506 Arc::new(Phi3VPrefixer)
507 }
508 fn modalities(&self, _config: &str) -> Result<Modalities> {
509 Ok(Modalities {
510 input: vec![SupportedModality::Text, SupportedModality::Vision],
511 output: vec![SupportedModality::Text],
512 })
513 }
514}
515
516impl IsqModelLoader for Phi3VLoader {
517 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
518 Ok(vec![
519 Regex::new(r"lm_head\.(weight|bias)$")?,
520 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
522 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
523 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
525 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
526 ])
527 }
528 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
529 self.isq_layer_regexes(config)
530 }
531}
532
533impl DeviceMappedModelLoader for Phi3VLoader {
534 fn mapped_max_act_size_elems(
535 &self,
536 config: &str,
537 params: &AutoDeviceMapParams,
538 _prompt_chunksize: usize,
539 ) -> Result<usize> {
540 let AutoDeviceMapParams::Vision {
542 max_seq_len,
543 max_batch_size,
544 max_image_shape: _,
545 max_num_images,
546 } = params
547 else {
548 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
549 };
550
551 let cfg: Phi3Config = serde_json::from_str(config)?;
552
553 let vcfg = &PHI3V_CLIP_CONFIG;
554
555 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
556 let img_seq_len = (num_patches + 1) * max_num_images;
557
558 let max_text_attn = {
559 let max_seq_len = img_seq_len + max_seq_len;
561 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
562 };
563
564 Ok(max_text_attn)
565 }
566
567 fn non_mapped_max_act_size_elems(
568 &self,
569 config: &str,
570 params: &AutoDeviceMapParams,
571 ) -> Result<usize> {
572 let AutoDeviceMapParams::Vision {
574 max_seq_len: _,
575 max_batch_size,
576 max_image_shape: _,
577 max_num_images,
578 } = params
579 else {
580 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
581 };
582
583 let cfg: Phi3Config = serde_json::from_str(config)?;
584
585 let vcfg = &PHI3V_CLIP_CONFIG;
586
587 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
588 let img_seq_len = num_patches + 1;
589
590 let max_vision_attn = {
591 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
592 };
593
594 Ok(max_vision_attn)
595 }
596
597 fn non_mapped_size_in_bytes(
598 &self,
599 config: &str,
600 dtype: DType,
601 weight_pack_factor: usize,
602 ) -> Result<usize> {
603 let cfg: Phi3Config = serde_json::from_str(config)?;
604 let elems = {
605 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
606 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
608 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
609 } else {
610 0
611 };
612 let norm = cfg.hidden_size;
613
614 let image_embed = {
615 let projection_cls = cfg
616 .embd_layer
617 .projection_cls
618 .clone()
619 .unwrap_or("linear".to_string());
620 let with_learnable_separator =
621 cfg.embd_layer.with_learnable_separator.unwrap_or(false);
622 let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
623 let image_dim_out = cfg.img_processor.image_dim_out;
624
625 let proj = match (projection_cls.as_str(), use_hd_transform) {
626 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
627 ("mlp", true) => {
628 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
629 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
630 a + b
631 }
632 ("mlp", false) => {
633 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
634 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
635 a + b
636 }
637 _ => {
638 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
639 }
640 };
641
642 let (glb_gn, sub_gn) = if with_learnable_separator {
643 let glb_gn = image_dim_out * 4;
644 let sub_gn = image_dim_out * 4;
645 (glb_gn, sub_gn)
646 } else {
647 (0, 0)
648 };
649
650 let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
651
652 proj + glb_gn + sub_gn + clip_vit
653 };
654
655 embed_tokens + lm_head + norm + image_embed
656 };
657
658 Ok(elems * dtype.size_in_bytes())
659 }
660
661 fn layer_sizes_in_bytes(
662 &self,
663 config: &str,
664 dtype: DType,
665 weight_pack_factor: usize,
666 ) -> Result<Vec<usize>> {
667 let cfg: Phi3Config = serde_json::from_str(config)?;
668 let per_layer_elems = {
669 let input_layernorm = cfg.hidden_size;
670 let post_attention_layernorm = cfg.hidden_size;
671
672 let size_in = cfg.hidden_size;
673 let head_dim = cfg.head_dim();
674 let op_size =
675 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
676 let qkv_proj = size_in * op_size / weight_pack_factor;
677 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
678
679 let h_size = cfg.hidden_size;
680 let i_size = cfg.intermediate_size;
681 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
682 let down_proj = h_size * i_size / weight_pack_factor;
683
684 input_layernorm
685 + post_attention_layernorm
686 + qkv_proj
687 + o_proj
688 + gate_up_proj
689 + down_proj
690 };
691 Ok(vec![
692 per_layer_elems * dtype.size_in_bytes();
693 cfg.num_hidden_layers
694 ])
695 }
696
697 fn num_layers(&self, config: &str) -> Result<usize> {
698 let cfg: Phi3Config = serde_json::from_str(config)?;
699 Ok(cfg.num_hidden_layers)
700 }
701
702 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
703 let cfg: Phi3Config = serde_json::from_str(config)?;
704
705 let cfg = ModelConfigMetadata {
706 max_seq_len: cfg.max_position_embeddings,
707 num_layers: cfg.num_hidden_layers,
708 hidden_size: cfg.hidden_size,
709 num_kv_heads: cfg.num_key_value_heads,
710 num_attn_heads: cfg.num_attention_heads,
711 sliding_window: cfg.sliding_window,
712 k_head_dim: cfg.head_dim(),
713 v_head_dim: cfg.head_dim(),
714 };
715
716 Ok(Box::new(cfg))
717 }
718
719 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
720 Some(vec![NonMappedSubModel::Vision])
721 }
722}
723
724pub struct Idefics2Loader;
730
731pub struct Idefics2Prefixer;
732
733impl MultimodalPromptPrefixer for Idefics2Prefixer {
734 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
735 prompt.to_string()
737 }
738}
739
740impl VisionModelLoader for Idefics2Loader {
741 fn load(
742 &self,
743 config: &str,
744 vb: ShardedVarBuilder,
745 normal_loading_metadata: NormalLoadingMetadata,
746 attention_mechanism: AttentionImplementation,
747 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
748 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
749 Ok(Box::new(Idefics2::new(
750 &cfg,
751 vb,
752 self.is_gptx(config),
753 normal_loading_metadata,
754 attention_mechanism,
755 )?))
756 }
757 fn is_gptx(&self, _config: &str) -> bool {
758 true
759 }
760 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
761 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
762 Ok(Box::new(cfg))
763 }
764 fn get_processor(
765 &self,
766 _model_config: &str,
767 processor_config: Option<ProcessorConfig>,
768 preprocessor_config: PreProcessorConfig,
769 max_edge: Option<u32>,
770 ) -> Arc<dyn Processor + Send + Sync> {
771 Arc::new(Idefics2Processor::new(
772 processor_config.unwrap(),
773 preprocessor_config,
774 max_edge,
775 ))
776 }
777 fn supports_paged_attention(&self, _config: &str) -> bool {
778 true
779 }
780 fn supports_prefix_cacher(&self, _config: &str) -> bool {
781 true
782 }
783 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
784 Arc::new(Idefics2Prefixer)
785 }
786 fn modalities(&self, _config: &str) -> Result<Modalities> {
787 Ok(Modalities {
788 input: vec![SupportedModality::Text, SupportedModality::Vision],
789 output: vec![SupportedModality::Text],
790 })
791 }
792}
793
794impl IsqModelLoader for Idefics2Loader {
795 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
796 Ok(vec![
797 Regex::new(r"lm_head\.(weight|bias)$")?,
798 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
800 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
801 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
802 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
803 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
805 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
806 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
807 ])
808 }
809 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
810 Ok(vec![
811 Regex::new(r"lm_head\.(weight|bias)$")?,
812 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
814 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
815 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
816 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
817 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
819 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
820 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
821 ])
822 }
823}
824
825impl DeviceMappedModelLoader for Idefics2Loader {
826 fn mapped_max_act_size_elems(
827 &self,
828 config: &str,
829 params: &AutoDeviceMapParams,
830 _prompt_chunksize: usize,
831 ) -> Result<usize> {
832 let AutoDeviceMapParams::Vision {
833 max_seq_len,
834 max_batch_size,
835 max_image_shape: _,
836 max_num_images,
837 } = params
838 else {
839 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
840 };
841
842 let cfg: Idefics2Config = serde_json::from_str(config)?;
843
844 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
845 let img_seq_len = (num_patches + 1) * max_num_images;
846
847 let max_text_attn = {
848 let max_seq_len = img_seq_len + max_seq_len;
850 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
851 };
852
853 Ok(max_text_attn)
854 }
855
856 fn non_mapped_max_act_size_elems(
857 &self,
858 config: &str,
859 params: &AutoDeviceMapParams,
860 ) -> Result<usize> {
861 let AutoDeviceMapParams::Vision {
862 max_seq_len: _,
863 max_batch_size,
864 max_image_shape: _,
865 max_num_images,
866 } = params
867 else {
868 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
869 };
870
871 let cfg: Idefics2Config = serde_json::from_str(config)?;
872
873 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
874 let img_seq_len = num_patches + 1;
875
876 let max_vision_attn = {
877 let images_factor = 5;
879
880 (max_batch_size * images_factor * max_num_images)
881 * cfg.vision_config.num_attention_heads
882 * img_seq_len
883 * img_seq_len
884 };
885
886 Ok(max_vision_attn)
887 }
888
889 fn non_mapped_size_in_bytes(
890 &self,
891 config: &str,
892 dtype: DType,
893 weight_pack_factor: usize,
894 ) -> Result<usize> {
895 let cfg: Idefics2Config = serde_json::from_str(config)?;
896 let text_elems = {
897 let tie_word_embeddings = cfg.tie_word_embeddings;
898 let cfg = &cfg.text_config;
899
900 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
901 let lm_head = if !tie_word_embeddings {
902 cfg.hidden_size * cfg.vocab_size
903 } else {
904 0
905 };
906 let norm = cfg.hidden_size;
907 embed_tokens + lm_head + norm
908 };
909
910 let connector_elems = {
911 let tcfg = &cfg.text_config;
912 let vcfg = &cfg.vision_config;
913 let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
914 let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
915 let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
916
917 let perceiver_elems = {
918 let tcfg = &cfg.text_config;
919 let pcfg = &cfg.perceiver_config;
920
921 let n_latents = pcfg.resampler_n_latents;
922 let hidden_size = tcfg.hidden_size;
923 let depth = pcfg.resampler_depth;
924
925 let norm = tcfg.hidden_size;
926 let latents = n_latents * hidden_size;
927
928 let layer_elems = {
929 let input_latents_norm = hidden_size;
930 let input_context_norm = hidden_size;
931 let post_attn_norm = hidden_size;
932
933 let num_heads = pcfg.resampler_n_heads;
934 let head_dim = pcfg.resampler_head_dim;
935 let num_key_value_heads = pcfg.num_key_value_heads;
936
937 let q_proj = hidden_size * num_heads * head_dim;
938 let k_proj = hidden_size * num_key_value_heads * head_dim;
939 let v_proj = hidden_size * num_key_value_heads * head_dim;
940 let o_proj = num_heads * head_dim * hidden_size;
941
942 let gate_proj = hidden_size * hidden_size * 4;
943 let up_proj = hidden_size * hidden_size * 4;
944 let down_proj = hidden_size * 4 * hidden_size;
945
946 input_latents_norm
947 + input_context_norm
948 + post_attn_norm
949 + q_proj
950 + k_proj
951 + v_proj
952 + o_proj
953 + gate_proj
954 + up_proj
955 + down_proj
956 };
957
958 norm + latents + layer_elems * depth
959 };
960
961 gate_proj + up_proj + down_proj + perceiver_elems
962 };
963
964 let vision_transformer = {
965 let cfg = &cfg.vision_config;
966
967 let post_layernorm = cfg.hidden_size;
968
969 let conv_config = Conv2dConfig {
970 stride: cfg.patch_size,
971 ..Default::default()
972 };
973 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
974 * cfg.patch_size
975 * cfg.patch_size;
976
977 let num_patches_per_side = cfg.image_size / cfg.patch_size;
978 let num_patches = num_patches_per_side.pow(2);
979 let position_embedding = num_patches * cfg.hidden_size;
980
981 let layer_elems = {
982 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
983 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
984
985 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
986 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
987
988 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
989 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
990 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
991 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
992
993 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
994 };
995
996 post_layernorm + patch_embedding + position_embedding + layer_elems
997 };
998
999 let elems = text_elems + connector_elems + vision_transformer;
1000
1001 Ok(elems * dtype.size_in_bytes())
1002 }
1003
1004 fn layer_sizes_in_bytes(
1005 &self,
1006 config: &str,
1007 dtype: DType,
1008 weight_pack_factor: usize,
1009 ) -> Result<Vec<usize>> {
1010 let cfg: Idefics2Config = serde_json::from_str(config)?;
1011 let cfg = cfg.text_config;
1012 let per_layer_elems = {
1013 let input_layernorm = cfg.hidden_size;
1014 let post_attention_layernorm = cfg.hidden_size;
1015
1016 let size_in = cfg.hidden_size;
1017 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1018 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1019 let q_proj = size_in * size_q / weight_pack_factor;
1020 let k_proj = size_in * size_kv / weight_pack_factor;
1021 let v_proj = size_in * size_kv / weight_pack_factor;
1022 let o_proj = size_q * size_in / weight_pack_factor;
1023
1024 let h_size = cfg.hidden_size;
1025 let i_size = cfg.intermediate_size;
1026 let gate_proj = h_size * i_size / weight_pack_factor;
1027 let up_proj = h_size * i_size / weight_pack_factor;
1028 let down_proj = i_size * h_size / weight_pack_factor;
1029
1030 input_layernorm
1031 + post_attention_layernorm
1032 + q_proj
1033 + k_proj
1034 + v_proj
1035 + o_proj
1036 + gate_proj
1037 + up_proj
1038 + down_proj
1039 };
1040 Ok(vec![
1041 per_layer_elems * dtype.size_in_bytes();
1042 cfg.num_hidden_layers
1043 ])
1044 }
1045
1046 fn num_layers(&self, config: &str) -> Result<usize> {
1047 let cfg: Idefics2Config = serde_json::from_str(config)?;
1048 Ok(cfg.text_config.num_hidden_layers)
1049 }
1050 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1051 let cfg: Idefics2Config = serde_json::from_str(config)?;
1052 let cfg = &cfg.text_config;
1053
1054 let cfg = ModelConfigMetadata {
1055 max_seq_len: cfg.max_position_embeddings,
1056 num_layers: cfg.num_hidden_layers,
1057 hidden_size: cfg.hidden_size,
1058 num_kv_heads: cfg.num_key_value_heads,
1059 num_attn_heads: cfg.num_attention_heads,
1060 sliding_window: cfg.sliding_window,
1061 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1062 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1063 };
1064
1065 Ok(Box::new(cfg))
1066 }
1067
1068 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1069 Some(vec![NonMappedSubModel::Vision])
1070 }
1071}
1072
1073pub struct LLaVANextLoader;
1079
1080pub struct LLaVANextPrefixer;
1081
1082impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1083 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1084 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1085 }
1086}
1087
1088impl VisionModelLoader for LLaVANextLoader {
1089 fn load(
1090 &self,
1091 config: &str,
1092 vb: ShardedVarBuilder,
1093 normal_loading_metadata: NormalLoadingMetadata,
1094 attention_mechanism: AttentionImplementation,
1095 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1096 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1097 Ok(Box::new(LLaVANext::new(
1098 &cfg,
1099 vb,
1100 self.is_gptx(config),
1101 normal_loading_metadata,
1102 attention_mechanism,
1103 )?))
1104 }
1105 fn is_gptx(&self, _config: &str) -> bool {
1106 false
1107 }
1108 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1109 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1110 Ok(Box::new(cfg))
1111 }
1112 fn get_processor(
1113 &self,
1114 model_config: &str,
1115 _processor_config: Option<ProcessorConfig>,
1116 _preprocessor_config: PreProcessorConfig,
1117 _max_edge: Option<u32>,
1118 ) -> Arc<dyn Processor + Send + Sync> {
1119 Arc::new(LLaVANextProcessor::new(model_config))
1120 }
1121 fn supports_paged_attention(&self, _config: &str) -> bool {
1122 true
1123 }
1124 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1125 true
1126 }
1127 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1128 Arc::new(LLaVANextPrefixer)
1129 }
1130 fn modalities(&self, _config: &str) -> Result<Modalities> {
1131 Ok(Modalities {
1132 input: vec![SupportedModality::Text, SupportedModality::Vision],
1133 output: vec![SupportedModality::Text],
1134 })
1135 }
1136}
1137
1138impl IsqModelLoader for LLaVANextLoader {
1139 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1140 Ok(vec![
1141 Regex::new(r"lm_head\.(weight|bias)$")?,
1142 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1144 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1145 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1146 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1147 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1149 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1150 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1151 ])
1152 }
1153 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1154 Ok(vec![
1155 Regex::new(r"lm_head\.(weight|bias)$")?,
1156 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1158 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1159 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1160 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1161 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1163 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1164 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1165 ])
1166 }
1167}
1168
1169impl DeviceMappedModelLoader for LLaVANextLoader {
1170 fn mapped_max_act_size_elems(
1171 &self,
1172 config: &str,
1173 params: &AutoDeviceMapParams,
1174 _prompt_chunksize: usize,
1175 ) -> Result<usize> {
1176 let AutoDeviceMapParams::Vision {
1177 max_seq_len,
1178 max_batch_size,
1179 max_image_shape,
1180 max_num_images,
1181 } = params
1182 else {
1183 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1184 };
1185
1186 let config: LLaVAConfig = serde_json::from_str(config)?;
1187
1188 #[allow(clippy::cast_possible_truncation)]
1189 let img_seq_len =
1190 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1191 &config,
1192 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1193 );
1194 let img_seq_len = img_seq_len * max_num_images;
1195
1196 let max_text_attn = {
1197 let cfg = &config.text_config;
1198 let max_seq_len = img_seq_len + max_seq_len;
1200
1201 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1202 };
1203
1204 Ok(max_text_attn)
1205 }
1206
1207 fn non_mapped_max_act_size_elems(
1208 &self,
1209 config: &str,
1210 params: &AutoDeviceMapParams,
1211 ) -> Result<usize> {
1212 let AutoDeviceMapParams::Vision {
1213 max_seq_len: _,
1214 max_batch_size,
1215 max_image_shape,
1216 max_num_images,
1217 } = params
1218 else {
1219 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1220 };
1221
1222 let config: LLaVAConfig = serde_json::from_str(config)?;
1223
1224 #[allow(clippy::cast_possible_truncation)]
1225 let img_seq_len =
1226 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1227 &config,
1228 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1229 );
1230
1231 let max_vision_attn = {
1232 (max_batch_size * max_num_images)
1233 * config.vision_config.num_attention_heads
1234 * img_seq_len
1235 * img_seq_len
1236 };
1237
1238 Ok(max_vision_attn)
1239 }
1240
1241 fn non_mapped_size_in_bytes(
1242 &self,
1243 config: &str,
1244 dtype: DType,
1245 weight_pack_factor: usize,
1246 ) -> Result<usize> {
1247 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1248 let text_elems = {
1249 let cfg = &cfg.text_config;
1250 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1251 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1252 let norm = cfg.hidden_size;
1253 embed_tokens + lm_head + norm
1254 };
1255
1256 let image_newline = cfg.text_config.hidden_size;
1257 let mmproj = {
1258 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1259 + cfg.text_config.hidden_size;
1260 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1261 + cfg.text_config.hidden_size;
1262
1263 linear_1 + linear_2
1264 };
1265 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1266
1267 let elems = text_elems + image_newline + mmproj + vision_tower;
1268 Ok(elems * dtype.size_in_bytes())
1269 }
1270
1271 fn layer_sizes_in_bytes(
1272 &self,
1273 config: &str,
1274 dtype: DType,
1275 weight_pack_factor: usize,
1276 ) -> Result<Vec<usize>> {
1277 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1278 let per_layer_elems = {
1279 let cfg = &cfg.text_config;
1280 let input_layernorm = cfg.hidden_size;
1281 let post_attention_layernorm = cfg.hidden_size;
1282
1283 let size_in = cfg.hidden_size;
1284 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1285 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1286 let q_proj = size_in * size_q / weight_pack_factor;
1287 let k_proj = size_in * size_kv / weight_pack_factor;
1288 let v_proj = size_in * size_kv / weight_pack_factor;
1289 let o_proj = size_q * size_in / weight_pack_factor;
1290
1291 let h_size = cfg.hidden_size;
1292 let i_size = cfg.intermediate_size;
1293 let gate_proj = h_size * i_size / weight_pack_factor;
1294 let up_proj = h_size * i_size / weight_pack_factor;
1295 let down_proj = i_size * h_size / weight_pack_factor;
1296
1297 input_layernorm
1298 + post_attention_layernorm
1299 + q_proj
1300 + k_proj
1301 + v_proj
1302 + o_proj
1303 + gate_proj
1304 + up_proj
1305 + down_proj
1306 };
1307 Ok(vec![
1308 per_layer_elems * dtype.size_in_bytes();
1309 cfg.text_config.num_hidden_layers
1310 ])
1311 }
1312
1313 fn num_layers(&self, config: &str) -> Result<usize> {
1314 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1315 Ok(cfg.text_config.num_hidden_layers)
1316 }
1317
1318 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1319 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1320 let cfg = &cfg.text_config;
1321
1322 let cfg = ModelConfigMetadata {
1323 max_seq_len: cfg.max_position_embeddings,
1324 num_layers: cfg.num_hidden_layers,
1325 hidden_size: cfg.hidden_size,
1326 num_kv_heads: cfg.num_key_value_heads,
1327 num_attn_heads: cfg.num_attention_heads,
1328 sliding_window: cfg.sliding_window,
1329 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1330 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1331 };
1332
1333 Ok(Box::new(cfg))
1334 }
1335
1336 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1337 Some(vec![NonMappedSubModel::Vision])
1338 }
1339}
1340
1341pub struct LLaVALoader;
1347
1348pub struct LLaVAPrefixer;
1349
1350impl MultimodalPromptPrefixer for LLaVAPrefixer {
1351 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1352 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1353 }
1354}
1355
1356impl VisionModelLoader for LLaVALoader {
1357 fn load(
1358 &self,
1359 config: &str,
1360 vb: ShardedVarBuilder,
1361 normal_loading_metadata: NormalLoadingMetadata,
1362 attention_mechanism: AttentionImplementation,
1363 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1364 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1365 Ok(Box::new(LLaVA::new(
1366 &cfg,
1367 vb,
1368 self.is_gptx(config),
1369 normal_loading_metadata,
1370 attention_mechanism,
1371 )?))
1372 }
1373 fn is_gptx(&self, _config: &str) -> bool {
1374 false
1375 }
1376 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1377 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1378 Ok(Box::new(cfg))
1379 }
1380 fn get_processor(
1381 &self,
1382 model_config: &str,
1383 _processor_config: Option<ProcessorConfig>,
1384 _preprocessor_config: PreProcessorConfig,
1385 _max_edge: Option<u32>,
1386 ) -> Arc<dyn Processor + Send + Sync> {
1387 Arc::new(LLaVAProcessor::new(model_config))
1388 }
1389 fn supports_paged_attention(&self, _config: &str) -> bool {
1390 true
1391 }
1392 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1393 true
1394 }
1395 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1396 Arc::new(LLaVAPrefixer)
1397 }
1398 fn modalities(&self, _config: &str) -> Result<Modalities> {
1399 Ok(Modalities {
1400 input: vec![SupportedModality::Text, SupportedModality::Vision],
1401 output: vec![SupportedModality::Text],
1402 })
1403 }
1404}
1405
1406impl IsqModelLoader for LLaVALoader {
1407 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1408 Ok(vec![
1409 Regex::new(r"lm_head\.(weight|bias)$")?,
1410 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1412 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1413 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1414 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1415 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1417 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1418 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1419 ])
1420 }
1421 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1422 Ok(vec![
1423 Regex::new(r"lm_head\.(weight|bias)$")?,
1424 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1426 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1427 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1428 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1429 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1431 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1432 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1433 ])
1434 }
1435}
1436
1437impl DeviceMappedModelLoader for LLaVALoader {
1438 fn mapped_max_act_size_elems(
1439 &self,
1440 config: &str,
1441 params: &AutoDeviceMapParams,
1442 _prompt_chunksize: usize,
1443 ) -> Result<usize> {
1444 let AutoDeviceMapParams::Vision {
1445 max_seq_len,
1446 max_batch_size,
1447 max_image_shape: _,
1448 max_num_images,
1449 } = params
1450 else {
1451 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1452 };
1453
1454 let config: LLaVAConfig = serde_json::from_str(config)?;
1455
1456 let img_seq_len =
1457 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1458 let img_seq_len = img_seq_len * max_num_images;
1459
1460 let max_text_attn = {
1461 let cfg = &config.text_config;
1462 let max_seq_len = img_seq_len + max_seq_len;
1464
1465 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1466 };
1467
1468 Ok(max_text_attn)
1469 }
1470
1471 fn non_mapped_max_act_size_elems(
1472 &self,
1473 config: &str,
1474 params: &AutoDeviceMapParams,
1475 ) -> Result<usize> {
1476 let AutoDeviceMapParams::Vision {
1477 max_seq_len: _,
1478 max_batch_size,
1479 max_image_shape: _,
1480 max_num_images,
1481 } = params
1482 else {
1483 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1484 };
1485
1486 let config: LLaVAConfig = serde_json::from_str(config)?;
1487
1488 let img_seq_len =
1489 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1490
1491 let max_vision_attn = {
1492 (max_batch_size * max_num_images)
1493 * config.vision_config.num_attention_heads
1494 * img_seq_len
1495 * img_seq_len
1496 };
1497
1498 Ok(max_vision_attn)
1499 }
1500
1501 fn non_mapped_size_in_bytes(
1502 &self,
1503 config: &str,
1504 dtype: DType,
1505 weight_pack_factor: usize,
1506 ) -> Result<usize> {
1507 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1508 let text_elems = {
1509 let cfg = &cfg.text_config;
1510 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1511 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1512 let norm = cfg.hidden_size;
1513 embed_tokens + lm_head + norm
1514 };
1515
1516 let image_newline = cfg.text_config.hidden_size;
1517 let mmproj = {
1518 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1519 + cfg.text_config.hidden_size;
1520 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1521 + cfg.text_config.hidden_size;
1522
1523 linear_1 + linear_2
1524 };
1525 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1526
1527 let elems = text_elems + image_newline + mmproj + vision_tower;
1528 Ok(elems * dtype.size_in_bytes())
1529 }
1530
1531 fn layer_sizes_in_bytes(
1532 &self,
1533 config: &str,
1534 dtype: DType,
1535 weight_pack_factor: usize,
1536 ) -> Result<Vec<usize>> {
1537 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1538 let per_layer_elems = {
1539 let cfg = &cfg.text_config;
1540 let input_layernorm = cfg.hidden_size;
1541 let post_attention_layernorm = cfg.hidden_size;
1542
1543 let size_in = cfg.hidden_size;
1544 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1545 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1546 let q_proj = size_in * size_q / weight_pack_factor;
1547 let k_proj = size_in * size_kv / weight_pack_factor;
1548 let v_proj = size_in * size_kv / weight_pack_factor;
1549 let o_proj = size_q * size_in / weight_pack_factor;
1550
1551 let h_size = cfg.hidden_size;
1552 let i_size = cfg.intermediate_size;
1553 let gate_proj = h_size * i_size / weight_pack_factor;
1554 let up_proj = h_size * i_size / weight_pack_factor;
1555 let down_proj = i_size * h_size / weight_pack_factor;
1556
1557 input_layernorm
1558 + post_attention_layernorm
1559 + q_proj
1560 + k_proj
1561 + v_proj
1562 + o_proj
1563 + gate_proj
1564 + up_proj
1565 + down_proj
1566 };
1567 Ok(vec![
1568 per_layer_elems * dtype.size_in_bytes();
1569 cfg.text_config.num_hidden_layers
1570 ])
1571 }
1572
1573 fn num_layers(&self, config: &str) -> Result<usize> {
1574 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1575 Ok(cfg.text_config.num_hidden_layers)
1576 }
1577
1578 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1579 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1580 let cfg = &cfg.text_config;
1581
1582 let cfg = ModelConfigMetadata {
1583 max_seq_len: cfg.max_position_embeddings,
1584 num_layers: cfg.num_hidden_layers,
1585 hidden_size: cfg.hidden_size,
1586 num_kv_heads: cfg.num_key_value_heads,
1587 num_attn_heads: cfg.num_attention_heads,
1588 sliding_window: cfg.sliding_window,
1589 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1590 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1591 };
1592
1593 Ok(Box::new(cfg))
1594 }
1595
1596 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1597 Some(vec![NonMappedSubModel::Vision])
1598 }
1599}
1600
1601pub struct VLlamaLoader;
1607
1608pub struct VLlamaPrefixer;
1609
1610impl MultimodalPromptPrefixer for VLlamaPrefixer {
1611 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1612 format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1613 }
1614}
1615
1616impl VisionModelLoader for VLlamaLoader {
1617 fn load(
1618 &self,
1619 config: &str,
1620 vb: ShardedVarBuilder,
1621 normal_loading_metadata: NormalLoadingMetadata,
1622 attention_mechanism: AttentionImplementation,
1623 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1624 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1625 Ok(Box::new(MLlamaModel::new(
1626 &cfg,
1627 vb,
1628 self.is_gptx(config),
1629 normal_loading_metadata,
1630 attention_mechanism,
1631 )?))
1632 }
1633 fn is_gptx(&self, _config: &str) -> bool {
1634 true
1635 }
1636 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1637 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1638 Ok(Box::new(cfg))
1639 }
1640 fn get_processor(
1641 &self,
1642 _model_config: &str,
1643 _processor_config: Option<ProcessorConfig>,
1644 _preprocessor_config: PreProcessorConfig,
1645 _max_edge: Option<u32>,
1646 ) -> Arc<dyn Processor + Send + Sync> {
1647 Arc::new(MLlamaProcessor::new())
1648 }
1649 fn supports_paged_attention(&self, _config: &str) -> bool {
1650 false
1651 }
1652 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1653 true
1654 }
1655 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1656 Arc::new(VLlamaPrefixer)
1657 }
1658 fn modalities(&self, _config: &str) -> Result<Modalities> {
1659 Ok(Modalities {
1660 input: vec![SupportedModality::Text, SupportedModality::Vision],
1661 output: vec![SupportedModality::Text],
1662 })
1663 }
1664}
1665
1666impl IsqModelLoader for VLlamaLoader {
1667 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1668 let config: MLlamaConfig = serde_json::from_str(config)?;
1669 let cross_attn_layers = &config.text_config.cross_attention_layers;
1670 let transformer_layers =
1671 (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1672 let mut text_regexes = Vec::new();
1673 for layer in transformer_layers {
1674 text_regexes.extend(vec![
1675 Regex::new(&format!(
1677 r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1678 ))?,
1679 Regex::new(&format!(
1680 r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1681 ))?,
1682 Regex::new(&format!(
1683 r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1684 ))?,
1685 Regex::new(&format!(
1686 r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1687 ))?,
1688 Regex::new(&format!(
1690 r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1691 ))?,
1692 Regex::new(&format!(
1693 r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1694 ))?,
1695 Regex::new(&format!(
1696 r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1697 ))?,
1698 ]);
1699 }
1700 let vision_regexes = vec![
1701 Regex::new(
1703 r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1704 )?,
1705 Regex::new(
1706 r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1707 )?,
1708 Regex::new(
1709 r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1710 )?,
1711 Regex::new(
1712 r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1713 )?,
1714 Regex::new(
1716 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1717 )?,
1718 Regex::new(
1719 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1720 )?,
1721 Regex::new(
1722 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1723 )?,
1724 Regex::new(
1725 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1726 )?,
1727 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1729 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1730 ];
1731
1732 Ok([text_regexes, vision_regexes].concat())
1733 }
1734 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1735 self.isq_layer_regexes(config)
1736 }
1737}
1738
1739impl DeviceMappedModelLoader for VLlamaLoader {
1740 fn mapped_max_act_size_elems(
1741 &self,
1742 config: &str,
1743 params: &AutoDeviceMapParams,
1744 _prompt_chunksize: usize,
1745 ) -> Result<usize> {
1746 let AutoDeviceMapParams::Vision {
1747 max_seq_len,
1748 max_batch_size,
1749 max_image_shape: _,
1750 max_num_images,
1751 } = params
1752 else {
1753 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1754 };
1755
1756 let config: MLlamaConfig = serde_json::from_str(config)?;
1757
1758 let img_seq_len = {
1759 let cfg = &config.vision_config;
1760 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1761 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1762 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1763 };
1764 let img_seq_len = img_seq_len * max_num_images;
1765
1766 let max_cross_text_attn = {
1767 let cfg = &config.text_config;
1768 max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1769 };
1770
1771 let max_self_text_attn = {
1772 let cfg = &config.text_config;
1773 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1774 };
1775
1776 Ok(max_self_text_attn.max(max_cross_text_attn))
1777 }
1778
1779 fn non_mapped_max_act_size_elems(
1780 &self,
1781 config: &str,
1782 params: &AutoDeviceMapParams,
1783 ) -> Result<usize> {
1784 let AutoDeviceMapParams::Vision {
1785 max_seq_len: _,
1786 max_batch_size,
1787 max_image_shape: _,
1788 max_num_images,
1789 } = params
1790 else {
1791 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1792 };
1793
1794 let config: MLlamaConfig = serde_json::from_str(config)?;
1795
1796 let img_seq_len = {
1797 let cfg = &config.vision_config;
1798 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1799 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1800 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1801 };
1802 let max_vision_attn = {
1803 let cfg = &config.vision_config;
1804 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1805 };
1806
1807 Ok(max_vision_attn)
1808 }
1809
1810 fn non_mapped_size_in_bytes(
1811 &self,
1812 config: &str,
1813 dtype: DType,
1814 weight_pack_factor: usize,
1815 ) -> Result<usize> {
1816 let config: MLlamaConfig = serde_json::from_str(config)?;
1817 let text_elems = {
1818 let cfg = &config.text_config;
1819 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1820 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1822 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1823 } else {
1824 0
1825 };
1826 let norm = cfg.hidden_size;
1827 embed_tokens + lm_head + norm
1828 };
1829
1830 let vision_elems = {
1831 let cfg = &config.vision_config;
1832
1833 let conv_cfg = Conv2dConfig {
1834 stride: cfg.patch_size,
1835 ..Default::default()
1836 };
1837 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1838 * cfg.patch_size
1839 * cfg.patch_size;
1840
1841 let class_embedding = cfg.hidden_size;
1842
1843 let gated_positional_embedding = {
1844 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1845 let embedding = num_patches * cfg.hidden_size;
1846 let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1847 * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1848
1849 embedding + tile_embedding
1850 };
1851
1852 let pre_tile_positional_embedding =
1853 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1854 let post_tile_positional_embedding =
1855 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1856
1857 let layernorm_pre = cfg.hidden_size;
1858 let layernorm_post = cfg.hidden_size;
1859
1860 let encoder_layer = {
1861 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1862 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1863
1864 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1865 let q_proj =
1866 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1867 let k_proj =
1868 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1869 let v_proj =
1870 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1871 let o_proj =
1872 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1873
1874 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1875 + cfg.intermediate_size;
1876 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1877 + cfg.hidden_size;
1878
1879 input_layernorm
1880 + post_attention_layernorm
1881 + q_proj
1882 + k_proj
1883 + v_proj
1884 + o_proj
1885 + fc1
1886 + fc2
1887 };
1888
1889 patch_embedding
1890 + class_embedding
1891 + gated_positional_embedding
1892 + pre_tile_positional_embedding
1893 + post_tile_positional_embedding
1894 + layernorm_pre
1895 + layernorm_post
1896 + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1897 };
1898
1899 let elems = text_elems + vision_elems;
1900 Ok(elems * dtype.size_in_bytes())
1901 }
1902
1903 fn layer_sizes_in_bytes(
1904 &self,
1905 config: &str,
1906 dtype: DType,
1907 weight_pack_factor: usize,
1908 ) -> Result<Vec<usize>> {
1909 let config: MLlamaConfig = serde_json::from_str(config)?;
1910 let cfg = &config.text_config;
1911
1912 let mut layer_sizes = Vec::new();
1913
1914 for i in 0..cfg.num_hidden_layers {
1915 let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1916 1
1918 } else {
1919 weight_pack_factor
1920 };
1921
1922 let per_layer_elems = {
1923 let input_layernorm = cfg.hidden_size;
1924 let post_attention_layernorm = cfg.hidden_size;
1925
1926 let size_in = cfg.hidden_size;
1927 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1928 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1929 let q_proj = size_in * size_q / weight_pack_factor;
1930 let k_proj = size_in * size_kv / weight_pack_factor;
1931 let v_proj = size_in * size_kv / weight_pack_factor;
1932 let o_proj = size_q * size_in / weight_pack_factor;
1933
1934 let h_size = cfg.hidden_size;
1935 let i_size = cfg.intermediate_size;
1936 let gate_proj = h_size * i_size / weight_pack_factor;
1937 let up_proj = h_size * i_size / weight_pack_factor;
1938 let down_proj = i_size * h_size / weight_pack_factor;
1939
1940 input_layernorm
1941 + post_attention_layernorm
1942 + q_proj
1943 + k_proj
1944 + v_proj
1945 + o_proj
1946 + gate_proj
1947 + up_proj
1948 + down_proj
1949 };
1950
1951 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1952 }
1953
1954 Ok(layer_sizes)
1955 }
1956
1957 fn num_layers(&self, config: &str) -> Result<usize> {
1958 let config: MLlamaConfig = serde_json::from_str(config)?;
1959 Ok(config.text_config.num_hidden_layers)
1960 }
1961
1962 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1963 let cfg: MLlamaConfig = serde_json::from_str(config)?;
1964 let cfg = &cfg.text_config;
1965
1966 let cfg = ModelConfigMetadata {
1967 max_seq_len: cfg.max_position_embeddings,
1968 num_layers: cfg.num_hidden_layers,
1969 hidden_size: cfg.hidden_size,
1970 num_kv_heads: cfg.num_key_value_heads,
1971 num_attn_heads: cfg.num_attention_heads,
1972 sliding_window: None,
1973 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1974 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1975 };
1976
1977 Ok(Box::new(cfg))
1978 }
1979
1980 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1981 Some(vec![NonMappedSubModel::Vision])
1982 }
1983}
1984
1985pub struct Qwen2VLLoader;
1991
1992pub struct Qwen2VLPrefixer;
1993
1994impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
1995 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1996 format!(
1997 "{}{prompt}",
1998 format!(
1999 "{}{}{}",
2000 Qwen2VLProcessor::VISION_START,
2001 Qwen2VLProcessor::IMAGE_PAD,
2002 Qwen2VLProcessor::VISION_END
2003 )
2004 .repeat(image_indexes.len())
2005 )
2006 }
2007}
2008
2009impl VisionModelLoader for Qwen2VLLoader {
2010 fn load(
2011 &self,
2012 config: &str,
2013 vb: ShardedVarBuilder,
2014 normal_loading_metadata: NormalLoadingMetadata,
2015 attention_mechanism: AttentionImplementation,
2016 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2017 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2018 Ok(Box::new(Qwen2VLModel::new(
2019 &cfg,
2020 vb,
2021 self.is_gptx(config),
2022 normal_loading_metadata,
2023 attention_mechanism,
2024 )?))
2025 }
2026 fn is_gptx(&self, _config: &str) -> bool {
2027 true
2028 }
2029 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2030 let config: Qwen2VLConfig = serde_json::from_str(config)?;
2031 Ok(Box::new(config))
2032 }
2033 fn get_processor(
2034 &self,
2035 _model_config: &str,
2036 _processor_config: Option<ProcessorConfig>,
2037 _preprocessor_config: PreProcessorConfig,
2038 max_edge: Option<u32>,
2039 ) -> Arc<dyn Processor + Send + Sync> {
2040 Arc::new(Qwen2VLProcessor::new(max_edge))
2041 }
2042 fn supports_paged_attention(&self, _config: &str) -> bool {
2043 false
2044 }
2045 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2046 Arc::new(Qwen2VLPrefixer)
2047 }
2048 fn modalities(&self, _config: &str) -> Result<Modalities> {
2049 Ok(Modalities {
2050 input: vec![SupportedModality::Text, SupportedModality::Vision],
2051 output: vec![SupportedModality::Text],
2052 })
2053 }
2054}
2055
2056impl IsqModelLoader for Qwen2VLLoader {
2057 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2058 Ok(vec![
2059 Regex::new(r"lm_head\.(weight|bias)$")?,
2060 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2062 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2063 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2064 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2065 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2067 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2068 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2069 ])
2070 }
2071 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2072 self.isq_layer_regexes(config)
2073 }
2074}
2075
2076impl DeviceMappedModelLoader for Qwen2VLLoader {
2077 fn mapped_max_act_size_elems(
2078 &self,
2079 config: &str,
2080 params: &AutoDeviceMapParams,
2081 _prompt_chunksize: usize,
2082 ) -> Result<usize> {
2083 let AutoDeviceMapParams::Vision {
2084 max_seq_len,
2085 max_batch_size,
2086 max_image_shape,
2087 max_num_images,
2088 } = params
2089 else {
2090 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2091 };
2092
2093 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2094
2095 let img_seq_len = {
2096 let cfg = &cfg.vision_config;
2097 let grid_t = max_num_images / cfg.temporal_patch_size;
2098 let grid_h = max_image_shape.0 / cfg.patch_size;
2099 let grid_w = max_image_shape.1 / cfg.patch_size;
2100 grid_t * grid_h * grid_w
2101 };
2102 let img_seq_len = img_seq_len * max_num_images;
2103
2104 let max_text_attn = {
2105 let max_seq_len = img_seq_len + max_seq_len;
2107 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2108 };
2109
2110 Ok(max_text_attn)
2111 }
2112
2113 fn non_mapped_max_act_size_elems(
2114 &self,
2115 config: &str,
2116 params: &AutoDeviceMapParams,
2117 ) -> Result<usize> {
2118 let AutoDeviceMapParams::Vision {
2119 max_seq_len: _,
2120 max_batch_size,
2121 max_image_shape,
2122 max_num_images,
2123 } = params
2124 else {
2125 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2126 };
2127
2128 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2129
2130 let img_seq_len = {
2131 let cfg = &cfg.vision_config;
2132 let grid_t = max_num_images / cfg.temporal_patch_size;
2133 let grid_h = max_image_shape.0 / cfg.patch_size;
2134 let grid_w = max_image_shape.1 / cfg.patch_size;
2135 grid_t * grid_h * grid_w
2136 };
2137
2138 let max_vision_attn = {
2139 let cfg = &cfg.vision_config;
2140 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2141 };
2142
2143 Ok(max_vision_attn)
2144 }
2145
2146 fn non_mapped_size_in_bytes(
2147 &self,
2148 config: &str,
2149 dtype: DType,
2150 weight_pack_factor: usize,
2151 ) -> Result<usize> {
2152 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2153 let text_elems = {
2154 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2155 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2157 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2158 } else {
2159 0
2160 };
2161 let norm = cfg.hidden_size;
2162 embed_tokens + lm_head + norm
2163 };
2164
2165 let patch_merger = {
2166 let cfg = &cfg.vision_config;
2167 let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2168
2169 let mlp0 = hidden_size * hidden_size + hidden_size;
2170 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2171
2172 let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2173
2174 mlp0 + mlp2 + ln_q
2175 };
2176
2177 let patch_embed = {
2178 let cfg = &cfg.vision_config;
2179 let conv_cfg = Conv3dConfig {
2180 stride: cfg.patch_size,
2181 ..Default::default()
2182 };
2183 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2184 cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2185 * kernel_sizes[0]
2186 * kernel_sizes[1]
2187 * kernel_sizes[2]
2188 };
2189
2190 let encoder_layer = {
2191 let cfg = &cfg.vision_config;
2192 let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2193 let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2194
2195 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2196 let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2197 let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2198 let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2199
2200 let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2201 let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2202
2203 norm1 + norm2 + fc1 + fc2 + qkv + out
2204 };
2205
2206 let elems =
2207 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2208
2209 Ok(elems * dtype.size_in_bytes())
2210 }
2211
2212 fn layer_sizes_in_bytes(
2213 &self,
2214 config: &str,
2215 dtype: DType,
2216 weight_pack_factor: usize,
2217 ) -> Result<Vec<usize>> {
2218 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2219 let per_layer_elems = {
2220 let input_layernorm = cfg.hidden_size;
2221 let post_attention_layernorm = cfg.hidden_size;
2222
2223 let size_in = cfg.hidden_size;
2224 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2225 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2226 let q_proj = size_in * size_q / weight_pack_factor + size_q;
2227 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2228 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2229 let o_proj = size_q * size_in / weight_pack_factor;
2230
2231 let h_size = cfg.hidden_size;
2232 let i_size = cfg.intermediate_size;
2233 let gate_proj = h_size * i_size / weight_pack_factor;
2234 let up_proj = h_size * i_size / weight_pack_factor;
2235 let down_proj = i_size * h_size / weight_pack_factor;
2236
2237 input_layernorm
2238 + post_attention_layernorm
2239 + q_proj
2240 + k_proj
2241 + v_proj
2242 + o_proj
2243 + gate_proj
2244 + up_proj
2245 + down_proj
2246 };
2247 Ok(vec![
2248 per_layer_elems * dtype.size_in_bytes();
2249 cfg.num_hidden_layers
2250 ])
2251 }
2252
2253 fn num_layers(&self, config: &str) -> Result<usize> {
2254 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2255 Ok(cfg.num_hidden_layers)
2256 }
2257
2258 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2259 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2260
2261 let cfg = ModelConfigMetadata {
2262 max_seq_len: cfg.max_position_embeddings,
2263 num_layers: cfg.num_hidden_layers,
2264 hidden_size: cfg.hidden_size,
2265 num_kv_heads: cfg.num_key_value_heads,
2266 num_attn_heads: cfg.num_attention_heads,
2267 sliding_window: cfg.sliding_window,
2268 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2269 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2270 };
2271
2272 Ok(Box::new(cfg))
2273 }
2274
2275 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2276 Some(vec![NonMappedSubModel::Vision])
2277 }
2278}
2279
2280pub struct Idefics3Loader;
2286
2287pub struct Idefics3Prefixer;
2288
2289impl MultimodalPromptPrefixer for Idefics3Prefixer {
2290 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2291 prompt.to_string()
2293 }
2294}
2295
2296impl VisionModelLoader for Idefics3Loader {
2297 fn load(
2298 &self,
2299 config: &str,
2300 vb: ShardedVarBuilder,
2301 normal_loading_metadata: NormalLoadingMetadata,
2302 attention_mechanism: AttentionImplementation,
2303 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2304 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2305 Ok(Box::new(Idefics3Model::new(
2306 &cfg,
2307 vb,
2308 self.is_gptx(config),
2309 normal_loading_metadata,
2310 attention_mechanism,
2311 )?))
2312 }
2313 fn is_gptx(&self, _config: &str) -> bool {
2314 true
2315 }
2316 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2317 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2318 Ok(Box::new(cfg))
2319 }
2320 fn get_processor(
2321 &self,
2322 _model_config: &str,
2323 processor_config: Option<ProcessorConfig>,
2324 preprocessor_config: PreProcessorConfig,
2325 max_edge: Option<u32>,
2326 ) -> Arc<dyn Processor + Send + Sync> {
2327 Arc::new(Idefics3Processor::new(
2328 processor_config.unwrap_or_default(),
2329 preprocessor_config,
2330 max_edge,
2331 ))
2332 }
2333 fn supports_paged_attention(&self, _config: &str) -> bool {
2334 true
2335 }
2336 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2337 true
2338 }
2339 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2340 Arc::new(Idefics3Prefixer)
2341 }
2342 fn modalities(&self, _config: &str) -> Result<Modalities> {
2343 Ok(Modalities {
2344 input: vec![SupportedModality::Text, SupportedModality::Vision],
2345 output: vec![SupportedModality::Text],
2346 })
2347 }
2348}
2349
2350impl IsqModelLoader for Idefics3Loader {
2351 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2352 Ok(vec![
2353 Regex::new(r"lm_head\.(weight|bias)$")?,
2354 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2356 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2357 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2358 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2359 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2361 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2362 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2363 ])
2364 }
2365 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2366 Ok(vec![
2367 Regex::new(r"lm_head\.(weight|bias)$")?,
2368 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2370 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2371 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2372 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2373 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2375 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2376 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2377 ])
2394 }
2395}
2396
2397impl DeviceMappedModelLoader for Idefics3Loader {
2398 fn mapped_max_act_size_elems(
2399 &self,
2400 config: &str,
2401 params: &AutoDeviceMapParams,
2402 _prompt_chunksize: usize,
2403 ) -> Result<usize> {
2404 let AutoDeviceMapParams::Vision {
2405 max_seq_len,
2406 max_batch_size,
2407 max_image_shape: _,
2408 max_num_images,
2409 } = params
2410 else {
2411 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2412 };
2413
2414 let cfg: Idefics3Config = serde_json::from_str(config)?;
2415
2416 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2417 let img_seq_len = (num_patches + 1) * max_num_images;
2418
2419 let max_text_attn = {
2420 let max_seq_len = img_seq_len + max_seq_len;
2422 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2423 };
2424
2425 Ok(max_text_attn)
2426 }
2427
2428 fn non_mapped_max_act_size_elems(
2429 &self,
2430 config: &str,
2431 params: &AutoDeviceMapParams,
2432 ) -> Result<usize> {
2433 let AutoDeviceMapParams::Vision {
2434 max_seq_len: _,
2435 max_batch_size,
2436 max_image_shape: _,
2437 max_num_images,
2438 } = params
2439 else {
2440 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2441 };
2442
2443 let cfg: Idefics3Config = serde_json::from_str(config)?;
2444
2445 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2446 let img_seq_len = num_patches + 1;
2447
2448 let max_vision_attn = {
2449 let images_factor = 5;
2451
2452 (max_batch_size * images_factor * max_num_images)
2453 * cfg.vision_config.num_attention_heads
2454 * img_seq_len
2455 * img_seq_len
2456 };
2457
2458 Ok(max_vision_attn)
2459 }
2460
2461 fn non_mapped_size_in_bytes(
2462 &self,
2463 config: &str,
2464 dtype: DType,
2465 weight_pack_factor: usize,
2466 ) -> Result<usize> {
2467 let cfg: Idefics3Config = serde_json::from_str(config)?;
2468 let text_elems = {
2469 let cfg = &cfg.text_config;
2470
2471 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2472 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2473 let norm = cfg.hidden_size;
2474 embed_tokens + lm_head + norm
2475 };
2476
2477 let connector_elems = {
2478 let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2479 let out_dim = cfg.text_config.hidden_size;
2480
2481 in_dim * out_dim
2482 };
2483
2484 let vision_transformer = {
2485 let cfg = &cfg.vision_config;
2486
2487 let post_layernorm = cfg.hidden_size;
2488
2489 let conv_config = Conv2dConfig {
2490 stride: cfg.patch_size,
2491 ..Default::default()
2492 };
2493 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2494 * cfg.patch_size
2495 * cfg.patch_size;
2496
2497 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2498 let num_patches = num_patches_per_side.pow(2);
2499 let position_embedding = num_patches * cfg.hidden_size;
2500
2501 let layer_elems = {
2502 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2503 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2504
2505 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2506 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2507
2508 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2509 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2510 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2511 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2512
2513 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2514 };
2515
2516 post_layernorm
2517 + patch_embedding
2518 + position_embedding
2519 + layer_elems * cfg.num_hidden_layers
2520 };
2521
2522 let elems = text_elems + connector_elems + vision_transformer;
2523
2524 Ok(elems * dtype.size_in_bytes())
2525 }
2526
2527 fn layer_sizes_in_bytes(
2528 &self,
2529 config: &str,
2530 dtype: DType,
2531 weight_pack_factor: usize,
2532 ) -> Result<Vec<usize>> {
2533 let cfg: Idefics3Config = serde_json::from_str(config)?;
2534 let cfg = cfg.text_config;
2535 let per_layer_elems = {
2536 let input_layernorm = cfg.hidden_size;
2537 let post_attention_layernorm = cfg.hidden_size;
2538
2539 let size_in = cfg.hidden_size;
2540 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2541 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2542 let q_proj = size_in * size_q / weight_pack_factor;
2543 let k_proj = size_in * size_kv / weight_pack_factor;
2544 let v_proj = size_in * size_kv / weight_pack_factor;
2545 let o_proj = size_q * size_in / weight_pack_factor;
2546
2547 let h_size = cfg.hidden_size;
2548 let i_size = cfg.intermediate_size;
2549 let gate_proj = h_size * i_size / weight_pack_factor;
2550 let up_proj = h_size * i_size / weight_pack_factor;
2551 let down_proj = i_size * h_size / weight_pack_factor;
2552
2553 input_layernorm
2554 + post_attention_layernorm
2555 + q_proj
2556 + k_proj
2557 + v_proj
2558 + o_proj
2559 + gate_proj
2560 + up_proj
2561 + down_proj
2562 };
2563 Ok(vec![
2564 per_layer_elems * dtype.size_in_bytes();
2565 cfg.num_hidden_layers
2566 ])
2567 }
2568
2569 fn num_layers(&self, config: &str) -> Result<usize> {
2570 let cfg: Idefics3Config = serde_json::from_str(config)?;
2571 Ok(cfg.text_config.num_hidden_layers)
2572 }
2573 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2574 let cfg: Idefics3Config = serde_json::from_str(config)?;
2575 let cfg = &cfg.text_config;
2576
2577 let cfg = ModelConfigMetadata {
2578 max_seq_len: cfg.max_position_embeddings,
2579 num_layers: cfg.num_hidden_layers,
2580 hidden_size: cfg.hidden_size,
2581 num_kv_heads: cfg.num_key_value_heads,
2582 num_attn_heads: cfg.num_attention_heads,
2583 sliding_window: None,
2584 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2585 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2586 };
2587
2588 Ok(Box::new(cfg))
2589 }
2590
2591 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2592 Some(vec![NonMappedSubModel::Vision])
2593 }
2594}
2595
2596pub struct MiniCpmOLoader;
2602
2603pub struct MiniCpmOPrefixer;
2604
2605impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2606 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2607 format!(
2608 "{}{prompt}",
2609 "(<image>./</image>)".repeat(image_indexes.len())
2610 )
2611 }
2612}
2613
2614impl VisionModelLoader for MiniCpmOLoader {
2615 fn load(
2616 &self,
2617 config: &str,
2618 vb: ShardedVarBuilder,
2619 normal_loading_metadata: NormalLoadingMetadata,
2620 attention_mechanism: AttentionImplementation,
2621 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2622 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2623 Ok(Box::new(MiniCpmOModel::new(
2624 &cfg,
2625 vb,
2626 self.is_gptx(config),
2627 normal_loading_metadata,
2628 attention_mechanism,
2629 )?))
2630 }
2631 fn is_gptx(&self, _config: &str) -> bool {
2632 true
2633 }
2634 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2635 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2636 Ok(Box::new(cfg))
2637 }
2638 fn get_processor(
2639 &self,
2640 _model_config: &str,
2641 processor_config: Option<ProcessorConfig>,
2642 preprocessor_config: PreProcessorConfig,
2643 max_edge: Option<u32>,
2644 ) -> Arc<dyn Processor + Send + Sync> {
2645 Arc::new(MiniCpmOProcessor::new(
2646 processor_config.unwrap_or_default(),
2647 preprocessor_config,
2648 max_edge,
2649 ))
2650 }
2651 fn supports_paged_attention(&self, _config: &str) -> bool {
2652 true
2653 }
2654 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2655 Arc::new(MiniCpmOPrefixer)
2656 }
2657 fn modalities(&self, _config: &str) -> Result<Modalities> {
2658 Ok(Modalities {
2659 input: vec![SupportedModality::Text, SupportedModality::Vision],
2660 output: vec![SupportedModality::Text],
2661 })
2662 }
2663}
2664
2665impl IsqModelLoader for MiniCpmOLoader {
2666 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2667 Ok(vec![
2668 Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2669 Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2671 Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2672 Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2673 Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2674 Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2676 Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2677 Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2678 ])
2679 }
2680 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2681 self.isq_layer_regexes(config)
2682 }
2683}
2684
2685impl DeviceMappedModelLoader for MiniCpmOLoader {
2686 fn mapped_max_act_size_elems(
2687 &self,
2688 config: &str,
2689 params: &AutoDeviceMapParams,
2690 _prompt_chunksize: usize,
2691 ) -> Result<usize> {
2692 let AutoDeviceMapParams::Vision {
2693 max_seq_len,
2694 max_batch_size,
2695 max_image_shape: _,
2696 max_num_images,
2697 } = params
2698 else {
2699 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2700 };
2701
2702 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2703
2704 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2705 let img_seq_len = (num_patches + 1) * max_num_images;
2706
2707 let max_text_attn = {
2708 let max_seq_len = img_seq_len + max_seq_len;
2710 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2711 };
2712
2713 Ok(max_text_attn)
2714 }
2715
2716 fn non_mapped_max_act_size_elems(
2717 &self,
2718 config: &str,
2719 params: &AutoDeviceMapParams,
2720 ) -> Result<usize> {
2721 let AutoDeviceMapParams::Vision {
2722 max_seq_len: _,
2723 max_batch_size,
2724 max_image_shape: _,
2725 max_num_images,
2726 } = params
2727 else {
2728 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2729 };
2730
2731 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2732
2733 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2734 let img_seq_len = num_patches + 1;
2735
2736 let max_vision_attn = {
2737 let images_factor = 5;
2739
2740 (max_batch_size * images_factor * max_num_images)
2741 * cfg.vision_config.num_attention_heads
2742 * img_seq_len
2743 * img_seq_len
2744 };
2745
2746 Ok(max_vision_attn)
2747 }
2748
2749 fn non_mapped_size_in_bytes(
2750 &self,
2751 config: &str,
2752 dtype: DType,
2753 weight_pack_factor: usize,
2754 ) -> Result<usize> {
2755 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2756 let text_elems = {
2757 let cfg = &cfg.text_config;
2758
2759 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2760 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2761 let norm = cfg.hidden_size;
2762 embed_tokens + lm_head + norm
2763 };
2764
2765 let vision_transformer = {
2766 let cfg = &cfg.vision_config;
2767
2768 let post_layernorm = cfg.hidden_size;
2769
2770 let conv_config = Conv2dConfig {
2771 stride: cfg.patch_size,
2772 ..Default::default()
2773 };
2774 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2775 * cfg.patch_size
2776 * cfg.patch_size;
2777
2778 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2779 let num_patches = num_patches_per_side.pow(2);
2780 let position_embedding = num_patches * cfg.hidden_size;
2781
2782 let layer_elems = {
2783 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2784 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2785
2786 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2787 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2788
2789 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2790 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2791 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2792 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2793
2794 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2795 };
2796
2797 post_layernorm
2798 + patch_embedding
2799 + position_embedding
2800 + layer_elems * cfg.num_hidden_layers
2801 };
2802
2803 let elems = text_elems + vision_transformer;
2804
2805 Ok(elems * dtype.size_in_bytes())
2806 }
2807
2808 fn layer_sizes_in_bytes(
2809 &self,
2810 config: &str,
2811 dtype: DType,
2812 weight_pack_factor: usize,
2813 ) -> Result<Vec<usize>> {
2814 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2815 let cfg = cfg.text_config;
2816 let per_layer_elems = {
2817 let input_layernorm = cfg.hidden_size;
2818 let post_attention_layernorm = cfg.hidden_size;
2819
2820 let size_in = cfg.hidden_size;
2821 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2822 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2823 let q_proj = size_in * size_q / weight_pack_factor;
2824 let k_proj = size_in * size_kv / weight_pack_factor;
2825 let v_proj = size_in * size_kv / weight_pack_factor;
2826 let o_proj = size_q * size_in / weight_pack_factor;
2827
2828 let h_size = cfg.hidden_size;
2829 let i_size = cfg.intermediate_size;
2830 let gate_proj = h_size * i_size / weight_pack_factor;
2831 let up_proj = h_size * i_size / weight_pack_factor;
2832 let down_proj = i_size * h_size / weight_pack_factor;
2833
2834 input_layernorm
2835 + post_attention_layernorm
2836 + q_proj
2837 + k_proj
2838 + v_proj
2839 + o_proj
2840 + gate_proj
2841 + up_proj
2842 + down_proj
2843 };
2844 Ok(vec![
2845 per_layer_elems * dtype.size_in_bytes();
2846 cfg.num_hidden_layers
2847 ])
2848 }
2849
2850 fn num_layers(&self, config: &str) -> Result<usize> {
2851 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2852 Ok(cfg.text_config.num_hidden_layers)
2853 }
2854 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2855 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2856 let cfg = &cfg.text_config;
2857
2858 let cfg = ModelConfigMetadata {
2859 max_seq_len: cfg.max_position_embeddings,
2860 num_layers: cfg.num_hidden_layers,
2861 hidden_size: cfg.hidden_size,
2862 num_kv_heads: cfg.num_key_value_heads,
2863 num_attn_heads: cfg.num_attention_heads,
2864 sliding_window: None,
2865 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2866 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2867 };
2868
2869 Ok(Box::new(cfg))
2870 }
2871}
2872
2873pub struct Phi4MMLoader;
2879
2880pub struct Phi4MMPrefixer;
2881
2882impl MultimodalPromptPrefixer for Phi4MMPrefixer {
2883 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2884 format!(
2887 "{}{prompt}",
2888 image_indexes
2889 .into_iter()
2890 .map(|image_index| format!("<|image_{}|>", image_index + 1))
2891 .join("")
2892 )
2893 }
2894 fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
2895 format!(
2898 "{}{prompt}",
2899 audio_indexes
2900 .into_iter()
2901 .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
2902 .join("")
2903 )
2904 }
2905}
2906
2907impl VisionModelLoader for Phi4MMLoader {
2908 fn load(
2909 &self,
2910 config: &str,
2911 vb: ShardedVarBuilder,
2912 normal_loading_metadata: NormalLoadingMetadata,
2913 attention_mechanism: AttentionImplementation,
2914 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2915 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2916 Ok(Box::new(Phi4MMModel::new(
2917 &cfg,
2918 vb,
2919 self.is_gptx(config),
2920 normal_loading_metadata,
2921 attention_mechanism,
2922 )?))
2923 }
2924 fn is_gptx(&self, _config: &str) -> bool {
2925 true
2926 }
2927 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2928 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2929 Ok(Box::new(cfg))
2930 }
2931 fn get_processor(
2932 &self,
2933 _model_config: &str,
2934 processor_config: Option<ProcessorConfig>,
2935 preprocessor_config: PreProcessorConfig,
2936 _max_edge: Option<u32>,
2937 ) -> Arc<dyn Processor + Send + Sync> {
2938 Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2939 }
2940 fn supports_paged_attention(&self, _config: &str) -> bool {
2941 true
2942 }
2943 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2944 true
2945 }
2946 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2947 Arc::new(Phi4MMPrefixer)
2948 }
2949 fn modalities(&self, _config: &str) -> Result<Modalities> {
2950 Ok(Modalities {
2951 input: vec![
2952 SupportedModality::Text,
2953 SupportedModality::Vision,
2954 SupportedModality::Audio,
2955 ],
2956 output: vec![SupportedModality::Text],
2957 })
2958 }
2959}
2960
2961impl IsqModelLoader for Phi4MMLoader {
2962 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2963 Ok(vec![
2964 Regex::new(r"lm_head\.(weight|bias)$")?,
2965 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
2967 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2968 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
2970 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2971 ])
2972 }
2973 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2974 self.isq_layer_regexes(config)
2975 }
2976}
2977
2978impl DeviceMappedModelLoader for Phi4MMLoader {
2979 fn mapped_max_act_size_elems(
2980 &self,
2981 config: &str,
2982 params: &AutoDeviceMapParams,
2983 _prompt_chunksize: usize,
2984 ) -> Result<usize> {
2985 let AutoDeviceMapParams::Vision {
2987 max_seq_len,
2988 max_batch_size,
2989 max_image_shape: _,
2990 max_num_images,
2991 } = params
2992 else {
2993 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2994 };
2995
2996 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2997
2998 let vcfg = &PHI4_MM_VISION_CFG;
2999
3000 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3001 let img_seq_len = (num_patches + 1) * max_num_images;
3002
3003 let max_text_attn = {
3004 let max_seq_len = img_seq_len + max_seq_len;
3006 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3007 };
3008
3009 Ok(max_text_attn)
3010 }
3011
3012 fn non_mapped_max_act_size_elems(
3013 &self,
3014 _config: &str,
3015 params: &AutoDeviceMapParams,
3016 ) -> Result<usize> {
3017 let AutoDeviceMapParams::Vision {
3018 max_seq_len: _,
3019 max_batch_size,
3020 max_image_shape,
3021 max_num_images,
3022 } = params
3023 else {
3024 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3025 };
3026
3027 let vcfg = &PHI4_MM_VISION_CFG;
3028
3029 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3030 let img_seq_len = num_patches + 1;
3031
3032 let max_batch_size = max_batch_size
3033 * (max_image_shape
3034 .0
3035 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3036 * max_image_shape
3037 .1
3038 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3039 + 1);
3040
3041 let max_vision_attn = (max_batch_size * max_num_images)
3042 * vcfg.num_attention_heads
3043 * img_seq_len
3044 * img_seq_len;
3045 let max_qkv = 3
3046 * (max_batch_size
3047 * vcfg.num_attention_heads
3048 * img_seq_len
3049 * (vcfg.hidden_size / vcfg.num_attention_heads));
3050
3051 Ok(max_vision_attn + max_qkv)
3052 }
3053
3054 fn non_mapped_size_in_bytes(
3055 &self,
3056 config: &str,
3057 dtype: DType,
3058 weight_pack_factor: usize,
3059 ) -> Result<usize> {
3060 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3061 let elems = {
3062 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3063 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3065 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3066 } else {
3067 0
3068 };
3069 let norm = cfg.hidden_size;
3070
3071 let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3072 let projection_cls = img_embed
3073 .projection_cls
3074 .clone()
3075 .unwrap_or("linear".to_string());
3076 let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3077 let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3078 let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3079
3080 let proj = match (projection_cls.as_str(), use_hd_transform) {
3081 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3082 ("mlp", true) => {
3083 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3084 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3085 a + b
3086 }
3087 ("mlp", false) => {
3088 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3089 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3090 a + b
3091 }
3092 _ => {
3093 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3094 }
3095 };
3096
3097 let (glb_gn, sub_gn) = if with_learnable_separator {
3098 let glb_gn = image_dim_out * 4;
3099 let sub_gn = image_dim_out * 4;
3100 (glb_gn, sub_gn)
3101 } else {
3102 (0, 0)
3103 };
3104
3105 let vision_transformer = {
3106 let cfg = &PHI4_MM_VISION_CFG;
3107
3108 let post_layernorm = cfg.hidden_size;
3109
3110 let conv_config = Conv2dConfig {
3111 stride: cfg.patch_size,
3112 ..Default::default()
3113 };
3114 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3115 * cfg.patch_size
3116 * cfg.patch_size;
3117
3118 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3119 let num_patches = num_patches_per_side.pow(2);
3120 let position_embedding = num_patches * cfg.hidden_size;
3121
3122 let layer_elems = {
3123 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3124 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3125
3126 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3127 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3128
3129 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3130 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3131 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3132 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3133
3134 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3135 };
3136
3137 post_layernorm
3138 + patch_embedding
3139 + position_embedding
3140 + layer_elems * cfg.num_hidden_layers
3141 };
3142
3143 proj + glb_gn + sub_gn + vision_transformer
3144 } else {
3145 0
3146 };
3147
3148 embed_tokens + lm_head + norm + image_embed
3149 };
3150
3151 Ok(elems * dtype.size_in_bytes())
3152 }
3153
3154 fn layer_sizes_in_bytes(
3155 &self,
3156 config: &str,
3157 dtype: DType,
3158 weight_pack_factor: usize,
3159 ) -> Result<Vec<usize>> {
3160 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3161 let per_layer_elems = {
3162 let input_layernorm = cfg.hidden_size;
3163 let post_attention_layernorm = cfg.hidden_size;
3164
3165 let size_in = cfg.hidden_size;
3166 let head_dim = cfg.head_dim();
3167 let op_size =
3168 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3169 let qkv_proj = size_in * op_size / weight_pack_factor;
3170 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3171
3172 let h_size = cfg.hidden_size;
3173 let i_size = cfg.intermediate_size;
3174 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3175 let down_proj = h_size * i_size / weight_pack_factor;
3176
3177 input_layernorm
3178 + post_attention_layernorm
3179 + qkv_proj
3180 + o_proj
3181 + gate_up_proj
3182 + down_proj
3183 };
3184 Ok(vec![
3185 per_layer_elems * dtype.size_in_bytes();
3186 cfg.num_hidden_layers
3187 ])
3188 }
3189
3190 fn num_layers(&self, config: &str) -> Result<usize> {
3191 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3192 Ok(cfg.num_hidden_layers)
3193 }
3194
3195 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3196 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3197
3198 let cfg = ModelConfigMetadata {
3199 max_seq_len: cfg.max_position_embeddings,
3200 num_layers: cfg.num_hidden_layers,
3201 hidden_size: cfg.hidden_size,
3202 num_kv_heads: cfg.num_key_value_heads(),
3203 num_attn_heads: cfg.num_attention_heads,
3204 sliding_window: cfg.sliding_window,
3205 k_head_dim: cfg.head_dim(),
3206 v_head_dim: cfg.head_dim(),
3207 };
3208
3209 Ok(Box::new(cfg))
3210 }
3211
3212 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3213 Some(vec![NonMappedSubModel::Vision])
3214 }
3215}
3216
3217pub struct Qwen2_5VLLoader;
3223
3224pub struct Qwen2_5VLPrefixer;
3225
3226impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3227 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3228 format!(
3229 "{}{prompt}",
3230 format!(
3231 "{}{}{}",
3232 Qwen2_5VLProcessor::VISION_START,
3233 Qwen2_5VLProcessor::IMAGE_PAD,
3234 Qwen2_5VLProcessor::VISION_END
3235 )
3236 .repeat(image_indexes.len())
3237 )
3238 }
3239}
3240
3241impl VisionModelLoader for Qwen2_5VLLoader {
3242 fn load(
3243 &self,
3244 config: &str,
3245 vb: ShardedVarBuilder,
3246 normal_loading_metadata: NormalLoadingMetadata,
3247 attention_mechanism: AttentionImplementation,
3248 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3249 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3250 Ok(Box::new(Qwen2_5VLModel::new(
3251 &cfg,
3252 vb,
3253 self.is_gptx(config),
3254 normal_loading_metadata,
3255 attention_mechanism,
3256 )?))
3257 }
3258 fn is_gptx(&self, _config: &str) -> bool {
3259 true
3260 }
3261 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3262 let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3263 Ok(Box::new(config))
3264 }
3265 fn get_processor(
3266 &self,
3267 _model_config: &str,
3268 _processor_config: Option<ProcessorConfig>,
3269 _preprocessor_config: PreProcessorConfig,
3270 max_edge: Option<u32>,
3271 ) -> Arc<dyn Processor + Send + Sync> {
3272 Arc::new(Qwen2_5VLProcessor::new(max_edge))
3273 }
3274 fn supports_paged_attention(&self, _config: &str) -> bool {
3275 false
3276 }
3277 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3278 Arc::new(Qwen2_5VLPrefixer)
3279 }
3280 fn modalities(&self, _config: &str) -> Result<Modalities> {
3281 Ok(Modalities {
3282 input: vec![SupportedModality::Text, SupportedModality::Vision],
3283 output: vec![SupportedModality::Text],
3284 })
3285 }
3286}
3287
3288impl IsqModelLoader for Qwen2_5VLLoader {
3289 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3290 Ok(vec![
3291 Regex::new(r"lm_head\.(weight|bias)$")?,
3292 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3294 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3295 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3296 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3297 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3299 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3300 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3301 ])
3302 }
3303 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3304 self.isq_layer_regexes(config)
3305 }
3306}
3307
3308impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3309 fn mapped_max_act_size_elems(
3310 &self,
3311 config: &str,
3312 params: &AutoDeviceMapParams,
3313 _prompt_chunksize: usize,
3314 ) -> Result<usize> {
3315 let AutoDeviceMapParams::Vision {
3316 max_seq_len,
3317 max_batch_size,
3318 max_image_shape,
3319 max_num_images,
3320 } = params
3321 else {
3322 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3323 };
3324
3325 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3326
3327 let img_seq_len = {
3328 let cfg = &cfg.vision_config;
3329 let grid_t = max_num_images / cfg.temporal_patch_size;
3330 let grid_h = max_image_shape.0 / cfg.patch_size;
3331 let grid_w = max_image_shape.1 / cfg.patch_size;
3332 grid_t * grid_h * grid_w
3333 };
3334 let img_seq_len = img_seq_len * max_num_images;
3335
3336 let max_text_attn = {
3337 let max_seq_len = img_seq_len + max_seq_len;
3339 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3340 };
3341
3342 Ok(max_text_attn)
3343 }
3344
3345 fn non_mapped_max_act_size_elems(
3346 &self,
3347 config: &str,
3348 params: &AutoDeviceMapParams,
3349 ) -> Result<usize> {
3350 let AutoDeviceMapParams::Vision {
3351 max_seq_len: _,
3352 max_batch_size,
3353 max_image_shape,
3354 max_num_images,
3355 } = params
3356 else {
3357 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3358 };
3359
3360 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3361
3362 let img_seq_len = {
3363 let cfg = &cfg.vision_config;
3364 let grid_t = max_num_images / cfg.temporal_patch_size;
3365 let grid_h = max_image_shape.0 / cfg.patch_size;
3366 let grid_w = max_image_shape.1 / cfg.patch_size;
3367 grid_t * grid_h * grid_w
3368 };
3369
3370 let max_vision_attn = {
3371 let cfg = &cfg.vision_config;
3372 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3373 };
3374
3375 Ok(max_vision_attn)
3376 }
3377
3378 fn non_mapped_size_in_bytes(
3379 &self,
3380 config: &str,
3381 dtype: DType,
3382 weight_pack_factor: usize,
3383 ) -> Result<usize> {
3384 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3385 let text_elems = {
3386 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3387 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3389 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3390 } else {
3391 0
3392 };
3393 let norm = cfg.hidden_size;
3394 embed_tokens + lm_head + norm
3395 };
3396
3397 let patch_merger = {
3398 let cfg = &cfg.vision_config;
3399 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3400
3401 let mlp0 = hidden_size * hidden_size + hidden_size;
3402 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3403
3404 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3405
3406 mlp0 + mlp2 + ln_q
3407 };
3408
3409 let patch_embed = {
3410 let cfg = &cfg.vision_config;
3411 let conv_cfg = Conv3dConfig {
3412 stride: cfg.patch_size,
3413 ..Default::default()
3414 };
3415 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3416 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3417 * kernel_sizes[0]
3418 * kernel_sizes[1]
3419 * kernel_sizes[2]
3420 };
3421
3422 let encoder_layer = {
3423 let cfg = &cfg.vision_config;
3424 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3425 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3426
3427 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3428 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3429 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3430
3431 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3432 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3433
3434 norm1 + norm2 + fc1 + fc2 + qkv + out
3435 };
3436
3437 let elems =
3438 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3439
3440 Ok(elems * dtype.size_in_bytes())
3441 }
3442
3443 fn layer_sizes_in_bytes(
3444 &self,
3445 config: &str,
3446 dtype: DType,
3447 weight_pack_factor: usize,
3448 ) -> Result<Vec<usize>> {
3449 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3450 let per_layer_elems = {
3451 let input_layernorm = cfg.hidden_size;
3452 let post_attention_layernorm = cfg.hidden_size;
3453
3454 let size_in = cfg.hidden_size;
3455 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3456 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3457 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3458 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3459 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3460 let o_proj = size_q * size_in / weight_pack_factor;
3461
3462 let h_size = cfg.hidden_size;
3463 let i_size = cfg.intermediate_size;
3464 let gate_proj = h_size * i_size / weight_pack_factor;
3465 let up_proj = h_size * i_size / weight_pack_factor;
3466 let down_proj = i_size * h_size / weight_pack_factor;
3467
3468 input_layernorm
3469 + post_attention_layernorm
3470 + q_proj
3471 + k_proj
3472 + v_proj
3473 + o_proj
3474 + gate_proj
3475 + up_proj
3476 + down_proj
3477 };
3478 Ok(vec![
3479 per_layer_elems * dtype.size_in_bytes();
3480 cfg.num_hidden_layers
3481 ])
3482 }
3483
3484 fn num_layers(&self, config: &str) -> Result<usize> {
3485 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3486 Ok(cfg.num_hidden_layers)
3487 }
3488
3489 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3490 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3491
3492 let cfg = ModelConfigMetadata {
3493 max_seq_len: cfg.max_position_embeddings,
3494 num_layers: cfg.num_hidden_layers,
3495 hidden_size: cfg.hidden_size,
3496 num_kv_heads: cfg.num_key_value_heads,
3497 num_attn_heads: cfg.num_attention_heads,
3498 sliding_window: cfg.sliding_window,
3499 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3500 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3501 };
3502
3503 Ok(Box::new(cfg))
3504 }
3505
3506 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3507 Some(vec![NonMappedSubModel::Vision])
3508 }
3509}
3510
3511pub struct Gemma3Loader;
3517
3518pub struct Gemma3Prefixer;
3519
3520impl MultimodalPromptPrefixer for Gemma3Prefixer {
3521 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3522 prompt.to_string()
3523 }
3524}
3525
3526impl VisionModelLoader for Gemma3Loader {
3527 fn load(
3528 &self,
3529 config: &str,
3530 vb: ShardedVarBuilder,
3531 normal_loading_metadata: NormalLoadingMetadata,
3532 attention_mechanism: AttentionImplementation,
3533 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3534 let cfg: Gemma3Config = serde_json::from_str(config)?;
3535 Ok(Box::new(Gemma3Model::new(
3536 &cfg,
3537 vb,
3538 self.is_gptx(config),
3539 normal_loading_metadata,
3540 attention_mechanism,
3541 )?))
3542 }
3543 fn is_gptx(&self, _config: &str) -> bool {
3544 true
3545 }
3546 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3547 let config: Gemma3Config = serde_json::from_str(config)?;
3548 Ok(Box::new(config))
3549 }
3550 fn get_processor(
3551 &self,
3552 config: &str,
3553 processor_config: Option<ProcessorConfig>,
3554 _preprocessor_config: PreProcessorConfig,
3555 _max_edge: Option<u32>,
3556 ) -> Arc<dyn Processor + Send + Sync> {
3557 let config: Gemma3Config = serde_json::from_str(config).unwrap();
3558 Arc::new(Gemma3Processor::new(
3560 processor_config.unwrap_or_default(),
3561 matches!(config, Gemma3Config::WithVision { .. }),
3562 ))
3563 }
3564 fn supports_paged_attention(&self, _config: &str) -> bool {
3565 true
3566 }
3567 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3568 true
3569 }
3570 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3571 Arc::new(Gemma3Prefixer)
3572 }
3573 fn modalities(&self, _config: &str) -> Result<Modalities> {
3574 Ok(Modalities {
3575 input: vec![SupportedModality::Text, SupportedModality::Vision],
3576 output: vec![SupportedModality::Text],
3577 })
3578 }
3579}
3580
3581impl IsqModelLoader for Gemma3Loader {
3582 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3583 Ok(vec![
3584 Regex::new(r"lm_head\.(weight|bias)$")?,
3585 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3587 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3588 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3589 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3590 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3592 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3593 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3594 ])
3595 }
3596 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3597 Ok(vec![
3598 Regex::new(r"lm_head\.(weight|bias)$")?,
3599 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3601 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3602 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3603 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3604 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3606 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3607 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3608 ])
3609 }
3610}
3611
3612impl DeviceMappedModelLoader for Gemma3Loader {
3613 fn mapped_max_act_size_elems(
3614 &self,
3615 config: &str,
3616 params: &AutoDeviceMapParams,
3617 prompt_chunksize: usize,
3618 ) -> Result<usize> {
3619 let AutoDeviceMapParams::Vision {
3620 max_seq_len,
3621 max_batch_size,
3622 max_image_shape: _,
3623 max_num_images,
3624 } = params
3625 else {
3626 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3627 };
3628
3629 let cfg: Gemma3Config = serde_json::from_str(config)?;
3630
3631 match cfg {
3632 Gemma3Config::Text(text_config) => Ok(max_batch_size
3633 * text_config.num_attention_heads
3634 * prompt_chunksize
3635 * prompt_chunksize),
3636 Gemma3Config::WithVision {
3637 text_config,
3638 vision_config,
3639 ..
3640 } => {
3641 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3642 let img_seq_len = (num_patches + 1) * max_num_images;
3643
3644 let max_text_attn = {
3645 let max_seq_len = img_seq_len + *max_seq_len;
3647 max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3648 };
3649 Ok(max_text_attn)
3650 }
3651 }
3652 }
3653
3654 fn non_mapped_max_act_size_elems(
3655 &self,
3656 config: &str,
3657 params: &AutoDeviceMapParams,
3658 ) -> Result<usize> {
3659 let AutoDeviceMapParams::Vision {
3660 max_seq_len: _,
3661 max_batch_size,
3662 max_image_shape: _,
3663 max_num_images,
3664 } = params
3665 else {
3666 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3667 };
3668
3669 let cfg: Gemma3Config = serde_json::from_str(config)?;
3670
3671 match cfg {
3672 Gemma3Config::WithVision { vision_config, .. } => {
3673 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3674 let img_seq_len = num_patches + 1;
3675
3676 let max_vision_attn = {
3677 (max_batch_size * max_num_images)
3678 * vision_config.num_attention_heads
3679 * img_seq_len
3680 * img_seq_len
3681 };
3682
3683 Ok(max_vision_attn)
3684 }
3685 Gemma3Config::Text(_) => Ok(0),
3686 }
3687 }
3688
3689 fn non_mapped_size_in_bytes(
3690 &self,
3691 config: &str,
3692 dtype: DType,
3693 weight_pack_factor: usize,
3694 ) -> Result<usize> {
3695 let cfg: Gemma3Config = serde_json::from_str(config)?;
3696
3697 let text_elems = {
3698 let cfg = match &cfg {
3699 Gemma3Config::Text(cfg) => cfg,
3700 Gemma3Config::WithVision { text_config, .. } => text_config,
3701 };
3702 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3703 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3705 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3706 } else {
3707 0
3708 };
3709 let norm = cfg.hidden_size;
3710 embed_tokens + lm_head + norm
3711 };
3712
3713 let vision_transformer = if let Gemma3Config::WithVision {
3714 vision_config: cfg, ..
3715 } = &cfg
3716 {
3717 let post_layernorm = cfg.hidden_size;
3718
3719 let conv_config = Conv2dConfig {
3720 stride: cfg.patch_size,
3721 ..Default::default()
3722 };
3723 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3724 * cfg.patch_size
3725 * cfg.patch_size;
3726
3727 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3728 let num_patches = num_patches_per_side.pow(2);
3729 let position_embedding = num_patches * cfg.hidden_size;
3730
3731 let layer_elems = {
3732 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3733 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3734
3735 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3736 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3737
3738 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3739 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3740 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3741 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3742
3743 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3744 };
3745
3746 post_layernorm
3747 + patch_embedding
3748 + position_embedding
3749 + layer_elems * cfg.num_hidden_layers
3750 } else {
3751 0
3752 };
3753
3754 let elems = text_elems + vision_transformer;
3755
3756 Ok(elems * dtype.size_in_bytes())
3757 }
3758
3759 fn layer_sizes_in_bytes(
3760 &self,
3761 config: &str,
3762 dtype: DType,
3763 weight_pack_factor: usize,
3764 ) -> Result<Vec<usize>> {
3765 let cfg: Gemma3Config = serde_json::from_str(config)?;
3766
3767 let txt_cfg = match &cfg {
3768 Gemma3Config::Text(cfg) => cfg,
3769 Gemma3Config::WithVision { text_config, .. } => text_config,
3770 };
3771 let per_layer_elems = {
3772 let cfg = txt_cfg;
3773
3774 let input_layernorm = cfg.hidden_size;
3775 let post_attention_layernorm = cfg.hidden_size;
3776
3777 let size_in = cfg.hidden_size;
3778 let size_q = cfg.head_dim * cfg.num_attention_heads;
3779 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3780 let q_proj =
3781 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3782 let k_proj =
3783 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3784 let v_proj =
3785 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3786 let o_proj =
3787 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3788
3789 let h_size = cfg.hidden_size;
3790 let i_size = cfg.intermediate_size;
3791 let gate_proj = h_size * i_size / weight_pack_factor;
3792 let up_proj = h_size * i_size / weight_pack_factor;
3793 let down_proj = i_size * h_size / weight_pack_factor;
3794
3795 input_layernorm
3796 + post_attention_layernorm
3797 + q_proj
3798 + k_proj
3799 + v_proj
3800 + o_proj
3801 + gate_proj
3802 + up_proj
3803 + down_proj
3804 };
3805 Ok(vec![
3806 per_layer_elems * dtype.size_in_bytes();
3807 txt_cfg.num_hidden_layers
3808 ])
3809 }
3810
3811 fn num_layers(&self, config: &str) -> Result<usize> {
3812 let cfg: Gemma3Config = serde_json::from_str(config)?;
3813
3814 let txt_cfg = match &cfg {
3815 Gemma3Config::Text(cfg) => cfg,
3816 Gemma3Config::WithVision { text_config, .. } => text_config,
3817 };
3818
3819 Ok(txt_cfg.num_hidden_layers)
3820 }
3821
3822 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3823 let cfg: Gemma3Config = serde_json::from_str(config)?;
3824
3825 let cfg = match &cfg {
3826 Gemma3Config::Text(cfg) => cfg,
3827 Gemma3Config::WithVision { text_config, .. } => text_config,
3828 };
3829
3830 let cfg = ModelConfigMetadata {
3831 max_seq_len: cfg.max_position_embeddings,
3832 num_layers: cfg.num_hidden_layers,
3833 hidden_size: cfg.hidden_size,
3834 num_kv_heads: cfg.num_key_value_heads,
3835 num_attn_heads: cfg.num_attention_heads,
3836 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3838 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3839 };
3840
3841 Ok(Box::new(cfg))
3842 }
3843
3844 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3845 Some(vec![NonMappedSubModel::Vision])
3846 }
3847}
3848
3849pub struct Mistral3Loader;
3855
3856pub struct Mistral3Prefixer;
3857
3858impl MultimodalPromptPrefixer for Mistral3Prefixer {
3859 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3860 prompt.to_string()
3861 }
3862}
3863
3864impl VisionModelLoader for Mistral3Loader {
3865 fn load(
3866 &self,
3867 config: &str,
3868 vb: ShardedVarBuilder,
3869 normal_loading_metadata: NormalLoadingMetadata,
3870 attention_mechanism: AttentionImplementation,
3871 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3872 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3873 Ok(Box::new(Mistral3Model::new(
3874 &cfg,
3875 vb,
3876 self.is_gptx(config),
3877 normal_loading_metadata,
3878 attention_mechanism,
3879 )?))
3880 }
3881 fn is_gptx(&self, _config: &str) -> bool {
3882 true
3883 }
3884 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3885 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3886 Ok(Box::new(cfg))
3887 }
3888 fn get_processor(
3889 &self,
3890 _model_config: &str,
3891 processor_config: Option<ProcessorConfig>,
3892 _preprocessor_config: PreProcessorConfig,
3893 _max_edge: Option<u32>,
3894 ) -> Arc<dyn Processor + Send + Sync> {
3895 Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3896 }
3897 fn supports_paged_attention(&self, _config: &str) -> bool {
3898 true
3899 }
3900 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3901 true
3902 }
3903 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3904 Arc::new(Mistral3Prefixer)
3905 }
3906 fn modalities(&self, _config: &str) -> Result<Modalities> {
3907 Ok(Modalities {
3908 input: vec![SupportedModality::Text, SupportedModality::Vision],
3909 output: vec![SupportedModality::Text],
3910 })
3911 }
3912}
3913
3914impl IsqModelLoader for Mistral3Loader {
3915 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3916 Ok(vec![
3917 Regex::new(r"lm_head\.(weight|bias)$")?,
3918 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3920 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3921 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3922 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3923 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3925 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3926 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3927 ])
3928 }
3929 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3930 Ok(vec![
3931 Regex::new(r"lm_head\.(weight|bias)$")?,
3932 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3934 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3935 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3936 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3937 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3939 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3940 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3941 ])
3942 }
3943}
3944
3945#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3946impl DeviceMappedModelLoader for Mistral3Loader {
3947 fn mapped_max_act_size_elems(
3948 &self,
3949 config: &str,
3950 params: &AutoDeviceMapParams,
3951 _prompt_chunksize: usize,
3952 ) -> Result<usize> {
3953 let cfg: Mistral3Config = serde_json::from_str(config)?;
3954 let vcfg = &cfg.vision_config;
3955 let tcfg = &cfg.text_config;
3956
3957 let AutoDeviceMapParams::Vision {
3958 max_seq_len,
3959 max_batch_size,
3960 max_image_shape: (mut height, mut width),
3961 max_num_images,
3962 } = params
3963 else {
3964 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3965 };
3966
3967 let img_seq_len = {
3968 let (max_height, max_width) = (1540, 1540);
3972 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3973 if ratio > 1. {
3974 height = (height as f64 / ratio).floor() as usize;
3975 width = (width as f64 / ratio).floor() as usize;
3976 }
3977
3978 let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
3979 let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
3980
3981 height = num_height_tokens * vcfg.patch_size;
3982 width = num_width_tokens * vcfg.patch_size;
3983
3984 let num_height_tokens = height / vcfg.patch_size;
3985 let num_width_tokens = width / vcfg.patch_size;
3986
3987 (num_width_tokens + 1) * num_height_tokens
3988 };
3989
3990 let max_seq_len = img_seq_len * max_num_images + *max_seq_len;
3992 Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
3993 }
3994
3995 fn non_mapped_max_act_size_elems(
3996 &self,
3997 config: &str,
3998 params: &AutoDeviceMapParams,
3999 ) -> Result<usize> {
4000 let cfg: Mistral3Config = serde_json::from_str(config)?;
4001 let cfg = &cfg.vision_config;
4002
4003 let AutoDeviceMapParams::Vision {
4004 max_seq_len: _,
4005 max_batch_size,
4006 max_image_shape: (mut height, mut width),
4007 max_num_images,
4008 } = params
4009 else {
4010 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4011 };
4012
4013 let img_seq_len = {
4014 let (max_height, max_width) = (1540, 1540);
4018 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4019 if ratio > 1. {
4020 height = (height as f64 / ratio).floor() as usize;
4021 width = (width as f64 / ratio).floor() as usize;
4022 }
4023
4024 let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4025 let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4026
4027 height = num_height_tokens * cfg.patch_size;
4028 width = num_width_tokens * cfg.patch_size;
4029
4030 let num_height_tokens = height / cfg.patch_size;
4031 let num_width_tokens = width / cfg.patch_size;
4032
4033 (num_width_tokens + 1) * num_height_tokens
4034 };
4035
4036 Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4037 }
4038
4039 fn non_mapped_size_in_bytes(
4040 &self,
4041 config: &str,
4042 dtype: DType,
4043 weight_pack_factor: usize,
4044 ) -> Result<usize> {
4045 let cfg: Mistral3Config = serde_json::from_str(config)?;
4046
4047 let text_elems = {
4048 let cfg = &cfg.text_config;
4049
4050 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4051 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4053 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4054 } else {
4055 0
4056 };
4057 let norm = cfg.hidden_size;
4058 embed_tokens + lm_head + norm
4059 };
4060
4061 let vision_elems = {
4062 let cfg = &cfg.vision_config;
4063
4064 let patch_embed = {
4065 let conv_cfg = Conv2dConfig {
4066 stride: cfg.patch_size,
4067 ..Default::default()
4068 };
4069 cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4070 * cfg.patch_size
4071 * cfg.patch_size
4072 * cfg.patch_size
4073 };
4074 let ln_pre = cfg.hidden_size;
4075 let vision_layer = {
4076 let attn_norm = cfg.hidden_size;
4077 let ffn_norm = cfg.hidden_size;
4078
4079 let gate = cfg.hidden_size * cfg.intermediate_size;
4080 let up = cfg.hidden_size * cfg.intermediate_size;
4081 let down = cfg.hidden_size * cfg.intermediate_size;
4082
4083 let q = cfg.hidden_size * cfg.hidden_size;
4084 let k = cfg.hidden_size * cfg.hidden_size;
4085 let v = cfg.hidden_size * cfg.hidden_size;
4086 let o = cfg.hidden_size * cfg.hidden_size;
4087
4088 attn_norm + ffn_norm + gate + up + down + q + k + v + o
4089 };
4090
4091 patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4092 };
4093
4094 let elems = text_elems + vision_elems;
4095
4096 Ok(elems * dtype.size_in_bytes())
4097 }
4098
4099 fn layer_sizes_in_bytes(
4100 &self,
4101 config: &str,
4102 dtype: DType,
4103 weight_pack_factor: usize,
4104 ) -> Result<Vec<usize>> {
4105 let cfg: Mistral3Config = serde_json::from_str(config)?;
4106 let cfg = &cfg.text_config;
4107
4108 let per_layer_elems = {
4109 let input_layernorm = cfg.hidden_size;
4110 let post_attention_layernorm = cfg.hidden_size;
4111
4112 let size_in = cfg.hidden_size;
4113 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4114 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4115 let q_proj = size_in * size_q / weight_pack_factor;
4116 let k_proj = size_in * size_kv / weight_pack_factor;
4117 let v_proj = size_in * size_kv / weight_pack_factor;
4118 let o_proj = size_q * size_in / weight_pack_factor;
4119
4120 let h_size = cfg.hidden_size;
4121 let i_size = cfg.intermediate_size;
4122 let gate_proj = h_size * i_size / weight_pack_factor;
4123 let up_proj = h_size * i_size / weight_pack_factor;
4124 let down_proj = i_size * h_size / weight_pack_factor;
4125
4126 input_layernorm
4127 + post_attention_layernorm
4128 + q_proj
4129 + k_proj
4130 + v_proj
4131 + o_proj
4132 + gate_proj
4133 + up_proj
4134 + down_proj
4135 };
4136 Ok(vec![
4137 per_layer_elems * dtype.size_in_bytes();
4138 cfg.num_hidden_layers
4139 ])
4140 }
4141
4142 fn num_layers(&self, config: &str) -> Result<usize> {
4143 let cfg: Mistral3Config = serde_json::from_str(config)?;
4144 let cfg = &cfg.text_config;
4145 Ok(cfg.num_hidden_layers)
4146 }
4147
4148 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4149 let cfg: Mistral3Config = serde_json::from_str(config)?;
4150 let cfg = &cfg.text_config;
4151
4152 let cfg = ModelConfigMetadata {
4153 max_seq_len: cfg.max_position_embeddings,
4154 num_layers: cfg.num_hidden_layers,
4155 hidden_size: cfg.hidden_size,
4156 num_kv_heads: cfg.num_key_value_heads,
4157 num_attn_heads: cfg.num_attention_heads,
4158 sliding_window: cfg.sliding_window,
4159 k_head_dim: cfg.head_dim(),
4160 v_head_dim: cfg.head_dim(),
4161 };
4162
4163 Ok(Box::new(cfg))
4164 }
4165
4166 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4167 Some(vec![NonMappedSubModel::Vision])
4168 }
4169}
4170
4171pub struct VLlama4Loader;
4177
4178pub struct VLlama4Prefixer;
4179
4180impl MultimodalPromptPrefixer for VLlama4Prefixer {
4181 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4182 format!(
4183 "{}{prompt}",
4184 llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4185 )
4186 }
4187}
4188
4189impl VisionModelLoader for VLlama4Loader {
4190 fn load(
4191 &self,
4192 config: &str,
4193 vb: ShardedVarBuilder,
4194 normal_loading_metadata: NormalLoadingMetadata,
4195 attention_mechanism: AttentionImplementation,
4196 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4197 let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4198 Ok(Box::new(Llama4Model::new(
4199 &cfg,
4200 vb,
4201 self.is_gptx(config),
4202 normal_loading_metadata,
4203 attention_mechanism,
4204 )?))
4205 }
4206 fn is_gptx(&self, _config: &str) -> bool {
4207 false
4208 }
4209 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4210 let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4211 Ok(Box::new(cfg))
4212 }
4213 fn get_processor(
4214 &self,
4215 _model_config: &str,
4216 processor_config: Option<ProcessorConfig>,
4217 _preprocessor_config: PreProcessorConfig,
4218 _max_edge: Option<u32>,
4219 ) -> Arc<dyn Processor + Send + Sync> {
4220 Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4221 }
4222 fn supports_paged_attention(&self, _config: &str) -> bool {
4223 true
4224 }
4225 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4226 Arc::new(VLlama4Prefixer)
4227 }
4228 fn modalities(&self, _config: &str) -> Result<Modalities> {
4229 Ok(Modalities {
4230 input: vec![SupportedModality::Text, SupportedModality::Vision],
4231 output: vec![SupportedModality::Text],
4232 })
4233 }
4234}
4235
4236impl IsqModelLoader for VLlama4Loader {
4237 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4238 Ok(vec![
4239 Regex::new(r"lm_head\.(weight|bias)$")?,
4240 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4242 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4243 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4244 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4245 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4247 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4248 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4249 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4250 Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4251 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4252 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4253 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4254 Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4256 Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4257 Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4258 ])
4259 }
4260 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4261 Ok(vec![
4262 Regex::new(r"lm_head\.(weight|bias)$")?,
4263 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4265 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4266 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4267 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4268 Regex::new(
4270 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4271 )?,
4272 Regex::new(
4273 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4274 )?,
4275 Regex::new(
4276 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4277 )?,
4278 Regex::new(
4279 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4280 )?,
4281 Regex::new(
4282 r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4283 )?,
4284 Regex::new(
4285 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4286 )?,
4287 Regex::new(
4288 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4289 )?,
4290 Regex::new(
4291 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4292 )?,
4293 Regex::new(
4295 r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4296 )?,
4297 Regex::new(
4298 r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4299 )?,
4300 Regex::new(
4301 r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4302 )?,
4303 ])
4304 }
4305}
4306
4307impl VLlama4Loader {
4308 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4311 fn run_dummy_processing(
4312 &self,
4313 cfg: &Llama4Config,
4314 height: usize,
4315 width: usize,
4316 max_num_images: usize,
4317 max_batch_size: usize,
4318 ) -> Result<(usize, usize)> {
4319 let cfg = &cfg.vision_config;
4320
4321 let img_processor =
4322 Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4323 let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4324 let res = img_processor.preprocess(
4325 vec![image; max_num_images],
4326 vec![],
4327 &PreProcessorConfig::default(),
4328 &Device::Cpu,
4329 (max_batch_size, max_num_images),
4330 )?;
4331
4332 let pixels_batch_size = res.pixel_values.dim(0)?;
4333 let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4334
4335 let (image_h, image_w) = (
4336 res.pixel_values.dim(D::Minus2).unwrap(),
4337 res.pixel_values.dim(D::Minus1).unwrap(),
4338 );
4339 let num_patches_per_chunk = (image_h / img_processor.patch_size)
4340 * (image_w / img_processor.patch_size)
4341 / img_processor.downsample_ratio;
4342
4343 Ok((
4344 pixels_max_batch_size,
4345 num_patches_per_chunk * pixels_max_batch_size,
4346 ))
4347 }
4348}
4349
4350impl DeviceMappedModelLoader for VLlama4Loader {
4351 fn mapped_max_act_size_elems(
4352 &self,
4353 config: &str,
4354 params: &AutoDeviceMapParams,
4355 _prompt_chunksize: usize,
4356 ) -> Result<usize> {
4357 let AutoDeviceMapParams::Vision {
4358 max_seq_len,
4359 max_batch_size,
4360 max_image_shape: (height, width),
4361 max_num_images,
4362 } = params
4363 else {
4364 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4365 };
4366
4367 let cfg: Llama4Config = serde_json::from_str(config)?;
4368
4369 let (_pixels_batch_size, num_text_image_toks) =
4370 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4371
4372 let max_seq_len = max_seq_len + num_text_image_toks;
4373
4374 Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4375 }
4376 fn non_mapped_max_act_size_elems(
4377 &self,
4378 config: &str,
4379 params: &AutoDeviceMapParams,
4380 ) -> Result<usize> {
4381 let AutoDeviceMapParams::Vision {
4382 max_seq_len: _,
4383 max_batch_size,
4384 max_image_shape: (height, width),
4385 max_num_images,
4386 } = params
4387 else {
4388 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4389 };
4390
4391 let cfg: Llama4Config = serde_json::from_str(config)?;
4392
4393 let (pixels_batch_size, _num_text_image_toks) =
4394 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4395 let max_seq_len = cfg.vision_config.num_patches();
4396
4397 Ok((max_batch_size * pixels_batch_size)
4398 * cfg.vision_config.num_attention_heads
4399 * max_seq_len
4400 * max_seq_len)
4401 }
4402
4403 fn non_mapped_size_in_bytes(
4404 &self,
4405 config: &str,
4406 dtype: DType,
4407 weight_pack_factor: usize,
4408 ) -> Result<usize> {
4409 let cfg: Llama4Config = serde_json::from_str(config)?;
4410 let tcfg = &cfg.text_config;
4411
4412 let text_elems = {
4413 let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4414 let lm_head = if !tcfg.tie_word_embeddings {
4415 tcfg.hidden_size * tcfg.vocab_size
4416 } else {
4417 0
4418 };
4419 let norm = tcfg.hidden_size;
4420 embed_tokens + lm_head + norm
4421 };
4422
4423 let vision_elems = {
4424 let cfg = &cfg.vision_config;
4425
4426 let num_patches = cfg.num_patches();
4427
4428 let unfold_elems =
4429 (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4430 let class_embeddng_elems = cfg.hidden_size;
4431 let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4432 let layernorm_pre_elems = cfg.hidden_size;
4433 let layernorm_post_elems = cfg.hidden_size;
4434
4435 let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4436 / weight_pack_factor
4437 + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4438
4439 let encoder_layer = {
4440 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4441 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4442
4443 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4444 let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4445 / weight_pack_factor
4446 + cfg.num_attention_heads * head_dim;
4447 let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4448 / weight_pack_factor
4449 + cfg.num_attention_heads * head_dim;
4450 let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4451 / weight_pack_factor
4452 + cfg.num_attention_heads * head_dim;
4453 let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4454 / weight_pack_factor
4455 + cfg.num_attention_heads * head_dim;
4456
4457 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4458 + cfg.intermediate_size;
4459 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4460 + cfg.hidden_size;
4461
4462 input_layernorm
4463 + post_attention_layernorm
4464 + q_proj
4465 + k_proj
4466 + v_proj
4467 + o_proj
4468 + fc1
4469 + fc2
4470 };
4471
4472 unfold_elems
4473 + class_embeddng_elems
4474 + positional_embedding_vlm_elems
4475 + layernorm_post_elems
4476 + layernorm_pre_elems
4477 + pixel_shuffle_elems
4478 + encoder_layer * cfg.num_hidden_layers
4479 };
4480
4481 let elems = text_elems + vision_elems;
4482
4483 Ok(elems * dtype.size_in_bytes())
4484 }
4485
4486 fn layer_sizes_in_bytes(
4487 &self,
4488 config: &str,
4489 dtype: DType,
4490 weight_pack_factor: usize,
4491 ) -> Result<Vec<usize>> {
4492 let cfg: Llama4Config = serde_json::from_str(config)?;
4493 let tcfg = &cfg.text_config;
4494
4495 let mut per_layer_elems = Vec::new();
4496
4497 for layer_idx in 0..tcfg.num_hidden_layers {
4498 let input_layernorm = tcfg.hidden_size;
4499 let post_attention_layernorm = tcfg.hidden_size;
4500
4501 let size_in = tcfg.hidden_size;
4502 let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4503 let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4504 let q_proj = size_in * size_q / weight_pack_factor;
4505 let k_proj = size_in * size_kv / weight_pack_factor;
4506 let v_proj = size_in * size_kv / weight_pack_factor;
4507 let o_proj = size_q * size_in / weight_pack_factor;
4508
4509 let use_moe = tcfg.moe_layers().contains(&layer_idx);
4510 let moe_block = if use_moe {
4511 let h_size = tcfg.hidden_size;
4512 let i_size = tcfg.intermediate_size;
4513 let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4514 let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4515 let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4516
4517 gate_proj + up_proj + down_proj
4518 } else {
4519 let h_size = tcfg.hidden_size;
4520 let i_size = tcfg.intermediate_size_mlp;
4521 let gate_proj = h_size * i_size / weight_pack_factor;
4522 let up_proj = h_size * i_size / weight_pack_factor;
4523 let down_proj = i_size * h_size / weight_pack_factor;
4524
4525 gate_proj + up_proj + down_proj
4526 };
4527
4528 per_layer_elems.push(
4529 input_layernorm
4530 + post_attention_layernorm
4531 + q_proj
4532 + k_proj
4533 + v_proj
4534 + o_proj
4535 + moe_block,
4536 );
4537 }
4538
4539 Ok(per_layer_elems
4540 .into_iter()
4541 .map(|x| x * dtype.size_in_bytes())
4542 .collect())
4543 }
4544
4545 fn num_layers(&self, config: &str) -> Result<usize> {
4546 let cfg: Llama4Config = serde_json::from_str(config)?;
4547 Ok(cfg.text_config.num_hidden_layers)
4548 }
4549
4550 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4551 let cfg: Llama4Config = serde_json::from_str(config)?;
4552 let cfg = &cfg.text_config;
4553
4554 let cfg = ModelConfigMetadata {
4555 max_seq_len: cfg.max_position_embeddings,
4556 num_layers: cfg.num_hidden_layers,
4557 hidden_size: cfg.hidden_size,
4558 num_kv_heads: cfg.num_attention_heads,
4559 num_attn_heads: cfg.num_attention_heads,
4560 sliding_window: None,
4561 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4562 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4563 };
4564
4565 Ok(Box::new(cfg))
4566 }
4567
4568 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4569 Some(vec![NonMappedSubModel::Vision])
4570 }
4571}